본문 바로가기
개발 Tools/파이썬_Deep learning & ML

머신러닝 Stratified KFold

by 전컴반 2021. 7. 13.
반응형

Stratified KFold는 기존의 KFold의 단점을 보완하기 위해 나왔다.

만약에 암을 예측하는 프로그램을 만든다고 가정하자. 100명의 사람이 있는데 이중엔 암인 사람인 많아봐야 1-2 명 있을 것이다. 

기존의 KFold 로 나눈다면, 1-2명뿐이라 예측하기도 어렵고, 찾아내기도 어렵다. 이런 단점을 보완하고 너무 적거나 너무 많을 때 골고루 분류해주는 작업이다.

 

라이브러리

 

from sklearn.datasets import load_iris
import pandas as pd
from sklearn.model_selection import StratifiedKFold

 

분배

 

이번에도 붓꽃데이터를 가지고 와서 분류해보겠다. iris.target에는 0,1,2 이렇게 3가지 종류가 있다.  

바로 불러 와보겠다. 일부러 딱 떨어지지 않는 n_splits을 3으로 했다.

 

iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df["target"] = iris.target
print(df["target"].value_counts())

skf = StratifiedKFold(n_splits=3)
n_iter = 0

for train_index, test_index in skf.split(df, df["target"]):
    n_iter += 1
    target_train = df["target"].iloc[train_index]
    target_test = df["target"].iloc[test_index]
    print(f"교차검증: {n_iter}")
    print(f"학습 타겟 데이터 분포:\n{target_train.value_counts()}")
    print(f"검증 타겟 데이터 분포:\n{target_test.value_counts()}\n")


출력
교차검증: 1
학습 타겟 데이터 분포:
2    34
0    33
1    33
Name: target, dtype: int64
검증 타겟 데이터 분포:
0    17
1    17
2    16
Name: target, dtype: int64

교차검증: 2
학습 타겟 데이터 분포:
1    34
0    33
2    33
Name: target, dtype: int64
검증 타겟 데이터 분포:
0    17
2    17
1    16
Name: target, dtype: int64

교차검증: 3
학습 타겟 데이터 분포:
0    34
1    33
2    33
Name: target, dtype: int64
검증 타겟 데이터 분포:
1    17
2    17
0    16
Name: target, dtype: int64

 

0,1,2에 골고루 들어가는 걸 확인할  수 있다. 

반응형

댓글