개요
Categorical reparametrization은 뮤제로나 뮤즐리에서 다양한 reward scale에도 robust한 학습을 가능하게 하기 위해 사용된 기법 중 하나로, 뮤제로 논문에 소개되어 있습니다.
Reward, value가 -1,0,1로 표현이 될 수 있는 보드게임류 환경과는 다르게 general MDP환경들(Atari, LunarLander 등)의 reward 범위는 다양합니다. Atari와 같은 경우는 수백만 이상의 값이 들어가기도 하고, LunarLander도 -몇백에서 +몇백 정도의 값과, -0.03, -0.3과 같은 작은 값도 학습이 되야 하는 reward 구조를 가지고 있습니다.
이러한 general MDP의 경우 reward model, value model을 학습할 때 스칼라 값을 그대로 mse loss를 사용해 regression해버리면 loss는 매우 크게 변동하게 되고, reward model, value model 이전에 연결되어 있는 representation network, dynamics network들까지 크게 영향을 받아 네트워크 전체가 무너지게 되며 안정적인 학습이 불가능하게 됩니다.
때문에 이러한 문제를 해결하기 위해서 사용된 categorical reparametrization은 크게 두가지 동작으로 구성이 되는데, 첫번째는 스케일 조절 기능이 있는(squash하는) invertible transform을 통해서 크기를 조절하고, 그 값을 categorical한 표현으로 바꿔서 사용하고 학습할 수 있게 합니다. (반대 방향으로도 동작)
이것이 가능하게 하려면 먼저 reward model, value model도 categorical한 표현을 출력하는 구조여야 하고, 인퍼런스된 categorical한 표현을 스칼라 값으로 바꿔주는 함수와, 스칼라 값을 categorical한 표현으로 바꿔주는 함수가 필요합니다.
Categorical to scalar
(여기서 사용되는 invertible transform 식은 아래와 같음, eps=0.001)
scalar = torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps))** 2 - 1)
예시)
1. reward model, value model로부터 인퍼런스된 categorical representation
[-7.9622, -8.4434, -8.0816, -8.2011, -8.8170, -8.2907, -8.4424, -8.6694,
-7.3396, -5.0082, -3.3734, -1.9358, -1.6974, -1.5668, -0.7269, -0.1581,
-0.0480, -0.0675, 0.1998, 0.5082, 0.7352, 0.9445, 0.8766, 0.2834,
-0.4747, -1.5115, -3.3294, -4.5949, -5.0458, -5.0399, -2.8231, -4.5544,
-4.0168, -4.2494, -3.2774, -2.8336, -3.3649, -4.2941, -4.9554, -5.8936,
-5.9546, -7.0445, -8.0818, -8.4032, -7.9950, -8.3060, -8.1257, -8.5512,
-8.4400, -8.1185, -8.0508, -8.2104, -8.0700, -8.6658, -7.8899, -8.1327,
-8.0833, -7.8514, -8.6610, -8.1846, -8.4705]
2. softmax 처리
[2.1449e-05, 1.3256e-05, 1.9035e-05, 1.6891e-05, 9.1236e-06, 1.5443e-05,
1.3269e-05, 1.0574e-05, 3.9975e-05, 4.1142e-04, 2.1100e-03, 8.8843e-03,
1.1277e-02, 1.2850e-02, 2.9763e-02, 5.2562e-02, 5.8681e-02, 5.7550e-02,
7.5180e-02, 1.0234e-01, 1.2842e-01, 1.5831e-01, 1.4792e-01, 8.1735e-02,
3.8298e-02, 1.3580e-02, 2.2049e-03, 6.2200e-04, 3.9626e-04, 3.9859e-04,
3.6582e-03, 6.4774e-04, 1.1088e-03, 8.7873e-04, 2.3226e-03, 3.6202e-03,
2.1281e-03, 8.4031e-04, 4.3374e-04, 1.6974e-04, 1.5969e-04, 5.3698e-05,
1.9032e-05, 1.3800e-05, 2.0757e-05, 1.5208e-05, 1.8214e-05, 1.1901e-05,
1.3301e-05, 1.8344e-05, 1.9630e-05, 1.6734e-05, 1.9258e-05, 1.0613e-05,
2.3056e-05, 1.8086e-05, 1.9002e-05, 2.3961e-05, 1.0664e-05, 1.7172e-05,
1.2902e-05]
3. support 준비 (support size = 30)
[-30., -29., -28., -27., -26., -25., -24., -23., -22., -21., -20., -19.,
-18., -17., -16., -15., -14., -13., -12., -11., -10., -9., -8., -7.,
-6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.,
6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
30.]
4. softmax 처리된 표현과 support 텐서 곱해서 원소들 더하기
[-10.2160]
5. invertible transform 통해 스케일 늘려줘서 scalar값 출력
[-122.0749]
Scalar to categorical
(여기서 사용되는 invertible transform 식은 아래와 같음)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
예시)
1. 변환할 스칼라 값 받기
-194.1837
2. invertible transform을 통해서 스케일 줄이기
-13.1650
3. 스케일 줄여진 스칼라 값을 소숫점 윗자리 정수와 아랫자리로 분리(floor사용해서 내림+나머지)
-14.
0.8350
4. 소숫점 윗자리와 아랫자리 각각의 확률과 support에 들어갈때의 인덱스 구하기
0.1650, 16
0.8350, 17
5. 인덱스에 따라 빈 support에 확률 할당해주기
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1650, 0.8350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
해당 형태의 분포는 reward model, value model을 학습시킬 때 사용하는 cross entropy loss의 타겟으로 사용이 됩니다.
해당 기법이 적용된 LunarLander-v2 학습시의 loss그래프. Reward model 및 value model에 대한 loss인 L_r, L_v이 안정적으로 들어가는 모습을 볼 수 있음.
코드
support_size = 30
eps = 0.001
def to_scalar(x):
x = torch.softmax(x, dim=-1)
probabilities = x
support = (torch.tensor([x for x in range(-support_size, support_size + 1)]).expand(probabilities.shape).float().to(device))
x = torch.sum(support * probabilities, dim=1, keepdim=True)
scalar = torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps))** 2 - 1)
return scalar
def to_cr(x):
x = x.squeeze(-1).unsqueeze(0)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
x = torch.clip(x, -support_size, support_size)
floor = x.floor()
under = x - floor
floor_prob = (1 - under)
under_prob = under
floor_index = floor + support_size
under_index = floor + support_size + 1
logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).type(torch.float32).to(device)
logits.scatter_(2, floor_index.long().unsqueeze(-1), floor_prob.unsqueeze(-1))
under_prob = under_prob.masked_fill_(2 * support_size < under_index, 0.0)
under_index = under_index.masked_fill_(2 * support_size < under_index, 0.0)
logits.scatter_(2, under_index.long().unsqueeze(-1), under_prob.unsqueeze(-1))
return logits.squeeze(0)
댓글
댓글 쓰기