scikit-learn의 class_weight 매개 변수는 어떻게 작동합니까?


116

class_weightscikit-learn의 Logistic Regression의 매개 변수가 어떻게 작동하는지 이해하는 데 많은 문제 가 있습니다.

그 상황

로지스틱 회귀를 사용하여 매우 불균형 한 데이터 세트에서 이진 분류를 수행하고 싶습니다. 등급은 0 (음성) 및 1 (양성)으로 표시되며 관찰 된 데이터의 비율은 약 19 : 1이며 대부분의 샘플은 음성 결과를 나타냅니다.

첫 번째 시도 : 수동으로 훈련 데이터 준비

훈련 및 테스트를 위해 보유한 데이터를 분리 된 세트로 분할했습니다 (약 80/20). 그런 다음 훈련 데이터를 무작위로 추출하여 19 : 1과 다른 비율로 훈련 데이터를 얻었습니다. 2 : 1-> 16 : 1.

그런 다음 이러한 다양한 훈련 데이터 하위 집합에 대한 로지스틱 회귀를 훈련하고 다른 훈련 비율의 함수로 재현율 (= TP / (TP + FN))을 플로팅했습니다. 물론 재현율은 관찰 된 비율이 19 : 1 인 분리 된 테스트 샘플에서 계산되었습니다. 다른 훈련 데이터에 대해 다른 모델을 훈련했지만 동일한 (분리 된) 테스트 데이터에서 모든 모델에 대한 재현율을 계산했습니다.

결과는 예상대로였습니다. 리콜은 2 : 1 훈련 비율에서 약 60 % 였고 16 : 1에 도달했을 때 다소 빠르게 감소했습니다. 리콜 률이 5 % 이상인 2 : 1-> 6 : 1 비율이 여러 개있었습니다.

두 번째 시도 : 그리드 검색

다음으로, 나는 다른 정규화 매개 변수를 테스트하고 싶어하고 그래서 GridSearchCV를 사용하여 여러 값의 그리드 만든 C매개 변수뿐만 아니라 class_weight매개 변수를. 네거티브 : 포지티브 훈련 샘플의 n : m 비율을 사전 언어로 번역하기 위해 class_weight다음과 같이 여러 사전을 지정한다고 생각했습니다.

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

나는 또한 포함 None하고 auto.

이번에는 결과가 완전히 엉망이었습니다. 내 모든 리콜은를 class_weight제외한 모든 값에 대해 아주 작게 (<0.05) 나왔습니다 auto. 따라서 class_weight사전 설정 방법에 대한 나의 이해 가 잘못 되었다고 가정 할 수 있습니다 . 흥미롭게도 class_weight그리드 검색에서 'auto' 의 값은 모든의 값에 대해 약 59 % 였는데 C1 : 1로 균형이 맞다고 생각 했나요?

내 질문

  1. class_weight훈련 데이터에서 실제로 제공하는 것과 다른 균형을 이루기 위해 어떻게 적절하게 사용 합니까? 특히 class_weight네거티브 : 포지티브 훈련 샘플의 n : m 비율을 사용하려면 어떤 사전을 전달해야 합니까?

  2. 다양한 class_weight사전을 GridSearchCV에 전달하는 경우 교차 검증 중에 사전에 따라 훈련 폴드 데이터의 균형을 재조정하지만 테스트 폴드에서 내 점수 함수를 계산하기 위해 실제 주어진 샘플 비율을 사용합니까? 모든 측정 항목이 관찰 된 비율의 데이터에서 나온 경우에만 유용하기 때문에 이것은 매우 중요합니다.

  3. 비율만큼 의 auto가치는 무엇입니까 class_weight? 나는 문서를 읽고 "데이터의 빈도에 반비례하여 균형을 잡는다"는 것은 단지 1 : 1로 만든다는 것을 의미한다고 가정합니다. 이 올바른지? 그렇지 않다면 누군가 명확히 할 수 있습니까?


class_weight를 사용하면 손실 함수가 수정됩니다. 예를 들어, 교차 엔트로피 대신 위그 트 교차 엔트로피가됩니다. againstdatascience.com/…
prashanth

답변:


123

먼저, 리콜 만하는 것은 좋지 않을 수 있습니다. 모든 것을 포지티브 클래스로 분류하여 100 % 리콜을 달성 할 수 있습니다. 일반적으로 AUC를 사용하여 매개 변수를 선택한 다음 관심있는 작동 지점 (예 : 지정된 정밀도 수준)에 대한 임계 값을 찾는 것이 좋습니다.

class_weight작동 방식 : 1 대신 class[i]with 샘플의 실수에 페널티 를줍니다 class_weight[i]. 따라서 클래스 가중치가 높을수록 클래스에 더 중점을두고 싶다는 의미입니다. 당신이 말한 바에 따르면 클래스 0은 클래스 1보다 19 배 더 자주 발생합니다. 따라서 class_weight클래스 0에 비해 클래스 1 의 값 을 늘려야합니다 . 예를 들어 {0 : .1, 1 : .9}입니다. (가) 경우 class_weight1의 합계가 1이되지 않습니다, 그것은 기본적으로 정규화 매개 변수를 변경합니다.

class_weight="auto"작동 방식에 대해서는 이 토론을 참조하십시오 . 개발 버전에서는을 사용할 수 있습니다 class_weight="balanced". 이것은 이해하기 더 쉽습니다. 기본적으로 큰 샘플만큼 많은 샘플을 가질 때까지 작은 클래스를 복제하는 것을 의미하지만 암시적인 방식으로합니다.


1
감사! 빠른 질문 : 명확성을 위해 리콜을 언급했으며 실제로 어떤 AUC를 내 측정으로 사용할지 결정하려고합니다. 내 이해는 매개 변수를 찾기 위해 ROC 곡선 아래 영역 또는 리콜 영역 대 정밀도 곡선 아래 영역을 최대화해야한다는 것입니다. 이런 식으로 매개 변수를 선택한 후 곡선을 따라 슬라이딩하여 분류 임계 값을 선택한다고 생각합니다. 이것이 당신이 의미하는 바입니까? 그렇다면 두 곡선 중 어느 것이 내 목표가 가능한 한 많은 TP를 캡처하는 것인지 살펴 보는 것이 가장 합리적입니까? 또한 scikit-learn에 대한 귀하의 작업과 기여에 감사드립니다 !!!
kilgoretrout

1
ROC를 사용하는 것이 더 표준적인 방법이라고 생각하지만 큰 차이는 없을 것입니다. 하지만 곡선에서 점을 선택하려면 몇 가지 기준이 필요합니다.
안드레아스 뮬러

3
@MiNdFrEaK Andrew가 의미하는 바는 추정기가 소수 클래스의 샘플을 복제하여 다른 클래스의 샘플이 균형을 이룬다는 것입니다. 암시적인 방식으로 오버 샘플링하는 것입니다.
Shawn TIAN

8
@MiNdFrEaK 및 Shawn Tian : SV 기반 분류기 '균형'을 사용할 때 더 작은 클래스의 샘플을 더 많이 생성 하지 않습니다 . 그것은 말 그대로 소규모 수업에서 저지른 실수를 처벌합니다. 그렇지 않다고 말하는 것은 실수이며 오해의 소지가 있습니다. 특히 더 많은 샘플을 생성 할 여유가없는 대규모 데이터 세트에서 그렇습니다. 이 답변은 수정해야합니다.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight 클래스 가중치는 알고리즘에 따라 다르게 사용됩니다. 선형 모델 (예 : 선형 SVM 또는 로지스틱 회귀)의 경우 클래스 가중치는 다음과 같이 손실 함수를 변경합니다. 등급 가중치에 따라 각 샘플의 손실에 가중치를 부여합니다. 트리 기반 알고리즘의 경우 분할 기준의 가중치를 다시 지정하는 데 클래스 가중치가 사용됩니다. 그러나이 재조정은 각 클래스의 샘플 가중치를 고려하지 않습니다.
prashanth

2

첫 번째 대답은 작동 방식을 이해하는 데 좋습니다. 하지만 실제로 어떻게 사용해야하는지 이해하고 싶었습니다.

요약

  • 노이즈가없는 중간 정도의 불균형 데이터의 경우 클래스 가중치 적용에 큰 차이가 없습니다.
  • 노이즈가 있고 불균형이 심한 데이터의 경우 클래스 가중치를 적용하는 것이 좋습니다.
  • param class_weight="balanced"은 수동으로 최적화하지 않아도 괜찮습니다.
  • 함께 class_weight="balanced"하면 더 진정한 이벤트 (높은 TRUE 리콜)을 캡처 할뿐 아니라 거짓 경고를받을 가능성이 더 높습니다 (TRUE 정밀도를 낮출)
    • 결과적으로 모든 거짓 긍정으로 인해 총 TRUE %가 실제보다 높을 수 있습니다.
    • AUC는 잘못된 경보가 문제인 경우 여기에서 잘못 안내 할 수 있습니다.
  • 불균형이 심한 경우에도 결정 임계 값을 불균형 %로 변경할 필요가 없습니다. 0.5를 유지하는 것이 좋습니다 (또는 필요에 따라 그 주변 어딘가).

NB

RF 또는 GBM을 사용할 때 결과가 다를 수 있습니다. sklearn에는 class_weight="balanced" GBM이 없지만 lightgbm 에는LGBMClassifier(is_unbalance=False)

암호

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.