scikit의 다중 레이블 분류 지표


19

scikit을 사용하여 기존 문서에 주제를 할당하기 위해 다중 레이블 분류기를 작성하려고합니다.

나는 통해 전달 내 문서를 처리하고 TfidfVectorizer하여 통해 라벨 MultiLabelBinarizer과를 만들어 OneVsRestClassifierSGDClassifier추있다.

그러나 내 분류기를 테스트 할 때 나는 .29 까지의 점수 만 얻습니다.이 점수 는 비슷한 문제에 대해 꽤 낮습니다. 나는 스톱 워드, 유니 그램, 형태소 분석과 같은 TfidfVectorizer에서 여러 옵션을 시도했지만 그 결과를 크게 변화시키지 않는 것 같습니다.

나는 또한 GridSearchCV내 추정기에 가장 적합한 매개 변수를 얻는 데 사용 했으며 현재 다음에 무엇을 시도 해야하는지에 대한 아이디어가 없습니다.

동시에,에서 내가 사용할 수 없습니다 이해 무엇 scikit.metrics으로 OneVsRestClassifier어떻게 내가 그렇게 잘못이 무엇인지 알아 내기 위해 일부 메트릭 (F1, 정밀, 등 리콜)을받을 수 있나요?

내 데이터 모음에 문제가있을 수 있습니까?

업데이트 : 나는 또한 사용 해봤 CountVectorizerHashingVectorizer과에 파이프 라이닝 TfidfTransformer하지만 결과는 비슷합니다. 그래서 나는 bag-of-words 접근법이 토큰 화 도메인에서 최선을 다하고 있으며 나머지는 분류 기준에 달려 있다고 추측합니다 ...


1
0.29 측정이란 무엇입니까? 정확성? 다른 것?
Sycorax는 Reinstate Monica

@GeneralAbrial score분류기에서 실행되는 scikit 설명서에 따르면Returns the mean accuracy on the given test data and labels. In multi-label classification, this is the subset accuracy which is a harsh metric since you require for each sample that each label set be correctly predicted.
mobius

그게 당신이 한 일입니까? 귀하의 질문에서 이것이 확실하다는 것은 분명하지 않으므로 완벽하게 합리적인 질문입니다.
Sycorax는 Reinstate Monica

@GeneralAbrial 네, 이것이 제가 한 일입니다. 혼란을 드려 죄송합니다. 나는 질문을 개발이 아닌 더 이론적 인 모드로 유지하려고 노력했습니다.
mobius

여기에 코드를 추가 할 수 있습니까? 특히 SGD에 sample_weight = "balanced"를 사용하고 있습니까? 그러나 코드가 확인되면 다른 사항이있을 수 있습니다.
Diego

답변:


21

서브셋 정확도는 실제로 가혹한 메트릭입니다. 0.29가 얼마나 좋은지 나쁜지를 알기 위해서는 다음과 같은 아이디어가 필요합니다.

  • 각 샘플에 대해 평균적으로 몇 개의 레이블이 있는지 확인
  • 사용 가능한 경우 어노 테이터 간 계약을 검토하십시오 (그렇지 않은 경우 분류 자일 때 획득 한 서브 세트 정확도를 확인하십시오)
  • 주제가 잘 정의되어 있는지 생각
  • 각 라벨에 몇 개의 샘플이 있는지 살펴보십시오

해밍 점수를 계산하여 분류자가 실마리가 없는지 또는 그 대신 괜찮은지 확인하지만 모든 레이블을 올바르게 예측하는 데 문제가있을 수 있습니다. 해밍 점수를 계산하려면 아래를 참조하십시오.

동시에, 내가 이해 한 것으로부터 OneVsRestClassifier와 함께 scikit.metrics를 사용할 수 없으므로 어떤 메트릭 (F1, Precision, Recall 등)을 가져 와서 무엇이 잘못되었는지 알아낼 수 있습니까?

멀티 클래스 멀티 라벨 분류에 대한 정밀도 / 호출을 계산하는 방법을 참조하십시오 . . sklearn이 지원하는지 여부를 잊어 버렸습니다. sklearn은 혼란 매트릭스에 대해 다중 레이블을 지원하지 않습니다 . 이 숫자를 실제로 보는 것이 좋습니다.


해밍 점수 :

A의 다중 레벨 분류 설정 sklearn.metrics.accuracy_score만을 계산 부분 집합의 정확성 즉, 정확히 y_true에서 라벨의 해당 설정과 일치해야합니다 샘플에 대한 예측 라벨 세트 (3).

정확도를 계산하는이 방법은 언젠가는 정확히 일치 비율 (1)로 불 립니다 .

여기에 이미지 설명을 입력하십시오

정확도를 계산하는 또 다른 일반적인 방법은 (1) 및 (2)에 정의되어 있으며 해밍 점수 라고 불명확하게 나타납니다. (4) (해밍 손실과 밀접하게 관련되어 있기 때문에) 또는 레이블 기반 정확도 . 다음과 같이 계산됩니다.

여기에 이미지 설명을 입력하십시오

다음은 해밍 점수를 계산하는 파이썬 방법입니다.

# Code by /programming//users/1953100/william
# Source: /programming//a/32239764/395857
# License: cc by-sa 3.0 with attribution required

import numpy as np

y_true = np.array([[0,1,0],
                   [0,1,1],
                   [1,0,1],
                   [0,0,1]])

y_pred = np.array([[0,1,1],
                   [0,1,1],
                   [0,1,0],
                   [0,0,0]])

def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    '''
    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
    /programming//q/32239577/395857
    '''
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        #print('\nset_true: {0}'.format(set_true))
        #print('set_pred: {0}'.format(set_pred))
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        #print('tmp_a: {0}'.format(tmp_a))
        acc_list.append(tmp_a)
    return np.mean(acc_list)

if __name__ == "__main__":
    print('Hamming score: {0}'.format(hamming_score(y_true, y_pred))) # 0.375 (= (0.5+1+0+0)/4)

    # For comparison sake:
    import sklearn.metrics

    # Subset accuracy
    # 0.25 (= 0+1+0+0 / 4) --> 1 if the prediction for one sample fully matches the gold. 0 otherwise.
    print('Subset accuracy: {0}'.format(sklearn.metrics.accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)))

    # Hamming loss (smaller is better)
    # $$ \text{HammingLoss}(x_i, y_i) = \frac{1}{|D|} \sum_{i=1}^{|D|} \frac{xor(x_i, y_i)}{|L|}, $$
    # where
    #  - \\(|D|\\) is the number of samples  
    #  - \\(|L|\\) is the number of labels  
    #  - \\(y_i\\) is the ground truth  
    #  - \\(x_i\\)  is the prediction.  
    # 0.416666666667 (= (1+0+3+1) / (3*4) )
    print('Hamming loss: {0}'.format(sklearn.metrics.hamming_loss(y_true, y_pred))) 

출력 :

Hamming score: 0.375
Subset accuracy: 0.25
Hamming loss: 0.416666666667

(1) Sorower, Mohammad S. " 멀티 라벨 학습 알고리즘에 관한 문헌 조사. "Oregon State University, Corvallis (2010).

(2) Tsoumakas, Grigorios 및 Ioannis Katakis. " 다중 레이블 분류 : 개요. "그리스 테살로니키 아리스토텔레스 대학교 정보학과 (2006).

(3) Ghamrawi, Nadia 및 Andrew McCallum. " 집단적 다중 레이블 분류. "정보 및 지식 관리에 관한 제 14 차 ACM 국제 회의의 절차. ACM, 2005.

(4) Godbole, Shantanu 및 Sunita Sarawagi. " 다중 레이블 분류를위한 판별 방법. "지식 발견 및 데이터 마이닝의 발전. Springer Berlin Heidelberg, 2004. 22-30.


큰 대답, 그것은 단지 나를 더 좋게 만들었습니다 :) 나는 그것을 더 철저히 읽고, 해밍 점수를 시도하고 당신에게 돌아갑니다!
mobius

솔직히 말해서, 정확히 서브 세트 정확도 (정확한 일치 비율)가 무엇인지는 명확하지 않습니다. 좀 더 설명해 주시겠습니까? 멀티 클래스의 경우 이는 회상과 동일합니다.
Poete Maudit

hamming_scoreKeras상에서 기능 에러 : <ipython 입력-34-16066d66dfdd> hamming_score에서 (y_true, y_pred 정상화, sample_weight) 60 ''= 61 acc_list [] ---> 범위 내가 62 (y_true.shape [ 0]) : 63 set_true = set (np.where (y_true [i]) [0]) 64 set_pred = set (np.where (y_pred [i]) [0]) TypeError : 인덱스가 아닌 정수를 반환했습니다 (유형 없음 유형 )
rjurney

0

0.29 점수가 충분하지 않습니까? 혼동 행렬은 어떻게 생겼습니까? 단어 내용 만보고 분리 할 수없는 주제가 있습니까?

그렇지 않으면 문제를 해결해보십시오. 낮은 점수가 실제로 분류자가 데이터에서 수행 할 수있는 최선의 결과라는 가설. 이는이 방법을 사용하여 문서를 분류 할 수 없음을 의미합니다.

이 가설을 테스트하려면 알려진 단어 특성을 가진 일련의 테스트 문서가 필요합니다. 100 % 점수를 받아야합니다.

그렇지 않으면 버그가있는 것입니다. 그렇지 않으면 문서를 분류하기위한 다른 접근 방식이 필요합니다. 다른 학급의 문서는 어떻게 다릅니 까? 내 문서 등의 다른 기능을 봐야합니까?


숫자 외에도 0.29가 낮다는 것을 느낍니다. 훈련 된 모델을 사용하여 분류기에 수동으로 테스트하기 위해 이미 교육에 사용한 문서의 주제를 예측합니다. 사용자가 문서에 수동으로 입력 한 것과 같은 수의 주제를 얻을 수 없었습니다. 나는 보통 그것들의 부분 집합을 얻는다. 또한 혼란 매트릭스 질문에 관해서는, 나는 내가 scikit.metrics를 사용하여 OneVsRestClassifier에 혼란 행렬을 얻을 수 있다고 생각하지 않습니다 ... 그래도 난 그것을 확인합니다
뫼비우스의
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.