// BPTT?
BPTT는 RNN에서 계산되는 back propagation으로
sequential data의 특성으로 인해 발생하는 hidden state를 따라 역행하면서 전파되는 gradient의 계산 방법이다.
Back Propagation Through Time의 약자이며 사실 따지고 보면 단순한 역전파 계산과 동일하긴 하다.
아래에서 진행되는 연산은 캠프 내용중 과제로 나왔던 자료의 표기를 참고하여 작성되었다.
// Hidden state 에 연산되는 가중치를 위한 기울기
우선은 many-to-one 모델에서 생각해보자.
이 글의 전반에 걸쳐서 요 many-to-one 에만 집중할 것이다.
이것을 이해하면 나머지도 다 따라갈 수 있을 것이라 생각한다.
하나의 셀이 만드는 hidden state를 s 라고 할때 위첨자는 시계열 데이터의 시점을 나타내고
아래첨자는 벡터의 몇번째 원소인지 나타내준다.
이때 첫번째 시점에서 셀이 계산한 결과는 다음과 같다.
여기서 $W^{(r)}_{ip}$ 는 hidden state 와 계산되는 가중치 이고 $W^{(x)}_{iq}$ 는 입력데이터와 계산되는 가중치다.
$\sigma$는 activation function 이고 bias는 생략했다.
그럼 이제 계속해서 셀을 데이터를 따라 계산하며 진행할 수 있다.
그러면 이렇게 마지막 t 번째 시점의 데이터까지 전개된 것을 알 수 있다.
이제 마지막 t 번째 시점의 데이터로부터 output을 구하면
이렇게 loss function 까지 구하는 것이 가능하다.
여기서 $W^{(r)}_{ip}$ 의 $i,p$ 번째 원소의 gradient를 어떻게 구할 수 있을까가 바로 BPTT의 핵심 질문이다.
우선은 이걸 구해야 한다는 사실을 먼저 인지하고 그 다음으로 넘어가보자.
내가 위에 식들을 굉장히 자세하게 풀어 써 놓은 이유는 back propagation 에서 쉽게 계산 하기 위해서 이다.
우선 loss function은 위와 같이
$\bf{y}^{(t)}$ 의 함수이므로
우선은 이렇게 체인룰을 적용할 수 있다. 그리고 계속해서
여기까지 미분이 가능하다. 모든 원소에 대해서 적용해야 하는 미분은 벡터표시로 기울임 없이 볼드체를 사용했다.
문제는 지금부터 인데 저 $\bf{s}^{(t)}$ 를 미분할때 발생한다.
여기서 $\bf{h}^{(t)}$ 를 미분할때는 $W^{(r)}_{ip}$ 이 있으므로 $\bf{s}^{(t-1)}$ 만 남기면 될 것 같지만
$\bf{s}^{(t-1)}$ 가 다시 $W^{(r)}_{ip}$ 에 대한 함수 이므로 결국 이 둘에 대한 곱의 미분을 해주어야한다.
이렇게 되는 것이고 계속 미분을 해보면
이렇게 될것이고
이런식의 규칙성을 가지면서 끝까지 미분되게 된다.
잘 보면 한번 미분이 진행되면 $W^{(r)}_{ip}$ 가 한번씩 나오는 것을 볼 수 있는데
굳이 시간을 따라서 미분을 해주려는 어떤 작업을 하지 않아도
자연스럽게 수학은 우리를 시계열 미분으로 이끌어주는 것을 볼 수 있다.
그럼 저 위에 식을 좀 정리해보면
이렇게 바뀌는 것을 알 수 있다.
그런데 이건 $W^{(r)}_{ip}$ 에 대한 미분이고 $W^{(x)}_{iq}$ 에 대한 미분까지 구해야 진정한 BPTT를 구햇다고 할 수 있겠다.
// Input vector 에 연산되는 가중치를 위한 기울기
동일한 방식으로 진행한다.
우선은 loss function을 원하는 weight로 미분하는 것을 시작으로
이걸 최대한 자세하고 간단한 함수들의 도함수로 풀어서 쓴다.
저렇게 까지 되었다면 사실 마지막 도함수만 빼고 모두 알고있는 셈이다.
마지막 도함수를 구해보면
이렇게 바뀌는데 이유는 이전의 입력에 $W^{(x)}_{iq}$ 를 곱한 값이
이전 층으로 부터 넘어오는 hidden state를 만들기 때문에 미분을 해줘야 하기 때문이다.
잘보면 순전파의 matmul이 가진 연산 순서 때문에 $W^{(r)}_{ip}$ 의 아래첨자가 더해지는 방향이 바뀐것을 볼 수 있다.
모든 index 들은 dummy로 running 된다는 것에 목적을 두고 자세한 부분은 넘어가도록 하겠다.
이 작업을 계속해서 해주면 아래와 같은 식이 된다.
딱보면 아무런 의미도 없고 난해해 보이기만 한다.
$W^{(r)}_{ip}$ 처럼 의미있는 접근으로 수식을 정리할 수는 없을까?
back propagation 의 모토를 잘 생각하보면 우리는 목적지만 알 뿐
중간의 모든 과정을 수식으로 전개해서 알고자하는 것이 아니다.
특히 우리가 편미분을 통해서 도함수를 구할때는
다변수 함수에서 하나의 축에 사영된 변화량만을 신경쓰겠다는 의미이다.
그러므로 한번만 전개했던 식을 분배법칙으로 풀어써보면
이 식이
이렇게 나눠지고 더하기 우측의 항은 입력벡터 $x^{(t)}_{q}$와 $W^{(x)}_{iq}$ 의 곱에 대한 대한 미분임을 알 수 있고
좌측은 크로니컬델타와 같은 자세한 항들은 생략했지만
결국 $\bf{s}^{(t-1)}$으로 부터 순전파된 항들에대한 미분임을 알 수 있다.
따라서 아래와 같이 고치면
이런식으로 표현할 수 있게 되는데 우리가 이미 알고 있듯이 더하기 좌측에 있는 항으로 부터 계속해서
이전의 셀이 연산한 항에 대한 미분들이 나오게 되고 한번 더 해보면
이렇게 된다는 것을 알 수있고 결국은 쭉쭉 미분되면서
이렇게 전개 되고 t가 0이 될때까지 이 식은 계속 미분될 것이다.
너무 눈이 아프니까 좀 정리를 하면
체인룰을 이용해서 중간 도함수들을 좀 정리하면 이렇게 표현이 가능하고
이걸 더 깔끔하게 바꾸면
이렇게 정리가 가능해진다.
알고보면 각 시점에서 발생한 hidden state까지의 미분값을 모두 합친 것으로 정리된다.
// 마치며
Neural Network는 graphical 한 접근이 우선되고 동작하는 방법에 대한 이해가 먼저 되다보니
수학적으로 시원하게 설명되는 접근을 찾기가 힘들다.
특히나 빠르게 성장하는 이 분야는 어떤 주제가 핫하게 연구되던 시기를 놓쳐버리면
너무 순식간에 '잘 알려진 개념'이 되면서 더더욱 공부하기 힘들어지는 경향이 있다.
그래도 수학적으로 깔끔하게 생각을 정리하는 능력은 굉장히 큰 역량이라고 생각하고
지금도 계속 그런 자세를 유지하기 위해서 노력중이다.
나도 항상 BPTT에 대해서는 어딘가 숙제같은 느낌을 받으면서 살고있었는데
이번 기회를 통해서 많이 공부하고 발전한 것 같다.
항상 부지런하게 배우고 공부하자.
// 참고자료
https://mmuratarat.github.io/2019-02-07/bptt-of-rnn
https://arxiv.org/pdf/1610.02583.pdf
https://d2l.ai/chapter_recurrent-neural-networks/bptt.html#fig-rnn-bptt
'딥러닝 머신러닝 데이터 분석 > BoostCampAITech' 카테고리의 다른 글
[ Boost Camp ] Day-9 학습로그 (0) | 2021.08.10 |
---|---|
[ Boost Camp ] Day-8 학습로그 (0) | 2021.08.10 |
[ Boost Camp ] Day-4 학습로그 ( CNN, RNN ) (0) | 2021.08.06 |
[ Boost Camp ] Day-2 학습로그( 신경망 학습 ) (0) | 2021.08.06 |
[ Boost Camp ] Day-1 학습로그 (벡터, 행렬, 선형회귀, SGD) (0) | 2021.08.06 |
댓글