뮤제로(MuZero) 강화학습 알고리즘 - 오픈소스 코드 구경하기 (3)

개요

이번 글에서는 뮤제로(MuZero)에 대해 공개된 pseudo code나 오픈소스 구현된 코드에 대한 링크와 정보를 공유한다. 

1. Deepmind official pseudocode (.py)
supplementary information : https://www.nature.com/articles/s41586-020-03051-4 페이지 아래쪽에서 supplementary information 다운로드 가능함. 파일 내의 pseudocode.py.
아래는 원저자 github. 내용은 동일함


2. 오픈소스 구현된 코드

werner-dauvad의 github에 올려진 오픈소스 코드이다. Official pseudocode에 기반했지만 완전히 동일하다고 보장할 수 없고 성능이나 구현 정도도 논문에 발표된 것과 동일하다고 할 수는 없다. 

그래도 대부분의 기능들이 구현되어 있고 실제로 돌아가는 모습을 볼 수 있는 구현된 코드이다. 해당 github README 페이지에서는 소개 및 구현 사항들, 코드 구조에 대한 그림, 실행방법 등이 소개되어 있다. 

Pseudocode는 tf로 표현되어있는데 해당 소스는 pytorch기반이다. Environment역할을 하는 game들은 games폴더 내에 있는 파일들을 통해서 openai gym에서 게임을 불러오고 게임마다 필요한 config들을 설정한다.

기본적으로 self-play, replay buffer, trainer, shared storage 순서로 계속 순환하는 구조를 가지고 있다. 

muzero.py는 main역할을 하는 파일로 여러 모델들을 엮어서 init하고 self-play 및 train을 실행한다.

model.py는 Muzero모델,  representation function, dynamics function, prediction function에 관련된 class들을 담고 있다.

self-play, replay buffer, trainer, shared storage .py는 각각 관련 구현 내용들을 담고 있다.

아래는 코드를 구경하면서 대략 정리한 흐름 및 기타 메모들이다. 

muzero.py
def train()
    options.remote는 ray(병렬처리관련)코드
    (https://docs.ray.io/en/latest/ray-core/package-ref.html)
    train내에서 여러가지 worker들을 init and launch하고 
    self_play하는것으로 시작, 
    continuous_self_play & continuous_update_weights
    (둘다 self.shared_storage_worker, self.replay_buffer_worker 중간에 위치해서 remote로
    흐름이 연결되는듯) 

self_play.py
def continuous_self_play()
model parameter init from shared storage
    def play_game()      
        observation들로부터 mcts.run()는 root(Node객체) 리턴, 
    inside mcts.run() 
    mcts.run()시에 model.py의 여러 network들을 사용해서 initial_inference & recurrent_inference
            initial_inference에서 리턴되는 value – 그냥 prediction network에서 value 출력값            
    initial_inference = representation and prediction
    recurrent_inference = dynamics and prediction
        root의 children정보들로부터 select_action() (여기서 search policy visit count 정보)
        interact with game env
        append (root stat, a, observation, r) to history
        loop until < max_moves and not done
history는 continuous_self_play 통해 replay_buffer로 save_game()

replay_buffer.py
    save_game()
    compute_target_value계산에서 game_history로부터 무슨 value들을 꺼내와서 target_value z를 계산하는 것 같은데
    이 계산에서 쓰는 game_history의 root_values=[]는 Node객체의 value()로부터 쌓이는데 이는 value_sum/visit_count 
    (mcts 서치 결과로 나오는 search value관련된 것인듯) priotized replay관련 계산
    및 기타 길이 계산 및 저장 등
외부에서는 get_batch로 불러오고
get_batch는 make_target()에서 sample_game()으로 샘플해오고 여러 히스토리들 및 target 계산된 배치들을 계산해 리턴해줌
    compute_target_value()에서는 root_value(search value)로 bootstrap
    
Trainer.py
def Continous_update_weights()
replay_buffer에서 .get_batch로 데이터 가져옴
    update_weight()
    unroll prediction하기 with initial_inference & recurrent_inference
    loss 구하기(cross-entropy)
    optimize,  loss.backward()
log
model save to shared storage

shared storage.py
간단히 weight정보 등 저장 역할

model.py
외부에서 모델 불러올 때 MuzeroNetwork로 init하고 FC/Residual 로 이어짐
MuZeroResidualNetwork
Representation
Dynamics
Prediction
모두 한 모델 내 선언됨

외부에서는 initial_inference & recurrent_inference로 사용되고
이 내에서는 Representation() Dynamics() Prediction()
메소드는 해당하는 class network 사용
셋다 기본 backbone 구조는 resnet block 이고
Representation은 downsample되는 네트워크를 먼저 지남 
Dynamics는 cat은 그냥 image형태로 one hot encode된 action채널 추가하고 채널 다시 줄여줌
그리고 s’는 다음 prediction으로 넘기고, s’로부터 reward inference

아래는 initial_inference, recurrent_inference에 대한 설명이다.
initial_inference = representation and prediction
recurrent_inference = dynamics and prediction

첫 inference를 제외하고 나머지 state transition은 recurrent_inference가 책임진다.

3. EfficientZero 오픈소스 코드

아래는 후속 연구들에 대한 리뷰이다.

댓글

가장 많이 본 글

구글 람다(LaMDA)란? - 구글의 언어 모델

알파고 강화학습 원리

버텍스 AI란? - 구글 인공지능 플랫폼

카타고와 바둑 두어보기

뉴럴 네트워크란?

블로그 글 목록

뉴럴 네트워크를 학습시키는 방법