https://insengnewbie.tistory.com/215
전이학습은 특정 도메일 데이터를 분석하는데 좋은 성능을 발휘하는 최신 딥러닝 모델에 수천 만 건의 데이터를 학습시킨 모델을 가져다가 우리의 문제에 적용하는 기술이다. 전이 학습은 가장 성능이 좋은 모델, 그리고 방대한 양의 데이터의 학습이라는 두가지 장점을 쉽고 편하게 사용하여 딥러닝 기술의 확산에 큰 기여를 하였다. Tensorflow Keras 모듈의 applications API 는 전이학습을 위한 사전에 학습된 최신 모델들을 쉽게 사용할 수 있게 도와준다. 앞서 실습한 가위-바위-보 데이터 셋 분류 모델에 전이학습을 적용하여 좀 더 성능 좋은 VGG16 Backbone 모델 기반 CNN 모델을 구현하는 실습을 진행한다.
import tensorflow as tf
from tensorflow import keras
import os
""" Step 1. Input tensor 와 Target tensor 준비(훈련데이터)
Step 1-(1) 가위-바위-보 데이터셋 다운로드
"""
train_url = 'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/rps.zip'
test_url = 'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/rps-test-set.zip'
# 수강생 작성 코드
# 1. tensorflow.keras.utils 모듈의 get_file API를 이용하여 가위-바위-보 학습 데이터 셋 다운로드
keras.utils.get_file(fname='rps.zip', origin=train_url, extract=True, cache_dir='/content')
# 수강생 작성 코드
# 1. tensorflow.keras.utils 모듈의 get_file API를 이용하여 가위-바위-보 테스트 데이터 셋 다운로드
keras.utils.get_file(fname='rps-test-set.zip', origin=test_url, extract=True, cache_dir='/content')
""" Step 1-(2) ImageDataGenerator를 이용해 이미지 파일을 load 하기 위한 경로 지정"""
# 수강생 작성 코드
# 1. 저장된 학습, 테스트 데이터를 읽어 오기 위한 경로 정보 생성
# - hint : os.path.dirname() 메서드를 이용하여 데이터 셋이 저장된 경로 추출
# => /root/.keras/datasets/rps_test-set.zip => /root/.keras/datasets/
train_dir = os.path.dirname('/content/datasets/rps.zip')
test_dir = os.path.dirname('/content/datasets/rps-test-set.zip')
""" Step 1-(3) ImageDataGenerator 객체 생성
- 객체 생성 시 rescale 인자를 이용하여 텐서 내 원소의 범위를 [0 ~ 255] => [0 ~ 1] 로 ReScaling 진행
- 이미지의 사이즈를 VGG16 Backbone 모델에 적합한 (224, 224) 로 지정
- label 데이터는 one-hot-encoding 수행하세요
"""
# 수강생 작성 코드
# 1. ImageDataGenerator API 를 이용하여 학습 데이터와 테스트 데이터를 읽어오기 위한 객체 생성
# - feature 데이터를 [0, 1] 사이로 scailing을 수행하세요
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
""" - flow_from_directory() 메서드를 이용하여 학습데이터와 검증데이터를 위한 DirectoryIterator 객체 생성"""
# 수강생 작성 코드
# 1. ImageDataGenerator 객체의 flow_from_directory 메서드를 이용하여 데이터를 읽어오기 위한 정보를 설정하세요
# - 이미지의 사이즈를 VGG16 Backbone 모델에 적합한 (224, 224) 로 지정하세요
# - label 데이터는 one-hot-encoding 수행하세요
train_generator = train_datagen.flow_from_directory(directory=train_dir,
target_size=(224, 224),
batch_size=32,
shuffle=True,
class_mode='categorical')
test_generator = test_datagen.flow_from_directory(directory=test_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical')
""" Step 2. VGG16을 Backbone 으로 하는 모델 디자인 및 학습 정보 설정"""
""" Step 2-(1) Pre-trained 된 VGG16 모델 객체 생성
- imagenet 데이터를 이용해 학습된 모델 객체 생성
- classification layer 제외
"""
# 수강생 작성 코드
# 1. tensorflow keras 에서 제공하는 이미 학습된 VGG16 Backbone 모델 객체를 생성하세요
# - imagenet 데이터를 이용해 학습된 Backbone 모델 객체를 생성하세요
# - classification layer 제외하여 VGG16 모델의 Backbone 객체만 생성하세요
# - 이미지의 사이즈를 VGG16 Backbone 모델에 적합한 (224, 224) 로 지정하세요
from tensorflow.keras.applications import VGG16
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(224, 224, 3))
""" Step 2-(2) VGG16 Backbone 모델에 classification layer 추가"""
# 수강생 작성 코드
# 1. Sequential API를 이용하여 가위-바위-보 데이터셋 을 분석 하기 위한 CNN 모델을 디자인 하세요
# - VGG16의 Backbone 모델에 classification layer를 직접 추가하여 모델을 디자인 하세요
# - label 데이터를 one-hot-encoding 한 것을 반영 하여 모델을 디자인 하세요
from tensorflow.keras import models, layers
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(units=256, activation='relu'))
model.add(layers.Dropout(0.3))
model.add(layers.Dense(units=2, activation='softmax'))
conv_base.trainable = False
""" Step 3. 모델의 학습 정보 설정"""
# 수강생 작성 코드
# 1. tf.keras.Model 객체의 compile 메서드를 이용하여 학습을 위한 정보들을 설정하세요
# - optimizer : 기 학습된 VGG16 Backbone을 사용하기 때문에 optimizer 의 학습율을 작게 지정 하세요
# - loss : label 데이터를 one-hot-encoding 한 것을 반영 하여 지정하세요
# - metrics : 체점 기준인 accuracy 로 설정
model.compile(optimizer=keras.optimizers.RMSprop(learning_rate=2e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
""" Step 4. 모델에 데이터 generator 연결 후 학습"""
# 수강생 작성 코드
# 1. tf.keras.Model 객체의 fit 메서드를 이용하여 모델을 학습하세요
# - fit 메서드의 verbose=2 로 설정 하세요
history = model.fit(train_generator,
steps_per_epoch=len(train_generator),
epochs=30, verbose=2,
validation_data=test_generator,
validation_steps=len(test_generator))
import matplotlib.pyplot as plt
# 모델의 학습과정 시각화
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
Epoch 1/30 91/91 - 68s - loss: 0.2392 - accuracy: 0.9111 - val_loss: 0.0754 - val_accuracy: 0.9855 - 68s/epoch - 751ms/step
Epoch 2/30 91/91 - 51s - loss: 0.0562 - accuracy: 0.9893 - val_loss: 0.0220 - val_accuracy: 0.9979 - 51s/epoch - 564ms/step
Epoch 3/30 91/91 - 51s - loss: 0.0199 - accuracy: 0.9976 - val_loss: 0.0080 - val_accuracy: 1.0000 - 51s/epoch - 564ms/step
Epoch 4/30 91/91 - 51s - loss: 0.0072 - accuracy: 1.0000 - val_loss: 0.0029 - val_accuracy: 1.0000 - 51s/epoch - 565ms/step
...
Epoch 27/30 91/91 - 51s - loss: 1.3871e-07 - accuracy: 1.0000 - val_loss: 4.4807e-08 - val_accuracy: 1.0000 - 51s/epoch - 563ms/step
Epoch 28/30 91/91 - 51s - loss: 1.5441e-07 - accuracy: 1.0000 - val_loss: 3.7717e-08 - val_accuracy: 1.0000 - 51s/epoch - 565ms/step
Epoch 29/30 91/91 - 51s - loss: 1.7655e-07 - accuracy: 1.0000 - val_loss: 3.4378e-08 - val_accuracy: 1.0000 - 51s/epoch - 564ms/step
Epoch 30/30 91/91 - 51s - loss: 1.0927e-07 - accuracy: 1.0000 - val_loss: 3.2811e-08 - val_accuracy: 1.0000 - 51s/epoch - 563ms/step
'프로젝트 > 코드프레소 체험단' 카테고리의 다른 글
시계열 데이터 처리를 위한 RNN 완벽 가이드 - RNN 모델 (0) | 2022.03.16 |
---|---|
이미지 데이터 처리를 위한 CNN 완벽 가이드 - 완강 후기 (0) | 2022.03.15 |
이미지 데이터 처리를 위한 CNN 완벽 가이드 - DataAugmentation을 활용한 성능 개선 (0) | 2022.03.15 |
이미지 데이터 처리를 위한 CNN 완벽 가이드 - Transfer Learning (0) | 2022.03.14 |
이미지 데이터 처리를 위한 CNN 완벽 가이드 - DataAugmentation (0) | 2022.03.14 |