뮤제로(MuZero) 강화학습 알고리즘 - (1) 에 이어 작성
먼저 replay buffer에 저장해 두었던 기록들에서 어느 state를 샘플해온다. (prioritized replay라는 기법이 사용된다. 학습하기에 중요하다고 생각되는 부분들을 우선적으로 가져와서 학습하는 것이다). 가져온 state로부터 그 이후 time step들에 해당하는 replay 정보들을 가져와 펼치고(B, observed steps), 또 동일한 가져왔던 state로부터 muzero 모델에 기반해 단순하게 unroll한 것을 동일 선상에 둔다(C, hypothetical steps). (논문에서는 5-steps unroll한다).
모델이 unroll해서 만든 hypothetical steps의 p,r,v 는 observed steps의 π,u,v으로부터 알 수 있는 search policy π, observed immediate reward u, value target z에 가까워져야 한다. 모델이 예측한 hypothetical steps에서의 정보들보다 MCTS를 통해서 보다 나은 선택을 한 기록인 observed steps의 π,u,v가 담고 있는 정보가 더 낫기 때문이다.(u,v또한 MCTS, planning at decision time에 따른 action의 결과로부터 얻어진 것들이기 때문에).
Hypothetical steps와 observed steps의 차이를 최소화하게끔 loss를 계산한다. 각 네트워크별 loss계산에는 cross-entropy loss가 권장된다.(MSE도 사용될 수 있다.) Overall loss는 위 식과 같이 계산된다. K=5 steps가 사용되고, value target z 계산에는 n=10 steps가 사용된다. (보드게임 같은 경우는 imm reward u가 없어서 해당 loss계산 없음) 5 step동안 unroll 되었던 네트워크들로부터 계산된 loss들이 합쳐지고, 이 overall loss가 unroll했던 네트워크들에 jointly trained 된다. (end-to-end, backpropagation through time).
이렇게 training 과정을 통해 뮤제로 모델은 improve되고, improved된 모델로 다시 self-play및 training을 하는 과정을 반복하다 보면 바둑, 체스, 쇼기, 아타리 모두 높은 성능에 도달할 수 있게 된다. (한 모델이 general하게 모든 게임을 다 하는 것은 아니고 각각에 대해 따로 학습함. 내부도 각자 미묘하게 다름)
이후 여러가지 측면에서 개선된 후속 논문들이 나왔는데 그 또한 간단하게 리뷰해 보려고 한다. 읽는 도중 오픈소스 코드도 일부 보았는데 관련 내용도 공유하려 한다.
논문 제목 : 'Mastering Atari, Go, chess and shogi by planning with a learned model'
supplementary information : https://www.nature.com/articles/s41586-020-03051-4 페이지 아래쪽
Self-play
알고리즘을 학습시키기 위해 데이터를 모으는 self-play 과정이 먼저 필요하다. 알고리즘은 처음에 아무렇게나 움직이는, 전혀 학습되지 않은 상태에서 시작한다. 아래 그림에서는 바둑의 경우로 설명한다. 각 바둑판은 실제 state이며, 그 밑의 화살표 갈래들은 MCTS(몬테카를로 트리 서치)로 lookahead search(planning at decision time)를 하는 과정을 표현한다.
MCTS로 lookahead search를 하는 이유는, self-play시에 lookahead search를 통해서 얻은 search policy π를 따라 action selection을 하는 것이 그냥 policy p로 부터 action selection을 하는 것보다 더 나은 선택을 할 가능성이 높기 때문이다. (이 MCTS과정이 뮤제로 알고리즘의 성능이 점진적으로 높아질 수 있게 하는 중요한 역할을 한다.)
self-play 과정, paper figure 1.b
self-play시 MCTS lookahead search, paper figure 1.a
MCTS로 lookahead search를 하는 이유는, self-play시에 lookahead search를 통해서 얻은 search policy π를 따라 action selection을 하는 것이 그냥 policy p로 부터 action selection을 하는 것보다 더 나은 선택을 할 가능성이 높기 때문이다. (이 MCTS과정이 뮤제로 알고리즘의 성능이 점진적으로 높아질 수 있게 하는 중요한 역할을 한다.)
MCTS는 실행 과정에서 여러가지 노드들을 펼치고 방문하면서 그 가치를 재고 위쪽 노드로 통계를 넘기는 형태의 동작을 한다. MCTS를 충분히 실행한 이후에는 visit count나 가치 등에 대한 정보가 가장 위 root node아래에 쌓이게 된다. 이 정보에 기반해 search policy π를 계산할 수 있다.
기존 알파제로 같은 경우는 위 그림처럼 복잡하게 h,g(각각 representation function, dynamics function)가 개입하지 않는다. State가 날것 그대로인 real state로 표현되었기 때문에 p,v(policy, value function)을 predict하는 데만 신경쓰면 되는 부분이다.
뮤제로는 state가 representation function과 dynamics function의 뒷받침 하에 hidden state로 관리된다. 처음 hidden state인 s0, root state는 representation function에 의해 만들어진다. 그리고 그 이후 state transition은 dynamics function이 담당한다.
(위 MCTS 그림에 나오는 p, v, r, a등은 이후 학습에 관련된 그림에 나오는 것들과는 다른 부분이며 큰 의미는 없다. 원리는 동일하겠지만 MCTS실행시 내부 동작을 설명하기 위함이지 이후 학습에서 사용되는 부분은 아니다.) 다시 처음 그림으로 돌아가서,
Self-play동안 search policy π에 따라 action selection을 하면서 계속 게임을 진행한다. 게임이 끝날 때까지나 정해진 길이,시간까지 진행하고 지나왔던 trajectory에서 얻은 정보들을 replay buffer(저장소)에 저장한다. 정보들은 각각 search policy π, search value v, selected action a, observed immediate reward u, observed state o 이다. (v가 논문 상에서 value function v와 헷갈릴 수 있는데 notation을 잘 봐야한다. 위 그림에서 나오는 search value v는 mcts로 search한 stat에 기반해서 계산한 value이다. 나중에 target value z계산에 쓰인다). 이렇게 얻어내고 저장한 정보들은 학습(training)시에 사용된다.
Training
이제 얻어낸 데이터로부터 네트워크들을 improve할 차례이다. 다른 supervised learning기법들과 유사하게 네트워크의 prediction과 target의 차이를 loss로 설정하고 그에 기반해 네트워크를 학습시킨다.
targets : search policy π, observed immediate reward u, value target z, 원저자 Julian Schrittwieser youtube
모델이 unroll해서 만든 hypothetical steps의 p,r,v 는 observed steps의 π,u,v으로부터 알 수 있는 search policy π, observed immediate reward u, value target z에 가까워져야 한다. 모델이 예측한 hypothetical steps에서의 정보들보다 MCTS를 통해서 보다 나은 선택을 한 기록인 observed steps의 π,u,v가 담고 있는 정보가 더 낫기 때문이다.(u,v또한 MCTS, planning at decision time에 따른 action의 결과로부터 얻어진 것들이기 때문에).
Overall loss
Loss가 저렇게 여러 군데서 합쳐져도 학습이 잘 되는 이유는 다른 head들에서 발생하는 loss들은 일종의 noise로 받아들여지거나 baseline function의 역할을 해주는 것이지 않은가 생각한다. 여러 unroll step들에 대해서 합치는 것은 batch 학습과 비슷한 원리인 듯 하다.
안정적으로 학습되게끔 몇가지 최적화가 적용된다. 여러 네트워크들의 head들에서 나오는 gradient들을 적정 비율로 조절해서 합친 gradient가 적정 크기에서 유지되게끔 unroll step K에 따라 1/K로 loss를 조절하거나 dynamics function에 적정 비율을 곱하는 등의 방법을 취한다.
성능 및 한계
성능은 바둑, 체스, 쇼기, 아타리 모두 superhuman이고 기존 모든 state-of-the-art 들의 성능 이상이다. 알파제로와의 보드 게임의 비교시에는 거의 비슷하거나 살짝 더 강한 모습을 보인다. Simulator의 도움 없이 planning하는 것이 simulator의 도움을 받은 것만큼 성능이 나오는 것이니 뮤제로의 model-based approach가 성공적으로 먹혔다는 이야기이다.
성공적인 model-based approach로 알고리즘의 적용가능 범위가 넓어지긴 했지만, 아직 충분히 general하다고 하기는 어려운 것 같다. 한 모델이 general하게 모든 게임을 다 하는 것이 아니고 각각에 대해 따로 학습해야 한다. 보드게임과 아타리에 적용되는 설정도 약간씩 다르고, 각 게임들마다 action plane 설정도 따로 해줘야 하는 등 각각에 신경써야 하는 부분들이 있다.
그리고 atari게임의 경우 action space 크기가 18이다. 더 크기가 큰 문제에도 적용 가능할 수 있는지는 이번 논문에서는 알 수 없다. 또 dynamics function이 stochastic 한 경우에는 아직 대응가능하지 않다.
글을 마치며
알파고가 처음 나왔을 때는 그냥 바둑 프로그램인 것 같은데 그걸로 뭘 할 수 있을까 싶었지만 이렇게 조금씩 한계가 하나씩 풀려나가면서 뮤제로까지 왔다. 그 이후 발전은 어떻게 흘러갈지 기대가 된다.
댓글
댓글 쓰기