프로젝트/코드프레소 체험단

파이썬으로 구현하는 머신러닝 : 회귀분석 - 사이킷런 활용한 라쏘 규제 실습

KimCookieYa 2022. 1. 18. 00:07

머신러닝에서는 과대적합(overfitting)을 줄이면서, 일반성을 가지는 모델을 생성하는 것이 중요함

이를 위해 사용되는 규제 기법에는 릿지(ridge) 회귀, 라쏘(lasso) 회귀 기법 등이 있음

 

라쏘 회귀(Lasso Regression)

 - 회귀계수의 절대값에 페널티를 부여하는 방식

 - 불필요한 회귀 계수를 0에 근사하도록 만들어 과대적합 개선

 - 주로 Feature Selection 의 목적으로 사용

 - L1 규제라고도 함

 - 모델러에 의해 지정된 alpha 값을 통해서 페널티를 조정할 수 있음

 

<적용 프로세스>
 1) alpha 값 정의
 2) Lasso(alpha) 클래스 객체 생성
 3) fit(X, y) 을 통해 학습 데이터 연결 및 규제 학습 수행
 4) predict(X) 통해 학습된 모델의 예측 수행
 5) score(X, y) 통해 R^2 값 확인(모델의 성능 지표 계산)

'''
-------- [최종 출력 결과] --------
Training-datasset R2 : 0.736
Test-datasset R2 : 0.693
Lasso Regression Coefficients :
RM         3.4
CHAS       1.8
RAD        0.3
ZN         0.1
INDUS      0.0
NOX       -0.0
AGE        0.0
TAX       -0.0
B          0.0
CRIM      -0.1
PTRATIO   -0.6
LSTAT     -0.7
DIS       -1.1
dtype: float64
----------------------------------
'''
# 필요한 라이브러리 로딩
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
# sklearn.linear_model 모듈의 Lasso 클래스 로딩
from sklearn.linear_model import Lasso

# 데이터셋 로딩
boston = load_boston()

# 데이터셋 분할
# test_size=0.3, random_state=12 로 지정하세요.
x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, test_size=0.3, random_state=12)

# 규제를 위한 alpha 값 초기화
# 학습시에는 alpha 값을 바꾸가면서 테스트해보시고,
# 최종 코드 제출시에는 0.1 로 지정후 제출하세요.
alpha = 0.1

# Lasso 클래스 객체 생성
lasso = Lasso(alpha=alpha)

# fit() 을 통한 규제 학습 수행
lasso.fit(x_train, y_train)

# predict() 를 통한 학습된 모델 기반 예측
lasso_pred = lasso.predict(x_test)

# score() 를 통해 회귀 모델의 R^2 출력
# 학습된 모델에 대한 R^2 계산
r2_train = lasso.score(x_train, y_train)
r2_test = lasso.score(x_test, y_test)
print('Training-datasset R2 : {0:.3f}'.format(r2_train))
print('Test-datasset R2 : {0:.3f}'.format(r2_test))

# 회귀 계수 저장을 위한 Seriess 객체 생성 및 출력
lasso_coef_table = pd.Series(data=np.round(lasso.coef_, 1), index=boston.feature_names)
print('Lasso Regression Coefficients :')
print(lasso_coef_table.sort_values(ascending=False))

# 막대그래프 시각화 
plt.figure(figsize=(10,5))
lasso_coef_table.plot(kind='bar')
plt.ylim(-10, 4)
plt.show()

라쏘 규제 실습 결과