Data Mining

[Data Mining] Adaptive Boosting Algorithm

빛나유 2016. 9. 11. 02:55

Adaptive Boosting Algorithm이다. 우선 이 알고리즘 설명하기 전에 내가 왜 Boosting을 다시 하냐면... Boosting 공부했던거 되세기기 위해서 예전에 써둔거 다시 읽어봤는데 뭔가 많이 부족하고 맘에 안들어서 다시하려고 그런다.


그런데 문제는 이 포스팅 정확하지 않을 수도 있다는 것이다. 그러니까 내가 쓴 것이 맞는지 틀리는지는 있는지 없는지 모르는 독자분들이 잘 판단해주세요ㅋㅋ 


이전 포스팅에서 디테일이 많이 생략되어있드라. 그 디테일을 이번 포스팅에서 살려보려고 한다.


우선 Boosting의 기본 개념을 다시 한번 되세기자. 쉽게 말해서 Weak learner 여러개를 합쳐서 트레이닝한 것이 Strong learner 하나보다 좋다는 것이다. 초등학교 축구 선수 100명이 박지성 한명보다 낫다는 것이다.


우선 weak learner의 error rate을 구하는 방법부터 알아보자. Weak learner h(x)는 아래와 같이 error rate을 구한다.

뭔가 어렵다 error rate 구하는 방법을 예제로 알아보자. 이거 헷깔린다 잘 봐둬라.


h1(x<thr)값은 x<thr라는 조건이 참이면 1 거짓이면 -1이다. 단순하다. 그래서 Weak learner이다. 그렇게 구한 값을 y와 비교한다. I(h1(x) != y)는 h1(x) != y라는 조건이 참이면 1 거짓이면 0이다. 이렇게 한 후에 d1이랑 곱해준다. d1은 weight1이라고 보면 된다. 젤 처음의 weight은 항상 1/데이터개수 이기 때문에 0.1로 초기화되는 것이다. error rate은 d1*I(h1(x) != y)의 합이기 때문에 우리는 0.3이라는 error를 가지게 된다. 또 다른 예제를 봐보자.

여기서는 x>thr이다. 부등호 방향이 바뀌었다. 그 외에는 구하는 방법은 같으니 한번 연습장에 해보시길 바란다.


여기서 thr은 무엇이냐면, threshold이다. 그 값을 정하는 기준은 error rate을 제일 작게 하는 threshold를 선택하는 것이다. 즉, threshold 0.5부터 8.5까지 error rate을 구해서 error rate이 젤 작은 threshold를 구하는 것이다. 첫번째 예제에서는 그 threshold가 2.5였고, 두번째 예제에서는 5.5였던 것이다.


이렇게 error rate을 구한 후에는 그 weak learner를 평가해야 한다. 평가치는 a(t)로 나타낸다. t번째의 weak learner의 평가치라고 보면 된다.


그래서 지금까지 한 것을 정리하면 각각의 t번째의 weak learner는 각각 error rate와 a값을 가지고 있었다. 이 값을 가지고 우리는 weight을 update 해줄 것이다. update해줄 때 중요한 원리는 t-1 weak learner에서 틀렸던 것들에 대해서 t weak learner에서는 더 큰 weight을 두고 맞았던 것들에 대해서는 더 작은 weight을 둔다는 것이다. 우리도 시험볼 때 평소에 자주 틀리는 것들에 대해서 중점적으로 더 공부하고 평소에 잘 맞는 것들에 대해서는 감만 유지하는 식으로 준비하는 것과 비슷하다. 그래서 weight을 update하는 식은 h(x)가 맞았는지 틀렸는지에 따라 두가지 경우로 나뉠 수 있다. 아래와 같이 나타낼 수 있다.

(위의 식에서 d는 weight을 의미한다는 점 참고)


예측이 맞을 때는 exp(-a)이고 틀릴 때는 exp(a)이다. 틀리면 커지게 마련이고 맞으면 작아지게 마련이다. 위의 식에서 Zt는 우선 그런게 있다고 치자. 그냥 주어진 값이라고 알아두고 저거 어떻게 구하는지는 조금 있다가 설명하려고 한다.


지금까지 한 것들이 weight(d1, d2, ...)이 한번 update될 때까지의 전체적인 과정이다. 그러면 실제로 아래의 예제를 통해서 training을 해보자.

위의 경우에서는 d3까지만(weight3) 구했다. 왜냐?? d3까지만 구해도 되니까... 왜냐?? 최종적으로 여러개의 weak learner들을 하나로 합치는 공식이 아래와 같은데


이 공식을 이용해서 첫번째 iteration, 두번째 iteration에서의 맞고 틀리고를 구해보면 둘다 세 개씩 틀린다. 그런데 세번째 iteration은 다 맞는다. 그래서 "오 잘됐군ㅋ" 하고 세번째 iteration에서 멈추는 것이다.


참 생각보다 그렇게 어려운 포스팅은 아니라고 생각하는데 왜 이렇게 오래 걸렸는지 잘 모르겠다. 하루 걸렸다. 뭐 쉬엄쉬엄해서 그런 것도 있긴 한데.. 


아무튼 이번 포스팅은 여기서 마치려고 한다. 다음 포스팅은 뭐가 될까? 몇 가지 생각해두고 있는 것이 Hadoop, Mahout, Spark 관련된 것일 수도 있고 또 다른 것이 될 수도 있고.. 아 배고파 점심 먹으러 가야지...