Muesli (MuZero 후속 알고리즘) implementation - (3)

개요

(해당 글은 초기 버전에 대한 것이고 다른 글에 새 버전으로 업데이트하였습니다.)

앞선 글에 이어서 설명합니다. 

[부분 7]

                    L_total = L_pg_cmpo + L_v/3/4 + L_r/2 + L_m/4                

                    Cumul_L_total += L_total

            Cumul_L_total /= 30
            self.optimizer.zero_grad()
            Cumul_L_total.backward()

            nn.utils.clip_grad_value_(self.parameters(), clip_value=1.0)

            ## dynamics network gradient scale 1/2  
            for d in self.dynamics_network.parameters():
                d.grad *= 0.5

            self.optimizer.step()

이 부분은 전에 구했던 로스들을 합해주는 부분입니다. 각각의 로스는 unroll한 만큼 나눠져 있고, Value model loss는 논문에 소개된 대로 0.25를 곱하는 것이 학습 안정성에 도움이 됩니다. L_m 또한 저는 0.25를 곱해주었습니다. 

제 구현은 루프에서 한 transition마다 loss를 구하고 accumlate한 뒤에 나눠주는 방식의 미니배치라서 저렇게 한 배치에 대한 계산이 끝난 뒤 사이즈만큼 나눠주고, backward 해줍니다. 그리고 논문에 소개된 대로 gradient를 [-1,1]로 클립해서 사용하였고, dynamics network로 흘러가는 gradient를 절반으로 처리해줍니다. 그리고 최종적으로 옵티마이저를 통해 weight들을 업데이트 해줍니다.

[부분 8]
            ## target network(prior parameters) moving average update
            alpha_target = 0.01
            params1 = self.named_parameters()
            params2 = target.named_parameters()
            dict_params2 = dict(params2)
            for name1, param1 in params1:
                if name1 in dict_params2:
                    dict_params2[name1].data.copy_(alpha_target*param1.data + (1-alpha_target)*dict_params2[name1].data)
            target.load_state_dict(dict_params2)

        #self.scheduler.step()

        ##trajectory clear
        self.state_traj.clear()
        self.action_traj.clear()
        self.P_traj.clear()
        self.r_traj.clear()
        return

해당 부분은 target network에 업데이트가 진행된 네트워크의 파라미터를 섞어주면서 진행하는 부분입니다. 이렇게 하면 target network는 이전에 업데이트되었던 여러 네트워크들을 섞어둔 네트워크가 되는 효과가 있다고 합니다. 그리고 online queue의 역할을 하는 traj array들을 clear해주면서 한번의 업데이트 함수 호출이 끝이 납니다. 


## Earned score per episode
plt.plot(score_arr, label ='score')
plt.legend(loc='upper left')

## game play video
agent.load_state_dict(torch.load("weights.pt"))
env = gnwrapper.LoopAnimation(gym.make(game_name))
state = env.reset()
for _ in range(100):
    with torch.no_grad():
        hs = agent.representation_network(torch.from_numpy(state).float().to(device))
        P, v = agent.prediction_network(hs)
        action = np.random.choice(np.arange(agent.action_space), p=P.detach().numpy())  
    env.render()
    state, rew, done, _ = env.step(action.item())
    if done:
        state = env.reset()
env.display()

self play를 진행하면서 모았던 점수를 그래프로 그려서 확인하고, 학습된 네트워크를 가지고 play하는 영상을 볼 수 있는 코드입니다.



여기까지 cartpole에 대해 구현해본 Muesli 코드에 대해 설명하였습니다. LunarLander에도 사용할 수 있게 시도해보면서 네트워크에 몇가지 최적화를 추가한 버전이 있는데 다음 글에서는 이에 대해 추가로 작성하도록 하겠습니다.



아래에는 몇가지 느낀점을 메모합니다.

일단 오해없이 잘 읽는게 제일 중요하다. 한편으로는 이해가 잘 가지 않으면 일단 만들면서 진행하는것도 한 방법이다.
논문 appendix쪽에 있는 내용들이나 최적화 관련 내용이나 하이퍼파라미터 알려준 것 따라서 가는게 생각보다 큰 진전을 준다. 
하이퍼파라미터 찾는건 항상 빡세고, 강력한 컴퓨팅파워가 필요함을 느낀다.
가능하면 들여다 볼 수 있는 변수나 활성화값들을 찍어보는 것이 큰 도움이 된다.
딥러닝 라이브러리는 놀라울 정도로 잘 만들어진 물건이지만 이 역시 사람이 만든 것이라 가끔 오류가 있고 오인할 부분도 많아서 가능하면 다큐멘트를 최대한 잘 읽고 값도 직접 찍어봐야 한다고 생각이 듬.
네트워크 입출력 값은 왠만하면 [0.1] 사이로 정규화 하는게 좋은 듯 함. 꼭 그러지 않아도 학습이 되긴 하지만 모델이 복잡해질수록 도움이 되는 부분이 있는 듯 하다.

 

댓글

가장 많이 본 글

구글 람다(LaMDA)란? - 구글의 언어 모델

알파고 강화학습 원리

버텍스 AI란? - 구글 인공지능 플랫폼

카타고와 바둑 두어보기

뉴럴 네트워크란?

블로그 글 목록

뉴럴 네트워크를 학습시키는 방법