시계열 데이터 처리를 위한 RNN 완벽 가이드 - GRU 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현

2022. 3. 27. 00:29·프로젝트/코드프레소 체험단

# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence

print(tf.__version__)

""" Input tensor와 Target tensor 준비(훈련 데이터) """
# imdb의 빈도수 기준으로 상위 10000개의 데이터를 load함.
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=10000)

print(len(input_train))
# 25000
print(input_train.shape)
# (25000,)

""" 입력 데이터의 전처리
 - LSTM 모델에 데이터를 입력하기 위해 시퀀스 데이터의 길이를 통함
"""
input_train = sequence.pad_sequences(input_train, maxlen=800)
input_test = sequence.pad_sequences(input_test, maxlen=800)

print(input_train.shape)
# (25000, 800)

""" RNN 모델 디자인
 - embedding layer: 32차원
 - hidden layer: 1개(32)
 - activation: tanh
"""
gru_model = models.Sequential()

gru_model.add(layers.Embedding(input_dim=10000, output_dim=32))
gru_model.add(layers.GRU(units=32, activation='tanh'))
gru_model.add(layers.Dense(units=1, activation='sigmoid'))

gru_model.summary()

""" 모델의 학습 정보 설정
 - loss: binary crossentropy
 - optimiizer: rmsprop
 - metric: accuracy
"""
gru_model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics='acc')

""" 모델에 input, target 데이터 연결 후 학습
 - batch size: 128
 - epochs: 10
 - validation data set percent: 20%
"""
history = gru_model.fit(x=input_train, y=y_train, batch_size=128, epochs=10, validation_split=0.2)

""" 학습과정의 시각화 및 성능 테스트 """
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc)+1)

plt.plot(epochs, acc, 'b', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

""" 테스트 데이터 셋을 통한 성능 측정"""
test_loss, test_acc = model.evaluate(x=input_test, y=y_test)

 

* overfitting이 일어나는 것은 동일하지만, 단층 LSTM 모델보다 정확도가 1% 정도 높고 적층 LSTM 모델보다는 정확도가 낮은 모습을 보인다.

 

 

 

'프로젝트 > 코드프레소 체험단' 카테고리의 다른 글

파이썬으로 배우는 데이터 분석: NumPy - NumPy 라이브러리 소개  (0) 2022.03.27
파이썬으로 배우는 데이터 분석 - 데이터 분석을 위한 파이썬 라이브러리  (0) 2022.03.27
시계열 데이터 처리를 위한 RNN 완벽 가이드 - 적층 LSTM 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현  (0) 2022.03.27
시계열 데이터 처리를 위한 RNN 완벽 가이드 - LSTM 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현  (0) 2022.03.26
시계열 데이터 처리를 위한 RNN 완벽 가이드 - GRU 모델  (0) 2022.03.24
'프로젝트/코드프레소 체험단' 카테고리의 다른 글
  • 파이썬으로 배우는 데이터 분석: NumPy - NumPy 라이브러리 소개
  • 파이썬으로 배우는 데이터 분석 - 데이터 분석을 위한 파이썬 라이브러리
  • 시계열 데이터 처리를 위한 RNN 완벽 가이드 - 적층 LSTM 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현
  • 시계열 데이터 처리를 위한 RNN 완벽 가이드 - LSTM 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현
KimCookieYa
KimCookieYa
무엇이 나를 살아있게 만드는가
  • KimCookieYa
    쿠키의 주저리
    KimCookieYa
  • 전체
    오늘
    어제
    • 분류 전체보기 (576)
      • 혼잣말 (88)
      • TIL (3)
      • 커리어 (24)
        • Sendy (21)
        • 외부활동 기록 (2)
      • 프로젝트 (186)
        • 티스토리 API (5)
        • 코드프레소 체험단 (89)
        • Web3 (3)
        • Pint OS (16)
        • 나만무 (14)
        • 대회 (6)
        • 정글 FE 스터디 (16)
        • MailBadara (12)
        • github.io (1)
        • 인공지능 동아리, AID (5)
        • 졸업과제 (18)
        • OSSCA 2024 (1)
      • 크래프톤 정글 2기 (80)
      • IT (169)
        • 코딩 (4)
        • CS (18)
        • 에러 (5)
        • 블록체인 (23)
        • Front-End (40)
        • 알고리즘&자료구조 정리 (3)
        • 코딩테스트 (3)
        • BOJ 문제정리 (41)
        • WILT (12)
        • ML-Agents (4)
        • 강화학습 (1)
        • Android (0)
        • LLM (2)
      • 전공 (1)
        • 머신러닝 (1)
      • 자기계발 (20)
        • 빡공단X베어유 (2)
        • 독서 (15)
  • 블로그 메뉴

    • 홈
    • 방명록
    • Github
    • Velog
    • 관리
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Pint OS
    센디
    블록체인
    핀토스
    MailBadara
    Flutter
    코드프레소
    딥러닝
    OS
    해커톤
    머신러닝
    자바스크립트
    부산대
    글리치해커톤
    react
    파이썬
    프로그래머스
    사이드프로젝트
    리액트
    알고리즘
    RNN
    크래프톤정글
    나만무
    docker
    니어프로토콜
    NEAR Protocol
    JavaScript
    numpy
    졸업과제
    pintos
  • hELLO· Designed By정상우.v4.10.3
KimCookieYa
시계열 데이터 처리를 위한 RNN 완벽 가이드 - GRU 모델을 이용한 영화리뷰 데이터셋 분류 모델 구현
상단으로

티스토리툴바