이번 포스팅은 Back propagation에 대해서 적어보려고 한다. 어마어마한 양의 식이 포스팅될 예정이다. 후 이거 공부하려고 오늘 하루 다 썼다.


우선 back propagation이 뭔지부터 설명해보자. back propagation은 ANN를 training하는 과정이다. 우리가 이전 포스팅에서 ANN의 단 한번의 forwarding으로는 전혀 그럴 듯한 결과를 얻지 못한다고 했었다. 그래서 weight을 update한다고 했었다. 그 weight update하는 과정을 우리는 미분을 써서 dJ/dw가 0이 되는 지점을 찾는다고 했다. 즉 weight을 update하면서 J는 최소값으로 가까워진다는 말이다. 이것을 하는 과정이 back propagation이다. 즉 ANN의 training하는 알고리즘 중 하나가 back propagation algorithm이다.


자 그러면 미분을 시작해보자. 우선 뭘 어떻게 해야하는지 outline을 잡고 시작해보자. 우선 우리는 dJ/dw를 구할 것인데 우리가 예제로 삼고 있는 XOR ANN은 두 개의 weight vector가 있다.

하나 W(1)은 Input layer에서 Hidden layer로 갈 때 사용하는 weight vector이고, 또 다른 하나 W(2)는 Hidden layer에서 Output layer로 갈 때 사용하는 weight vector이다.

이전의 J를 위의 w(1)과 w(2)로 미분을 하면 아래와 같다.

우리는 dJ/dw(1)와 dJ/dw(2)를 구해야 되는데 먼저 dJ/dw(2)를 구하도록 하자. 밑에 식에서 시그마(sum)는 일단 잠깐 빼고 시작하자.


아래의 식은 https://www.youtube.com/watch?v=GlcnxUlrtek 여기에서 나오는 식과 같다. 이 영상과 밑의 식을 같이 보시길 바란다. notation에 있는 숫자도 조금 틀릴 수 있다.

y와 yHat 그리고 z(3) 미분한 값은 각각 아래와 같다.


dz(3)/dw(2)는 왜 저 행렬이 될까? 간단하다. z(3) = aw(2)이기 때문에 w(2)로 미분해주면 a만 남는 것이다. 여기서 우리가 처음에 잠깐 생략했었던 sigma를 다시 넣어준다. 어떻게 넣어줄까? 바로 행렬 곱셈이다. 지금까지는 전부 스칼라 곱이었으나, 마지막에 행렬곱셈으로 바꿔주므로서 아래와 같이 summation을 할 수 있다. 

시그마(3) = -(y-yHat) f'(z(3))는 3x1 행렬이므로 위의 식을 행렬 곱셈으로 바꿔주기 위해서는 3x3 행렬을 시그마(3) 앞에 곱해줘야 되고, 적절한 순서로 summation해주기 위해 transpose matrix로 바꿔준다. 이 부분을 잘 이해해야 된다.


행렬곱셈으로 바꿔주므로서 우리는 앞서 생략했었던 summation도 다시 적용시켰다. 자 이제 dJ/w(1)을 구할 차례다. 이 값도 비슷하게 구할 수 있다. J를 w(1)으로 미분하여 시작하면 된다.

위의 식에서 w(2)^T = dz(3)/da(2) 이 부분.. 이 부분 나도 잘 모르겠다. 대충은 알겠는데, 왜 Transpose를 해줘야 되는지 정확하게는 모르겠다. 그거 빼고는 위의 식은 이전에 dJ/dw(2) 했을 때를 생각하면 이해가 될 것이다.


이제 dJ/dw(2)과 dJ/dw(1)을 구했으니 실제로 weight vector를 업데이트해주면 된다. 이 때 새로운 변수가 들어간다. 바로 learning rate이다.


new w(1) = old w(1) + (learning rate) * dJ/dw(1)

new w(2) = old w(2) + (learning rate) * dJ/dw(2)


이 식으로 w(1)과 w(2)를 update해준다. 이렇게 하면 이제 하나의 epoch가 끝난 것이다. w(1)과 w(2)가 update되었으니 x1과 x2를 update된 w(1)과 w(2)에 집어넣어서 위의 과정을 또 하면 된다. ANN은 이 과정을 무수히 많이 반복하여 error rate을 최대한 작아지게 만든다.(Training) 그 과정을 눈으로 보고 싶다면 아래의 링크를 통해 보면 된다.


http://www.emergentmind.com/neural-network


epoch를 반복반복반복할 수록 XOR 계산에 대한 predicted value가 정확히 0과 1은 아니지만 0.2 0.3 또는 0.97, 0.98로 가까워지는 것을 볼 수 있다. 이걸 만든 사람이 자기 개인 포스팅도 해놓았더라. 


http://mattmazur.com/2015/03/17/a-step-by-step-backpropagation-example/


우와. 정말 engineer다. engineer는 이래야 된다.ㅠㅠ 도무지 실제 예제를 설명해둔 것을 찾을 수가 없어서 자기가 직접 했덴다. 정말 나도 이런 engineer가 되어야 되는데...


아무튼 우리는 back propagation은 이만 끝내고 다음 포스팅으로 넘어가보자. 아마 다음 포스팅은 Word2Vec이 되지 않을까 싶다. 참고로 back propagation에 대한 실제 값을 통한 예제도 언젠간 반드시 포스팅할 거다. 

Posted by 빛나유
,