GridSearchCV API를 이용하여 최적의 모델 학습 시키기
API URL: sklearn.model_selection.GridSearchCV
- GridSearchCV() API의 Argument(인자)
- estimator: 구현하고자 하는 모델(Classification, Regression)
- param_grid: 모델의 튜닝에 사용될 파라미터 정보, 딕셔너리
- scoring: 검증 지표(성능 평가 지표)
- cv: Cross Validation의 Fold 숫자
- refit: 최적의 파라미터로 모델을 재 학습 시킬지 여부
* Grid Search CV를 활용하여 iris 데이터셋 분류 모델을 구현하는 실습
'''
-------- [최종 출력 결과] --------
Optimal parameter: {'max_depth': 3, 'min_samples_split': 2}
Max accuracy: 0.9750
Test accuracy: 0.9667
----------------------------------
'''
# 하이퍼파라미터 튜닝을 을 위한 GridSearchCV 라이브러리 로딩
from sklearn.model_selection import GridSearchCV
# 모델 구현을 위한 라이브러리 로딩
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
# load_iris() 메서드를 이용하여 iris 데이터 셋 로드
iris = load_iris()
# 학습, 테스트 데이터셋 분리
x_train, x_test, y_train, y_test = train_test_split(iris.data,
iris.target,
test_size=0.2,
random_state=121)
# DecisionTreeClassifier 모델 객체 생성
dtree = DecisionTreeClassifier()
# 모델의 후보 파라미터 셋(param_grid)을 지정한 딕셔너리 객체 생성
parameters = {'max_depth':[1,2,3], 'min_samples_split':[2,3]}
# GridSearchCV 객체 생성
grid_dtree = GridSearchCV(estimator=dtree, param_grid=parameters, cv=3, refit=True)
# GridSearchCV 객체의 fit() 메서드를 이용하여
# 후보 파라미터 셋의 성능 검증
grid_dtree.fit(x_train, y_train)
# 후보 파라미터 셋의 성능 검증 결과 출력
print('Optimal parameter:', grid_dtree.best_params_)
print('Max accuracy: {0:.4f}'.format(grid_dtree.best_score_))
# 최적의 파라미터 모델을 이용하여 예측값 생성
estimator = grid_dtree.best_estimator_
pred = estimator.predict(x_test)
# 최적의 파라미터 모델의 성능지표 출력
print('Test accuracy: {0:.4f}'.format(accuracy_score(y_test,pred)))
'프로젝트 > 코드프레소 체험단' 카테고리의 다른 글
이미지 데이터 처리를 위한 CNN 완벽 가이드 - Numpy를 이용한 인공신경망 데이터 이해 (0) | 2022.02.17 |
---|---|
이미지 데이터 처리를 위한 CNN 완벽 가이드 - Colab 기본 사용법 (0) | 2022.02.15 |
머신러닝을 위한 사이킷런 활용 팁 - 머신러닝 모델의 최적 파라미터 탐색 (0) | 2022.02.13 |
머신러닝을 위한 사이킷런 활용 팁 - K-Fold 교차 검증 실습 (0) | 2022.02.13 |
머신러닝을 위한 사이킷런 활용 팁 - 머신러닝 모델 선택 및 K-Fold 교차 검증 (0) | 2022.02.13 |