본문 바로가기

Deep Learning

LSTM 쉽게 이해하기

이 글은 고려대학교 산업경영공학부 김성범 교수님의 유튜브 영상 [핵심 머신러닝] RNN, LSTM, GRU를 정리한 것임을 밝힙니다.

 

RNN 모델은 기울기 소실(Vanishing Gradient) 문제 때문에, 모델이 과거 시점의 정보를 기억하기 어렵다고 했었다.

이를 보완한 모델이 바로 LSTM(Long Short-Term Memory)이다. 이 글에서는 LSTM에 대해 알아보도록 하자.

 

LSTM(Long Short-Term Memory)

LSTM이 RNN과 다른 점은 크게 두 가지이다.

 

1. 세 가지 gate: Forget gate(ft), Input gate(it), Output gate(ot)

2. Cell state 구조 (ct): 장기적으로 정보를 유지

 

* gate, state는 vector이며 일차원 집합으로 볼 수 있다. (ex. ft = [0, 0.2, 0.1, 0.3 ...] )

LSTM의 구조

 

위 그림이 LSTM 모델을 도식화한 것이다. 

모델의 목적은 hidden state(ht)를 구하는 것이다! htwhy만 곱해주면 예측값을 구할 수 있기 때문이다.

LSTM의 hidden state를 구하는 과정은 RNN보다는 조금은 복잡하다..! 

 

여기서 곱셈은 elmentwise product, 합성곱을 의미한다.

1. ht를 구하기위해, output gate cell state tanh 함수를 취한 값을 곱해준다.

2. cell state 는 이전 시점의 cell state forget gate 를 곱한 것과 임시 cell state input gate 를 곱한  것의 합으로 이루어진다.

3. 임시 cell state 는 현재 시점의 x와 이전 시점의 h들을 각각 가중치로 곱한 선형결합을 tanh 씌운 형태가 된다.

 

LSTM 학습

LSTM 학습 프로세스

 LSTM이 파라미터를 어떻게 학습하는지 알아보자.

Gate & Cell state

3개의 gate는 sigmoid 함수가 취해져 0~1의 값을 갖는다.

따라서, gate 들은 각각 다른 가중치의 역할을 하게 된다.

보시다시피, 각 gate는 가중합으로 이루어져 있다.

sigmoid 함수(σ) 안의 수식을 보면, RNN에서 ht를 구할 때 활성화함수 안에 있는 수식과 유사하다는 것을 알 수 있다.

그런데, 여기서 각 gate에 있는 wxhwhh , bh의 파라미터는 각각 다른 값을 가진다. 

이 파라미터들은 비용함수를 최소화하는 방향으로 학습된다.

 

그렇다면, 각각의 gate는 어떠한 가중치의 역할을 할까?

1. forget gate: 불필요한 과거정보를 잊어버림

2. input gate: 현재 정보를 기억함

 

불필요한 정보는 잊고, 추가할 정보는 추가해서 cell state를 구성한다

 

3. output gate: 어떤 정보를 output으로 내보낼지 결정

 

예를 들어, 벡터 연산을 통해 ht가 구해지는 그림을 한 번 보자.

여기서 ⊙는 위에서 본 곱과 똑같은 연산(elementwise product) 이다.

정리

LSTM은 최종적으로 hidden state(ht)와 cell state(ct)를 도출하는 모델이다.

 

- hidden state: 단기기억 (Short term)

- cell state: 장기기억 (Long term)

 

LSTM은 RNN에 비해 파라미터의 수가 많다 보니, 학습 시간이 오래 걸린다는 단점이 있다. 

(실제로 내 프로젝트에 LSTM 모델을 사용할 때... 충분한 인내가 필요했다..) 

 

LSTM을 개선한 모델로, 셀의 정보를 gate에 추가한 peephole connection이나 모델을 경량화한 GRU가 있다.

이런 모델은 시간이 되면, 다뤄보도록 하겠다..! (GRU는 조만간 올릴 듯..!)

'Deep Learning' 카테고리의 다른 글

Sequence-to-Sequence(Seq2Seq)  (0) 2024.11.19
Attention 메커니즘  (0) 2024.05.15
Transformer  (0) 2024.04.28
[신경망] Neural Network와 Back propagation  (1) 2023.12.03
RNN 쉽게 이해하기  (2) 2023.09.14