프로젝트/코드프레소 체험단

이미지 데이터 처리를 위한 CNN 완벽 가이드 - MNIST 데이터셋 분류 CNN 모델 구현

KimCookieYa 2022. 3. 6. 18:35

model.summary()

 

import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

# input tensor와 target tensor 준비
# (60000, 28, 28), (60000,)
(train_x, train_y), (test_x, test_y) = mnist.load_data()

# 입력 데이터의 전처리
# 0과 1 사이로 정규화
train_x = train_x/255.
test_x = test_y/255.

# MLP input_data : 1-D tensor
# convolution input_data : 3-D tensor
# train_x[0].shape === (28, 28)
# 컨볼루션 층은 이미지가 컬러임을 감안해 3D tensor를 입력받지만, mnist는 흑백이므로 reshape으로 조정해줘야한다.
train_x = train_x.reshape((60000, 28, 28, 1))
test_x = test_x.reshape((10000, 28, 28, 1))

# One-Hot Encoding으로 label값 전처리
from tensorflow.keras.utils import to_categorical
train_y = to_categorical(train_y)
test_y = to_categorical(test_y)

# CNN 모델 디자인
from tensorflow.keras import models, layers

model = models.Sequential()
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3),
						activation='relu',
                        input_shape=(28, 28, 1)))
model.add(layers.MaxPool2D(pool_size=(2, 2)))
model.add(layers.Conv2D(filters=64, kernel_size=(3, 3),
						activation='relu'))
model.add(layers.MaxPool2D(pool_size=(2, 2)))
model.add(layers.Conv2D(filters=64, kernel_size=(3, 3),
						activation='relu'))

# 3D를 1D로 변환
model.add(layers.Flatten())

# Classification : Fully Connected Layer 추가
model.add(layers.Dense(units=64, activation='relu'))
model.add(layers.Dense(units=10, activation='softmax'))

# 모델의 학습 정보 설정
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

# 모델 학습
history = model.fit(x=train_x, y=train_y, epochs=15, batch_size=256, validation_split=0.2)

# 모델의 성능 평가
test_loss, test_accuracy = model.evaluate(x=test_x, y=test_y)

# 모델의 학습과정 시각화
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()