뉴럴 네트워크를 학습시키는 방법




앞서 작성한 글인 뉴럴 네트워크 개요 에서, 뉴럴 네트워크는 데이터 간의 연관관계를 기억하는 능력을 가지고 있고, 데이터로부터 연관관계를 표현할 수 있는 특정한 수식을 자동으로 찾아내는 물건이라고 이야기 하였습니다.

어떻게 구성되어 있길래 특정 수식을 자동으로 찾아낼 수 있는 것일까요? 조금 바꾸어 말하면 뉴럴 네트워크는 어떻게 학습을 하는 것일까요?

간단한 뉴럴 네트워크

아래 그림은 간단한 형태의 뉴럴 네트워크입니다.


뉴럴 네트워크는 가장 왼쪽에 입력 데이터가 들어오면, 가장 오른쪽에 출력 데이터를 내보내 줍니다. 중간에는 많은 선들과 동그라미가 있습니다.

들은 각각 값을 가지고 있으며, 선은 자신의 왼쪽 동그라미의 값과 자신의 값을 곱하여 오른쪽 동그라미로 넘겨줍니다. 동그라미는 들어오는 값을 모두 더하는 역할을 합니다. 그리고 더한 값이 0보다 크면 오른쪽으로 값을 내보내고, 0보다 작으면 아무 것도 하지 않습니다. (내부는 좀 더 복잡한 수식들로 구성되어 있지만, 이 글에서는 생략하고 넘어가도록 하겠습니다.)

이런 구조가 한층, 한층 반복되며 연결되어 있습니다. 전체를 함께 보면 복잡해 보이지만 일부만 따로 떼서 본다면,
1번째 동그라미 x 0.1 + 2번째 동그라미 x 0.2 + 3번째 동그라미 x 0.5 와 같은 형태로 표현이 가능하게 됩니다.

이 구조가 의미하는 바는 무엇일까요?

뉴럴 네트워크에 존재하는 선들의 값이 어떤 수식을 구성하는 데에 사용된다는 것입니다. 선들의 값을 바꾸면, 뉴럴 네트워크가 표현하는 수식 또한 바뀌게 됩니다. 

뉴럴 네트워크 학습의 핵심은 선들의 값을 잘 찾아내는 것

올바른 선들의 값을 찾아낸다면, 연관관계를 잘 표현하는 수식을 만들어 낼 수 있게 되고, 결국 연관관계를 잘 기억해내는 잘 학습된 뉴럴 네트워크가 완성되는 것입니다.

그런데, 선들의 값을 찾는 것은 어려운 문제입니다. 최신 뉴럴 네트워크들은 수천만개, 많게는 수천억개 이상을 가지고 있습니다. 이 선들의 값을 어떻게 찾아내야할까요?

수십년 동안 수많은 방법들이 제시되어 왔는데, 그중 많이 사용되는 역전파 알고리즘에 대해서 설명드리도록 하겠습니다.

역전파 알고리즘은 수식으로 표현하는 것이 정확하지만, 이를 쉽게 풀어 이야기드릴까 합니다. 먼저, 역전파 알고리즘이 사용되려면 데이터셋이 준비되어야 합니다.

데이터셋과 뉴럴 네트워크

실제 데이터셋 
(입력:이미지, 출력:숫자로 구성된 데이터셋)


이상적인 뉴럴 네트워크
(이상적인 네트워크는 데이터셋의 입력-출력을 잘 따라해준다.)

이렇게 준비된 데이터셋을 뉴럴 네트워크가 잘 따라하게 만드는 것이 목적이며 이 과정에서 역전파 알고리즘이 사용됩니다.

학습되기 전에 이미지를 넣어보면?

네트워크는 아직 학습되지 않았기 때문에 데이터셋을 잘 따라하지 못합니다. 강아지 이미지를 넣었을 때 (1,0)이 나오는 게 이상적이지만 지금은 (0.4, -0.7)로 잘못된 값이 나옵니다. (0.4, -0.7)이 아니라 (1,0)이 나오게 선의 값을 변경해야 합니다.

조금 바꾸어 말하면, 이상(데이터셋)과 현실(현재 네트워크의 동작)의 차이를 최소화하는 방향으로 선의 값을 변경해야 하는 것입니다. 현재 차이를 계산해보면 (0.6, 0.7)이며, 이만큼 잘못 동작하고 있다는 것입니다.

이제 이 둘의 차이를 최소화하는 것에 집중하면, 뉴럴 네트워크를 올바르게 학습시킬 수 있습니다. 이 둘의 차이를 손실 함수(Loss function)라고 이야기합니다. 손실 함수를 최소화하기 위해서는 선을 어떻게 수정해야할까요?

수많은 선 중 한개의 선을 A라고 할때, 이 A의 값을 바꾸면, 손실 함수를 작아지게 할 수 있습니다.

그러기 위해서는 A의 값 변화와 손실 함수 값 사이에 어떤 관계가 있는지 알아내야 합니다. A의 값의 변화에 따라 손실 함수가 커질 수도 있고 작아질 수도 있습니다. 이 정보는 미분 계산을 통해서 구해낼 수 있습니다. (자세한 계산은 글 아래 링크를 참조)

이렇게 알아낸 정보를 기반으로 A값을 손실 함수가 줄어드는 방향으로 수정하는 것이 역전파 알고리즘의 아이디어입니다.

이러한 역전파 알고리즘을 A선,B선,C선...뉴럴 네트워크에 존재하는 모든 선들에 대해 사용합니다. 또, 데이터셋에 있는 다른 사진들에 대해서도 같은 과정을 반복합니다. 이미지를 넣어봐서 손실 함수를 계산하고, 미분 계산을 하여, 선들의 값을 변경하는 과정을 계속 반복하는 것입니다. 

역전파 알고리즘을 반복해서 사용하면

잘못된 값을 내보내던 네트워크는 역전파 알고리즘을 반복할수록 손실 함수의 값이 0에 점점 가까워지게 되며,

결국 올바른 동작을 하는 네트워크가 되게 됩니다.

글을 마치며

정리하자면 뉴럴 네트워크를 학습하는 것은 뉴럴 네트워크가 최대한 데이터셋에 가까운 출력을 내게끔 선의 값을 계속 수정하면서 뉴럴 네트워크를 완성시키는 것입니다.

뉴럴 네트워크를 학습시키는 것은 꼭 역전파 알고리즘만 있는 것은 아니며, 뉴럴 네트워크의 형태도 꼭 위와 같은 형태로 구성되어 있지 않습니다.뉴럴 네트워크들은 다양한 형태와 수식으로 구성되어 있으며, 학습 알고리즘 또한 개선된 방법들이 계속해서 발표되고 있습니다.

하지만 이러한 개선된 뉴럴 네트워크와 학습 알고리즘들도 데이터셋으로부터 올바른 수식을 찾아낸다는 아이디어에서 크게 벗어나지 않을 것입니다.

이 글에서는 수학적 설명이나 코드 없이 설명하면서 생략한 부분이 많습니다. 혹시 조금 더 자세히 뉴럴 네트워크에 대해서 공부해보고 싶으신 분은 "밑바닥부터 시작하는 딥러닝"이라는 책을 참고해 보시면 좋을 것 같습니다. 찾아보니 3편까지 나왔는데, 1,2편은 pdf가 웹상에 공개되어 있어서 무료로 보실 수 있습니다. 


댓글

가장 많이 본 글

구글 람다(LaMDA)란? - 구글의 언어 모델

알파고 강화학습 원리

버텍스 AI란? - 구글 인공지능 플랫폼

카타고와 바둑 두어보기

뉴럴 네트워크란?

블로그 글 목록