One-step Actor-Critic implementation

개요

One-step Actor-Critic 알고리즘을 PyTorch 코드로 구현해보았으며 Colab에서 코드 실행 및 결과 확인이 가능합니다.

설명

One-step Actor-Critic 알고리즘은 REINFORCE 처럼 trajectory를 구하고 그로부터 cumulative future rewards G를 구해 업데이트 하는 대신, 한번의 시행 결과의 reward와 s'의 state-value 값을 더해서 업데이트에 사용하는 구조입니다(bootstrapping). Online, incremental한 구조로 업데이트가 진행됩니다.
위의 pseudocode는 아래의 코드로 표현됩니다.

    def play_and_update_actor_critic(self):
        game_score = 0
        state = self.env.reset() # env 시작
        I = 1            
        while True:
            output = self.P(torch.from_numpy(state).float().to(device)) # inference
            inferenced_v = self.V(torch.from_numpy(state).float().to(device))
            prob_distribution = Categorical(output) #확률분포 표현
            action = prob_distribution.sample() #확률분포로부터 action 선택
            state, r, done, _ = self.env.step(action.item()) # env 진행  
            gradient_policy_a_s = prob_distribution.log_prob(action)
            with torch.no_grad():
                inferenced_v_from_next_s = self.V(torch.from_numpy(state).float().to(device))
            if done==True:
                inferenced_v_from_next_s = 0
            delta = r + 0.99*inferenced_v_from_next_s - inferenced_v
            V_loss = -1 * I * delta.item() * inferenced_v            
            self.V_optimizer.zero_grad()
            V_loss.backward()
            self.V_optimizer.step()
            P_loss = -1 * I *  delta.item() * gradient_policy_a_s
            self.P_optimizer.zero_grad()
            P_loss.backward()
            self.P_optimizer.step() 
            I = 0.99 * I
            game_score += r                 
            if done:
                break
        return game_score

계산그래프상 PolicyNetwork랑 ValueNetwork의 충돌을 막기 위해서 delta는 .item()으로 값만 사용해 loss를 구성해야 합니다. (.detach() 도 가능)

아래는 학습시 측정한 게임 스코어 그래프입니다. 학습이 약간 불안정한 모습을 보입니다. 적당한 시점에 파라미터를 저장하는게 바람직해 보입니다.

One-step Actor-Critic(cartpole) 
  REINFORCE with baseline(cartpole)
REINFORCE(cartpole)

Full code

import gym
import torch
import torch.nn as nn
from torch.distributions import Categorical
import matplotlib.pyplot as plt
!pip install gym[classic_control]
#pip install gym[box2d] #for lunarlander
!apt update
!apt install xvfb
!pip install pyvirtualdisplay
!pip install gym-notebook-wrapper
import gnwrapper
!nvidia-smi
print(torch.cuda.is_available())

class Agent(nn.Module):
    def __init__(selfinput_dimoutput_dimwidth):
        super().__init__()
        self.P = PolicyNetwork(input_dim, output_dim, width)
        self.P.to(device)
        self.P.train()  
        self.P_optimizer = torch.optim.Adam(self.P.parameters(), lr=0.0003)
        self.V = ValueNetwork(input_dim, output_dim, width)
        self.V.to(device)        
        self.V.train()
        self.V_optimizer = torch.optim.Adam(self.V.parameters(), lr=0.0003)   ## ?? 0,01
        self.trajectory = []       
        self.env = gym.make(game_name)        

    def play_and_update_actor_critic(self):
        game_score = 0
        state = self.env.reset() # env 시작
        I = 1            
        while True:
            output = self.P(torch.from_numpy(state).float().to(device)) # inference
            inferenced_v = self.V(torch.from_numpy(state).float().to(device))
            prob_distribution = Categorical(output) #확률분포 표현
            action = prob_distribution.sample() #확률분포로부터 action 선택
            state, r, done, _ = self.env.step(action.item()) # env 진행  
            gradient_policy_a_s = prob_distribution.log_prob(action)
            with torch.no_grad():
                inferenced_v_from_next_s = self.V(torch.from_numpy(state).float().to(device))
            if done==True:
                inferenced_v_from_next_s = 0
            delta = r + 0.99*inferenced_v_from_next_s - inferenced_v
            V_loss = -1 * I * delta.item() * inferenced_v               
            self.V_optimizer.zero_grad()
            V_loss.backward()
            self.V_optimizer.step()
            P_loss = -1 * I *  delta.item() * gradient_policy_a_s            
            self.P_optimizer.zero_grad()
            P_loss.backward()
            self.P_optimizer.step() 
            I = 0.99 * I
            game_score += r                 
            if done:
                break
        return game_score

class PolicyNetwork(nn.Module):  
    def __init__(selfinput_dimoutput_dimwidth):
        super().__init__()
        self.layer1 = torch.nn.Linear(input_dim, width)
        self.layer2 = torch.nn.Linear(width, width) 
        self.layer3 = torch.nn.Linear(width, output_dim) 

    def forward(selfx):
        x = self.layer1(x)
        x = torch.nn.functional.relu(x)
        x = self.layer2(x)
        x = torch.nn.functional.relu(x)
        x = self.layer3(x)
        x = torch.nn.functional.softmax(x, dim=0)
        return x

class ValueNetwork(nn.Module): 
    def __init__(selfinput_dimoutput_dimwidth):
        super().__init__()
        self.layer1 = torch.nn.Linear(input_dim, width)
        self.layer2 = torch.nn.Linear(width, width) 
        self.layer3 = torch.nn.Linear(width, 1

    def forward(selfx):
        x = self.layer1(x)
        x = torch.nn.functional.relu(x)
        x = self.layer2(x)
        x = torch.nn.functional.relu(x)
        x = self.layer3(x)
        return x

device = torch.device('cuda:0'if torch.cuda.is_available() else torch.device('cpu')
score_arr = []
game_name = 'CartPole-v1' #LunarLander-v2
env = gym.make(game_name) 
agent = Agent(env.observation_space.shape[0], env.action_space.n, 128
print(agent)
env.close()

#Self play 및 weight update
episode_nums = 500 #LunarLander-v2는 1500
for i in range(episode_nums):
    game_score = agent.play_and_update_actor_critic()
    score_arr.append(game_score)  
    if i%50==0 : print('episode', i)    
torch.save(agent.state_dict(), 'weights.pt'
agent.env.close()

#Episode별 얻은 score
plt.plot(score_arr, label ='score')
plt.legend(loc='upper left')

#학습된 모델로 게임 play한 영상
agent.load_state_dict(torch.load("weights.pt"))
env = gnwrapper.LoopAnimation(gym.make(game_name)) 
state = env.reset()
for _ in range(200):
    with torch.no_grad():
        output = agent.P(torch.from_numpy(state).float().to(device)) # inference
        prob_distribution = Categorical(output) #확률분포 표현
        action = prob_distribution.sample() #확률분포로부터 action 선택
    env.render()
    state, rew, done, _ = env.step(action.item())
    if done:
        state = env.reset()
env.display()

댓글

가장 많이 본 글

구글 람다(LaMDA)란? - 구글의 언어 모델

알파고 강화학습 원리

버텍스 AI란? - 구글 인공지능 플랫폼

카타고와 바둑 두어보기

뉴럴 네트워크란?

블로그 글 목록

뉴럴 네트워크를 학습시키는 방법