Muesli LunarLander-v2 implementation

개요

DeepMind에서 2021년도에 발표한 model-based 강화학습 알고리즘인 Muesli(뮤즐리)를 LunarLander-v2 환경에 구현하였습니다. Cartpole-v1 환경에도 실행 가능합니다. 아래는 코드를 돌려볼 수 있는 colab 링크입니다. (학습 및 Tensorboard, game play 가능)




학습된 네트워크로 플레이한 LunarLander-v2.
Score 그래프

Introduction

Muesli(뮤즐리)는 DeepMind에서 알파고 이후로부터 꾸준히 발전시켜온 알고리즘 중 비교적 최신에 위치하는 알고리즘으로 바둑, 체스와 같은 보드게임 뿐 아니라 Atari와 같이 시각적으로 복잡한 task에 대해서도 높은 성능을 내주는 model-based 강화학습 알고리즘입니다.

뮤즐리는 이전 연구인 뮤제로(MuZero)와 동일한 네트워크 구조 위에 MCTS(Monte Carlo Tree Search)없이 구성 가능한 loss를 사용함으로서 다른 model-free 알고리즘들과 비슷한 수준으로 연산량을 낮추고 뮤제로의 높은 성능을 유지합니다. 많은 연산량을 요구하는 MCTS를 사용하지 않고도 적은 비용으로 효율적인 업데이트가 가능해진 것입니다.

구현을 진행해보니, 사람들이 왜 PPO를 좋아하는지 알 것 같다는 생각이 들기도 합니다. PPO는 모델이나 로스가 간단한 편이고 하이퍼퍼라미터도 적은데 성능도 좋고 안정적이고 구현되어 있는 모델도 많고 여러모로 장점이 많습니다. 한편 뮤즐리는 네트워크들이 얽혀있다보니 여러가지 신경써줘야 될 부분들도 많고 하이퍼파라미터도 적지 않아서 실제로 사용하려면 손이 많이 가겠다는 생각이 듭니다. 그래도 잘 짠다면 성능 상한은 더 높을 것이니 높은 성능에 대한 요구가 있다면 시도해볼만 하지 않을까 싶습니다. 
Figure 1. 논문에 소개된 Atari에 대한 성능

네트워크 구조

뮤즐리는 model-based 알고리즘으로써 model-free 알고리즘과 다르게 state transition에 대한 학습을 가능하게 하는 네트워크들을 가지고 있고, 학습을 위해 네트워크들을 몇개의 time step만큼 펼치고(unroll)하고 한번에 학습시키는 방법을 사용합니다. 펼쳐진 네트워크의 각 policy, value, reward 부분들로부터 계산되어 앞쪽까지 흘러오는 gradient들이 중첩되는 형태로, unroll된 네트워크 전체가 한번에 학습되게 됩니다. (이를 논문에서는 jointly trained, end to end, by backpropagation through time이라고 표현을 합니다.)

아래는 뮤즐리의 네트워크가 펼쳐진 경우의 그림입니다.

Self-play시 얻은 데이터 observation을 representation network의 도움을 받아 hidden state로 변환하면서 unroll이 시작됩니다. Hidden state는 이후 time step으로 넘어갈때마다 dynamics network를 통해서 state transition이 됩니다. 그 과정에서 reward 및 policy, value 도 인퍼런스됩니다. 

네트워크가 이상적으로 학습된 상태라면 이러하게 unroll하는 과정에서 알고리즘은 simulator의 도움 없이도 미래의 time step에 예상되는 p, v, r을 예측할 수 있게 되고 이는 알고리즘이 planning할 수 있게 해주는 원천이 됩니다. 이러한 planning의 사용이 model-based 알고리즘이 model-free 알고리즘보다 성능 면에서 유리하게 만들어주는 부분인데, 한편으로는 네트워크가 이에 대한 학습을 안정적으로 지원해줘야 함이 우선일 것입니다. 

이를 위해서 네트워크에는 몇가지 최적화 기법들이 적용이 됩니다. 

1. Hidden state는 항상 적정 범위 내로 조정이 됩니다. (Min-max normalization)

2. Gradient 계산시 dynamics network가 매번 뒤로 흘려보내주는 gradient를 절반씩 줄이게끔 해서 앞쪽까지 과도한 gradient가 중첩되는 것을 막습니다.

3. 각 p, v, r로 들어가는 loss를 unroll step만큼 나눠주고, 하이퍼파라미터로도 적정량 나눠줘서 네트워크가 과도하게 학습되지 않게 합니다.

4. Categorical reparametrization 기법을 사용합니다. 이는 네트워크가 스칼라 값을 그대로 학습하는 게 아니라 categorical하게 변환된 분포를 학습하게 하는 기법으로 reward scale이 클 때 네트워크가 흔들리게 되는 것을 막아줘서 학습에 큰 도움이 됩니다. (Categorical reparametrization 설명글)

이러한 여러 기법들을 통해 네트워크가 서로 엮여 있어서 생기는 불안정성을 크게 줄여 안정적인 학습을 기대할 수 있게 됩니다.

알고리즘 요약

뮤즐리 알고리즘의 policy update는 크게 두가지로 구성된다고 볼 수 있는데, 논문에서 제시한 regularized objective를 optimize함으로써, 그리고 model을 학습하는데 도움이 되는 auxiliary loss들로부터의 학습이 이를 뒷받침하며 업데이트가 이뤄집니다. 

논문에서 제시하는 regularized objective는 policy 학습의 주된 역할을 하는 부분으로, TRPO와 같은 regularized policy optimization 계열 식들에서 볼 수 있는 것처럼 레귤라이저를 사용하며 레귤라이저는 한 단계의 탐색을 통해 planning하여 계산하는 특수한 타겟을 사용하고 이는 뮤즐리가 뮤제로의 성능을 낼 수 있는 기반이 됩니다.

그리고 미래의 time step들에 있는 policy들로 들어가는 model loss와 value, reward의 학습이 auxiliary loss로서 작용해 네트워크가 hidden state들을 안정적으로 학습하게끔 해 주어 성능을 끌어올려 줍니다. 이러한 loss들이 모두 합해져서 네트워크가 한번에 학습되게 됩니다. (Loss들에 대한 자세한 설명은 Loss설명글에 작성)

아래는 뮤즐리 논문에 소개된 pseudocode입니다. (Self-play, 미니배치 구성, loss구성, advantage normalization, target network parameter update 등의 내용을 담고 있는데 자세한 설명은 상세구현글에 작성)


학습 결과

아래는 학습이 완료될때까지 Tensorboard를 통해 확인한 LunarLander-v2 self-play시의 점수, 게임들의 에피소드 길이 및 마지막 reward에 대한 그래프들입니다. 


Score 그래프, 게임 길이 및 마지막 reward 그래프 (파랑,빨강)

이러한 결과들로부터 agent가 어떻게 환경을 학습해나가는지 볼 수 있는데, 먼저 LunarLander 환경에 대해 간략한 설명을 하겠습니다. LunarLander환경은 우주선이 깃발 사이의 평평한 곳에 착륙하게끔 하는 목적을 가진 환경으로, 처음에 우주선이 랜덤한 힘과 방향으로 튕겨지게 되는데 이를 엔진으로 제어해 착륙 지점 가까이 온 뒤 부서지지 않게 천천히 착륙하고 엔진을 꺼야 하는 환경입니다.

환경의 보상은 대략 이러한 형태로 이루어지는데, 우주선이 좋은 궤적으로 중심 가까이를 따라 내려오는 경우 그 과정에서 약간의 보상들이 주어지고, 착륙해 엔진을 끄고 정지하는데 성공하면 +100의 보상을 주고, 우주선이 넘어지거나 너무 빨리 내려와 부서지면 -100의 보상을 줍니다. 

학습 초반에는 모델이 환경을 학습해가며 이리저리 랜덤하게 움직이다가, 점점 땅에 박으면 안되고 화면 밖으로 나가면 안되고 시간제한을 넘으면 안되는 것을 배우고, 좋은 궤적을 따라 중심으로 내려오는 것이 좋다는 것을 조금씩 학습해가면서 땅에 가까워지려 하는 모습을 보입니다.

학습 중반에는 내려가는 궤적에서 주는 보상을 따라 착지하는 것까지는 성공을 하지만 당장 바닥에 붙어서 괜히 크게 움직이다 -100을 받기 싫어서, 혹은 아직 +100의 착륙을 경험하지 못해 좌우 엔진을 조금씩 써가며 버티는 모습을 보입니다. 그렇게 데이터를 모으면서 내려가는 궤적을 조금씩 개선하는데, 이러한 구간이 대략 0에서 50근처의 점수로 올라가는 구간입니다. 깃발 근처까지 좋은 궤적으로 내려오면서 약 150정도의 점수를 쌓고, 바닥에서 버티다 시간제한을 넘어가 -100을 받는 것을 반복합니다.

그러다가 우연히 착륙에 성공해 +100을 받게 되고 이것이 데이터로 쌓인 뒤로는 점점 땅에 붙어 비비는 시간이 줄어들고 바로 착륙하려는 모습이 강해지면서 학습이 완료되게 됩니다. 아래는 좋은 궤적을 따라 내려가서 이상적으로 착륙했을 경우 reward가 쌓여가는 모습을 보여주는 그래프입니다.
(착륙 성공한 경우) 한 game play 내에서 reward 쌓이는 모습


해당 game play 영상


이렇게 agent가 단계적으로, 전략적으로 환경을 학습해나가는 것을 직접 들여다보니 강화학습 알고리즘이라는 것이 참 신기한 물건이라는 생각이 듭니다.



참 여러모로 읽기 쉽지 않은 논문이었고 구현과정도 여러 난관들이 있었지만 일단 붙잡고 몇번이고 다시 읽어보고 온갖 군데를 돌아다니다 보면 조금씩 이해가 늘고 어딘가 숨어 있는 좋은 자료들이 찾아지면서 처음에는 막막해서 몸이 움직이지 않던 난관들이 하나씩 넘어가지는 소중한 경험들을 할 수 있었던 프로젝트였던 것 같습니다.

추후 구현을 더 다듬고 확장해서 Atari와 같은 환경에 적용해보려 합니다. GitHub - Itomigna2/Muesli-lunarlander: Muesli RL algorithm implementation (PyTorch) (LunarLander-v2)에 해당 프로젝트를 올려두었고, 혹시 질문이 있으시면 E-mail : emtgit2@gmail.com 로 메일 주세요!

참고링크

논문 저자 발표자료 : https://icml.cc/virtual/2021/poster/10769
뮤제로 supplementary information : https://www.nature.com/articles/s41586-020-03051-4 페이지 아래쪽


코드

class Representation(nn.Module): 
    """Representation Network

    Representation network produces hidden state from observations.
    Hidden state scaled within the bounds of [-1,1]. 
    Simple mlp network used with 1 skip connection.

    input : raw input
    output : hs(hidden state) 
    """
    def __init__(selfinput_dimoutput_dimwidth):
        super().__init__()
        self.skip = torch.nn.Linear(input_dim, output_dim)  
        self.layer1 = torch.nn.Linear(input_dim, width)
        self.layer2 = torch.nn.Linear(width, width)
        self.layer3 = torch.nn.Linear(width, width) 
        self.layer4 = torch.nn.Linear(width, width)  
        self.layer5 = torch.nn.Linear(width, output_dim)     
        
    def forward(selfx):
        s = self.skip(x)
        x = self.layer1(x)
        x = torch.nn.functional.relu(x)
        x = self.layer2(x)
        x = torch.nn.functional.relu(x)
        x = self.layer3(x)
        x = torch.nn.functional.relu(x)
        x = self.layer4(x)
        x = torch.nn.functional.relu(x)
        x = self.layer5(x)    
        x = torch.nn.functional.relu(x+s)
        x = 2*(x - x.min(-1,keepdim=True)[0])/(x.max(-1,keepdim=True)[0] - x.min(-1,keepdim=True)[0])-1 
        return x


class Dynamics(nn.Module): 
    """Dynamics Network

    Dynamics network transits (hidden state + action) to next hidden state and inferences reward model.
    Hidden state scaled within the bounds of [-1,1]. Action encoded to one-hot representation. 
    Zeros tensor is used for action -1.
    
    Output of the reward head is categorical representation, instaed of scalar value.
    Categorical output will be converted to scalar value with 'to_scalar()',and when 
    traning target value will be converted to categorical target with 'to_cr()'.
    
    input : hs, action
    output : next_hs, reward 
    """
    def __init__(selfinput_dimoutput_dimwidthaction_space):
        super().__init__()
        self.layer1 = torch.nn.Linear(input_dim + action_space, width)
        self.layer2 = torch.nn.Linear(width, width) 
        self.hs_head = torch.nn.Linear(width, output_dim)
        self.reward_head = nn.Sequential(
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,support_size*2+1)           
        ) 
        self.one_hot_act = torch.cat((F.one_hot(torch.arange(0, action_space) % action_space, num_classes=action_space),
                                      torch.zeros(action_space).unsqueeze(0)),
                                      dim=0).to(device)        
       
    def forward(selfxaction):
        action = self.one_hot_act[action.squeeze(1)]
        x = torch.cat((x,action.to(device)), dim=1)
        x = self.layer1(x)
        x = torch.nn.functional.relu(x)
        x = self.layer2(x)
        x = torch.nn.functional.relu(x)
        hs = self.hs_head(x)
        hs = torch.nn.functional.relu(hs)
        reward = self.reward_head(x)    
        hs = 2*(hs - hs.min(-1,keepdim=True)[0])/(hs.max(-1,keepdim=True)[0] - hs.min(-1,keepdim=True)[0])-1
        return hs, reward


class Prediction(nn.Module): 
    """Prediction Network

    Prediction network inferences probability distribution of policy and value model from hidden state. 

    Output of the value head is categorical representation, instaed of scalar value.
    Categorical output will be converted to scalar value with 'to_scalar()',and when 
    traning target value will be converted to categorical target with 'to_cr()'.
        
    input : hs
    output : P, V 
    """
    def __init__(selfinput_dimoutput_dimwidth):
        super().__init__()
        self.layer1 = torch.nn.Linear(input_dim, width)
        self.layer2 = torch.nn.Linear(width, width) 
        self.policy_head = nn.Sequential(
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,output_dim)           
        ) 
        self.value_head = nn.Sequential(
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,width),
            nn.ReLU(),
            nn.Linear(width,support_size*2+1)           
        ) 
   
    def forward(selfx):
        x = self.layer1(x)
        x = torch.nn.functional.relu(x)
        x = self.layer2(x)
        x = torch.nn.functional.relu(x)
        P = self.policy_head(x)
        P = torch.nn.functional.softmax(P, dim=-1
        V = self.value_head(x)      
        return P, V


"""
For categorical representation
reference : https://github.com/werner-duvaud/muzero-general
In my opinion, support size have to cover the range of maximum absolute value of 
reward and value of entire trajectories. Support_size 30 can cover almost [-900,900].
"""
support_size = 30
eps = 0.001

def to_scalar(x):
    x = torch.softmax(x, dim=-1)
    probabilities = x
    support = (torch.tensor([x for x in range(-support_size, support_size + 1)]).expand(probabilities.shape).float().to(device))
    x = torch.sum(support * probabilities, dim=1, keepdim=True)
    scalar = torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps))** 2 - 1)
    return scalar

def to_scalar_no_soft(x): ## test purpose
    probabilities = x 
    support = (torch.tensor([x for x in range(-support_size, support_size + 1)]).expand(probabilities.shape).float().to(device))
    x = torch.sum(support * probabilities, dim=1, keepdim=True)
    scalar = torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps))** 2 - 1)
    return scalar

def to_cr(x):
    x = x.squeeze(-1).unsqueeze(0)
    x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
    x = torch.clip(x, -support_size, support_size)
    floor = x.floor()
    under = x - floor
    floor_prob = (1 - under)
    under_prob = under
    floor_index = floor + support_size
    under_index = floor + support_size + 1
    logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).type(torch.float32).to(device)
    logits.scatter_(2, floor_index.long().unsqueeze(-1), floor_prob.unsqueeze(-1))
    under_prob = under_prob.masked_fill_(2 * support_size < under_index, 0.0)
    under_index = under_index.masked_fill_(2 * support_size < under_index, 0.0)
    logits.scatter_(2, under_index.long().unsqueeze(-1), under_prob.unsqueeze(-1))
    return logits.squeeze(0)


##Target network
class Target(nn.Module):
    """Target Network
    
    Target network is used to approximate v_pi_prior, q_pi_prior, pi_prior.
    It contains older network parameters. (exponential moving average update)
    """
    def __init__(selfstate_dimaction_dimwidth):
        super().__init__()
        self.representation_network = Representation(state_dim*8, state_dim*4, width) 
        self.dynamics_network = Dynamics(state_dim*4, state_dim*4, width, action_dim)
        self.prediction_network = Prediction(state_dim*4, action_dim, width) 
        self.to(device)

##Muesli agent
class Agent(nn.Module):
    """Agent Class"""
    def __init__(selfstate_dimaction_dimwidth):
        super().__init__()
        self.representation_network = Representation(state_dim*8, state_dim*4, width) 
        self.dynamics_network = Dynamics(state_dim*4, state_dim*4, width, action_dim)
        self.prediction_network = Prediction(state_dim*4, action_dim, width) 
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=0.00032, weight_decay=0)
        self.scheduler = PolynomialLRDecay(self.optimizer, max_decay_steps=4000, end_learning_rate=0.0000)   
        self.to(device)

        self.state_traj = []
        self.action_traj = []
        self.P_traj = []
        self.r_traj = []      

        self.state_replay = []
        self.action_replay = []
        self.P_replay = []
        self.r_replay = []   

        self.action_space = action_dim
        self.env = gym.make(game_name)     

        self.var = 0
        self.beta_product = 1.0

        self.var_m = [0 for _ in range(5)]
        self.beta_product_m = [1.0 for _ in range(5)] 


    def self_play_mu(selfmax_timestep=10000):        
        """Self-play and save trajectory to replay buffer

        (Originally target network used to inference policy, but i used agent network instead)

        Eight previous observations stacked -> representation network -> prediction network 
        -> sampling action follow policy -> next env step
        """      
        game_score = 0
        state = self.env.reset()
        state_dim = len(state)
        for i in range(max_timestep):   
            start_state = state
            if i == 0:
                stacked_state = np.concatenate((state, state, state, state, state, state, state, state), axis=0)
            else:
                stacked_state = np.roll(stacked_state,-state_dim,axis=0)                
                stacked_state[-state_dim:]=state

            with torch.no_grad():
                hs = self.representation_network(torch.from_numpy(stacked_state).float().to(device))
                P, v = self.prediction_network(hs)    
            action = np.random.choice(np.arange(self.action_space), p=P.detach().cpu().numpy())   
            state, r, done, info = self.env.step(action)                    
            
            if i == 0:
                for _ in range(8):
                    self.state_traj.append(start_state)
            else:
                self.state_traj.append(start_state)
            self.action_traj.append(action)
            self.P_traj.append(P.cpu().numpy())
            self.r_traj.append(r)

            game_score += r

            ## For fix lunarlander-v2 env does not return reward -100 when 'TimeLimit.truncated'
            if done:
                if (info['TimeLimit.truncated'] == Trueand abs(r)!=100:
                    game_score -= 100
                    self.r_traj[-1] = -100
                    r = -100
                last_frame = i
                break

        #print('self_play: score, r, done, info, lastframe', int(game_score), r, done, info, i)


        # for update inference over trajectory length
        for _ in range(5):
            self.state_traj.append(np.zeros_like(state))

        for _ in range(6):
            self.r_traj.append(0.0)
            self.action_traj.append(-1)  

        # traj append to replay
        self.state_replay.append(self.state_traj)
        self.action_replay.append(self.action_traj)
        self.P_replay.append(self.P_traj)
        self.r_replay.append(self.r_traj)  

        writer.add_scalars('Selfplay',
                           {'lastreward': r,
                            'lastframe': last_frame+1
                           },global_i)

        return game_score , r, last_frame


    def update_weights_mu(selftarget):
        """Optimize network weights.

        Iteration: 20
        Mini-batch size: 16 (4 replay, 4 seqeuences in 1 replay)
        Replay: Uniform replay without on-policy data
        Discount: 0.997
        Unroll: 5 step
        L_m: 5 step(Muesli)
        Observations: Stack 8 frame
        regularizer_multiplier: 5 
        Loss: L_pg_cmpo + L_v/6/4 + L_r/5/1 + L_m
        """

        for _ in range(20): 
            state_traj = []
            action_traj = []
            P_traj = []
            r_traj = []      
            G_arr_mb = []

            for epi_sel in range(4):
                if(epi_sel>-1):## replay proportion
                    sel = np.random.randint(0,len(self.state_replay)) 
                else:
                    sel = -1

                ## multi step return G (orignally retrace used)
                G = 0
                G_arr = []
                for r in self.r_replay[sel][::-1]:
                    G = 0.997 * G + r
                    G_arr.append(G)
                G_arr.reverse()
                
                for i in np.random.randint(len(self.state_replay[sel])-5-7,size=4):
                    state_traj.append(self.state_replay[sel][i:i+13])
                    action_traj.append(self.action_replay[sel][i:i+5])
                    r_traj.append(self.r_replay[sel][i:i+5])
                    G_arr_mb.append(G_arr[i:i+6])                        
                    P_traj.append(self.P_replay[sel][i])


            state_traj = torch.from_numpy(np.array(state_traj)).to(device)
            action_traj = torch.from_numpy(np.array(action_traj)).unsqueeze(2).to(device)
            P_traj = torch.from_numpy(np.array(P_traj)).to(device)
            G_arr_mb = torch.from_numpy(np.array(G_arr_mb)).unsqueeze(2).float().to(device)
            r_traj = torch.from_numpy(np.array(r_traj)).unsqueeze(2).float().to(device)
            inferenced_P_arr = []

            ## stacking 8 frame
            stacked_state_0 = torch.cat((state_traj[:,0], state_traj[:,1], state_traj[:,2], state_traj[:,3],
                                         state_traj[:,4], state_traj[:,5], state_traj[:,6], state_traj[:,7]), dim=1)


            ## agent network inference (5 step unroll)
            first_hs = self.representation_network(stacked_state_0)
            first_P, first_v_logits = self.prediction_network(first_hs)      
            inferenced_P_arr.append(first_P)

            second_hs, r_logits = self.dynamics_network(first_hs, action_traj[:,0])    
            second_P, second_v_logits = self.prediction_network(second_hs)
            inferenced_P_arr.append(second_P)

            third_hs, r2_logits = self.dynamics_network(second_hs, action_traj[:,1])    
            third_P, third_v_logits = self.prediction_network(third_hs)
            inferenced_P_arr.append(third_P)

            fourth_hs, r3_logits = self.dynamics_network(third_hs, action_traj[:,2])    
            fourth_P, fourth_v_logits = self.prediction_network(fourth_hs)
            inferenced_P_arr.append(fourth_P)

            fifth_hs, r4_logits = self.dynamics_network(fourth_hs, action_traj[:,3])    
            fifth_P, fifth_v_logits = self.prediction_network(fifth_hs)
            inferenced_P_arr.append(fifth_P)

            sixth_hs, r5_logits = self.dynamics_network(fifth_hs, action_traj[:,4])    
            sixth_P, sixth_v_logits = self.prediction_network(sixth_hs)
            inferenced_P_arr.append(sixth_P)


            ## target network inference
            with torch.no_grad():
                t_first_hs = target.representation_network(stacked_state_0)
                t_first_P, t_first_v_logits = target.prediction_network(t_first_hs)  


            ## normalized advantage
            beta_var = 0.99
            self.var = beta_var*self.var + (1-beta_var)*(torch.sum((G_arr_mb[:,0] - to_scalar(t_first_v_logits))**2)/16)
            self.beta_product *= beta_var
            var_hat = self.var/(1-self.beta_product)
            under = torch.sqrt(var_hat + 1e-12)


            ## L_pg_cmpo first term (eq.10)
            importance_weight = torch.clip(first_P.gather(1,action_traj[:,0])
                                        /(P_traj.gather(1,action_traj[:,0])),
                                        01
            )
            first_term = -1 * importance_weight * (G_arr_mb[:,0] - to_scalar(t_first_v_logits))/under


            ##lookahead inferences (one step look-ahead to some actions to estimate q_prior, from target network)
            with torch.no_grad():                                
                r1_arr = []
                v1_arr = []
                a1_arr = []
                for _ in range(self.action_space): #sample <= N(action space), now N    
                    action1_stack = []
                    for p in t_first_P:             
                        action1_stack.append(np.random.choice(np.arange(self.action_space), p=p.detach().cpu().numpy()))    
                    hs, r1 = target.dynamics_network(t_first_hs, torch.unsqueeze(torch.tensor(action1_stack),1))
                    _, v1 = target.prediction_network(hs)

                    r1_arr.append(to_scalar(r1))
                    v1_arr.append(to_scalar(v1) )
                    a1_arr.append(torch.tensor(action1_stack))               

            ## z_cmpo_arr (eq.12)
            with torch.no_grad():   
                exp_clip_adv_arr = [torch.exp(torch.clip((r1_arr[k] + 0.997 * v1_arr[k] - to_scalar(t_first_v_logits))/under, -11))
                                    .tolist() for k in range(self.action_space)]
                exp_clip_adv_arr = torch.tensor(exp_clip_adv_arr).to(device)
                z_cmpo_arr = []
                for k in range(self.action_space):
                    z_cmpo = (1 + torch.sum(exp_clip_adv_arr[k],dim=0) - exp_clip_adv_arr[k]) / self.action_space 
                    z_cmpo_arr.append(z_cmpo.tolist())
            z_cmpo_arr = torch.tensor(z_cmpo_arr).to(device)


            ## L_pg_cmpo second term (eq.11)
            second_term = 0
            for k in range(self.action_space):
                second_term += exp_clip_adv_arr[k]/z_cmpo_arr[k] * torch.log(first_P.gather(1, torch.unsqueeze(a1_arr[k],1).to(device)))
            regularizer_multiplier = 5 
            second_term *= -1 * regularizer_multiplier / self.action_space


            ## L_pg_cmpo               
            L_pg_cmpo = first_term + second_term


            ## L_m
            L_m  = 0
            for i in range(5):
                stacked_state = torch.cat(( state_traj[:,i+1], state_traj[:,i+2], state_traj[:,i+3], state_traj[:,i+4],
                                            state_traj[:,i+5], state_traj[:,i+6], state_traj[:,i+7], state_traj[:,i+8]), dim=1)
                with torch.no_grad():
                    t_hs = target.representation_network(stacked_state)
                    t_P, t_v_logits = target.prediction_network(t_hs) 

                beta_var = 0.99
                self.var_m[i] = beta_var*self.var_m[i] + (1-beta_var)*(torch.sum((G_arr_mb[:,i+1] - to_scalar(t_v_logits))**2)/16)
                self.beta_product_m[i]  *= beta_var
                var_hat = self.var_m[i] /(1-self.beta_product_m[i])
                under = torch.sqrt(var_hat + 1e-12)

                with torch.no_grad():                                
                    r1_arr = []
                    v1_arr = []
                    a1_arr = []
                    for j in range(self.action_space):  
                        action1_stack = []
                        for _ in t_P:             
                            action1_stack.append(j)    
                        hs, r1 = target.dynamics_network(t_hs, torch.unsqueeze(torch.tensor(action1_stack),1))
                        _, v1 = target.prediction_network(hs)

                        r1_arr.append(to_scalar(r1))
                        v1_arr.append(to_scalar(v1) )
                        a1_arr.append(torch.tensor(action1_stack))    

                with torch.no_grad():   
                    exp_clip_adv_arr = [torch.exp(torch.clip((r1_arr[k] + 0.997 * v1_arr[k] - to_scalar(t_v_logits))/under,-11))
                                        .tolist() for k in range(self.action_space)]
                    exp_clip_adv_arr = torch.tensor(exp_clip_adv_arr).to(device)

                ## Paper appendix F.2 : Prior policy
                t_P = 0.967*t_P + 0.03*P_traj + 0.003*torch.tensor([[0.25,0.25,0.25,0.25for _ in range(16)]).to(device) 

                pi_cmpo_all = [(t_P.gather(1, torch.unsqueeze(a1_arr[k],1).to(device)) 
                                * exp_clip_adv_arr[k])
                                .squeeze(-1).tolist() for k in range(self.action_space)]                
        
                pi_cmpo_all = torch.tensor(pi_cmpo_all).transpose(0,1).to(device)
                pi_cmpo_all = pi_cmpo_all/torch.sum(pi_cmpo_all,dim=1).unsqueeze(-1)
                kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
                L_m += kl_loss(torch.log(inferenced_P_arr[i+1]), pi_cmpo_all) 
            
            L_m/=5


            ## L_v
            ls = nn.LogSoftmax(dim=-1)

            L_v = -1 * (
                (to_cr(G_arr_mb[:,0])*ls(first_v_logits)).sum(-1, keepdim=True)
                +  (to_cr(G_arr_mb[:,1])*ls(second_v_logits)).sum(-1, keepdim=True)
                +  (to_cr(G_arr_mb[:,2])*ls(third_v_logits)).sum(-1, keepdim=True)
                +  (to_cr(G_arr_mb[:,3])*ls(fourth_v_logits)).sum(-1, keepdim=True)
                +  (to_cr(G_arr_mb[:,4])*ls(fifth_v_logits)).sum(-1, keepdim=True)
                +  (to_cr(G_arr_mb[:,5])*ls(sixth_v_logits)).sum(-1, keepdim=True)
            )

            ## L_r     
            L_r = -1 * (
                (to_cr(r_traj[:,0])*ls(r_logits)).sum(-1, keepdim=True)
                + (to_cr(r_traj[:,1])*ls(r2_logits)).sum(-1, keepdim=True)
                + (to_cr(r_traj[:,2])*ls(r3_logits)).sum(-1, keepdim=True)
                + (to_cr(r_traj[:,3])*ls(r4_logits)).sum(-1, keepdim=True)
                + (to_cr(r_traj[:,4])*ls(r5_logits)).sum(-1, keepdim=True)
            )


            ## start of dynamics network gradient *0.5
            first_hs.register_hook(lambda grad: grad * 0.5)
            second_hs.register_hook(lambda grad: grad * 0.5
            third_hs.register_hook(lambda grad: grad * 0.5)    
            fourth_hs.register_hook(lambda grad: grad * 0.5)   
            fifth_hs.register_hook(lambda grad: grad * 0.5)  
            sixth_hs.register_hook(lambda grad: grad * 0.5)  


            ## total loss
            L_total = L_pg_cmpo + L_v/6/4 + L_r/5/1 + L_m   
          
            
            ## optimize
            self.optimizer.zero_grad()
            L_total.mean().backward()
            nn.utils.clip_grad_value_(self.parameters(), clip_value=1.0)
            self.optimizer.step()
            

            ## target network(prior parameters) moving average update
            alpha_target = 0.01 
            params1 = self.named_parameters()
            params2 = target.named_parameters()
            dict_params2 = dict(params2)
            for name1, param1 in params1:
                if name1 in dict_params2:
                    dict_params2[name1].data.copy_(alpha_target*param1.data + (1-alpha_target)*dict_params2[name1].data)
            target.load_state_dict(dict_params2)          


        self.scheduler.step()

        self.state_traj.clear()
        self.action_traj.clear()
        self.P_traj.clear()
        self.r_traj.clear()

        writer.add_scalars('Loss',{'L_total': L_total.mean(),
                                  'L_pg_cmpo': L_pg_cmpo.mean(),
                                  'L_v': (L_v/6/4).mean(),
                                  'L_r': (L_r/5/1).mean(),
                                  'L_m': (L_m).mean()
                                  },global_i)
        
        writer.add_scalars('vars',{'self.var':self.var,
                                   'self.var_m':self.var_m[0]
                                  },global_i)
        
        return



%rm -rf scalar/
%load_ext tensorboard
%tensorboard --logdir scalar --port=6008

device = torch.device('cuda:0'if torch.cuda.is_available() else torch.device('cpu')
score_arr = []
game_name = 'LunarLander-v2' 
env = gym.make(game_name) 
target = Target(env.observation_space.shape[0], env.action_space.n, 128)
agent = Agent(env.observation_space.shape[0], env.action_space.n, 128)  
print(agent)
env.close()

## initialization
target.load_state_dict(agent.state_dict())

## Self play & Weight update loop
episode_nums = 4000
for i in range(episode_nums):
    writer = SummaryWriter(logdir='scalar/')
    global_i = i    
    game_score , last_r, frame= agent.self_play_mu()       
    writer.add_scalar('score', game_score, global_i)
    
    score_arr.append(game_score)  
    if i%10==0:
        mean_score = np.mean(np.array(score_arr[i-30:i]))
        #print('episode, avg, score, last_r, len\n', i, mean_score, int(game_score), last_r, frame)

    if i%100==0:
        torch.save(agent.state_dict(), 'weights.pt'

    if mean_score > 250 and np.mean(np.array(score_arr[-5:])) > 250:
        torch.save(agent.state_dict(), 'weights.pt')
        print('Done')
        break

    agent.update_weights_mu(target) 
    writer.close()

torch.save(agent.state_dict(), 'weights.pt'
agent.env.close()




댓글

가장 많이 본 글

구글 람다(LaMDA)란? - 구글의 언어 모델

알파고 강화학습 원리

버텍스 AI란? - 구글 인공지능 플랫폼

카타고와 바둑 두어보기

뉴럴 네트워크란?

블로그 글 목록

뉴럴 네트워크를 학습시키는 방법