9-4 LSTM 순환 신경망을 사용한 텍스트 분류
-LSTM
먼 영어 단어 사이의 관계를 식별하기 위한 시간 단계
Hochreiter와 Schmidhuber가 1997년에 발명한 이 모델은 기울기 소멸 문제를 극복하여 긴 시퀀스를 성공적으로 모델링합니다.
셀 구조
– TensorFlow로 LSTM 순환 신경망 구축
1. 순환 신경 LSTM 네트워크의 구조
from tensorflow.keras.layers import LSTM
model_lstm = Sequential()
model_lstm.add(Embedding(1000, 32))
model_lstm.add(LSTM(8))
model_lstm.add(Dense(1, activation='sigmoid'))
model_lstm.summary()
"""
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_1 (Embedding) (None, None, 32) 32000
lstm (LSTM) (None, 8) 1312
dense_2 (Dense) (None, 1) 9
=================================================================
Total params: 33,321
Trainable params: 33,321
Non-trainable params: 0
_________________________________________________________________
"""
2. 모델 교육
model_lstm.compile(optimizer="adam", loss="binary_crossentropy", metrics=('accuracy'))
history = model_lstm.fit(x_train_seq, y_train, epochs=10, batch_size=32,
validation_data=(x_val_seq, y_val))
"""
Epoch 1/10
625/625 (==============================) - 27s 39ms/step - loss: 0.4462 - accuracy: 0.8015 - val_loss: 0.3765 - val_accuracy: 0.8376
Epoch 2/10
625/625 (==============================) - 24s 38ms/step - loss: 0.3355 - accuracy: 0.8598 - val_loss: 0.3745 - val_accuracy: 0.8422
Epoch 3/10
625/625 (==============================) - 23s 37ms/step - loss: 0.3124 - accuracy: 0.8705 - val_loss: 0.3617 - val_accuracy: 0.8422
Epoch 4/10
625/625 (==============================) - 22s 34ms/step - loss: 0.2948 - accuracy: 0.8770 - val_loss: 0.3690 - val_accuracy: 0.8352
Epoch 5/10
625/625 (==============================) - 23s 36ms/step - loss: 0.2787 - accuracy: 0.8859 - val_loss: 0.3740 - val_accuracy: 0.8394
Epoch 6/10
625/625 (==============================) - 23s 36ms/step - loss: 0.2720 - accuracy: 0.8860 - val_loss: 0.4034 - val_accuracy: 0.8324
Epoch 7/10
625/625 (==============================) - 23s 37ms/step - loss: 0.2581 - accuracy: 0.8921 - val_loss: 0.3881 - val_accuracy: 0.8324
Epoch 8/10
625/625 (==============================) - 24s 38ms/step - loss: 0.2456 - accuracy: 0.8982 - val_loss: 0.4137 - val_accuracy: 0.8256
Epoch 9/10
625/625 (==============================) - 22s 35ms/step - loss: 0.2371 - accuracy: 0.9015 - val_loss: 0.4231 - val_accuracy: 0.8342
Epoch 10/10
625/625 (==============================) - 23s 37ms/step - loss: 0.2259 - accuracy: 0.9083 - val_loss: 0.4448 - val_accuracy: 0.8306
"""
3. 손실 및 정확도 플롯 그리기
plt.plot(history.history('loss'))
plt.plot(history.history('val_loss'))
plt.show()
plt.plot(history.history('accuracy'))
plt.plot(history.history('val_accuracy'))
plt.show()
4. 유효성 검사 세트의 정확도 평가
loss, accuracy = model_lstm.evaluate(x_val_seq, y_val, verbose=0)
print(accuracy)
##출력: 0.8306000232696533
※ 내용