R의 랜덤 포레스트 분류에서 예측 변수 세트의 상대적 중요성


31

randomForestR의 분류 모델에 대한 변수 집합의 상대적인 중요성을 결정하고 싶습니다 .이 importance함수는 MeanDecreaseGini각 개별 예측 변수에 대한 메트릭을 제공합니다 . 집합의 각 예측 변수를 합산하는 것만 큼 간단합니까?

예를 들면 다음과 같습니다.

# Assumes df has variables a1, a2, b1, b2, and outcome
rf <- randomForest(outcome ~ ., data=df)
importance(rf)
# To determine whether the "a" predictors are more important than the "b"s,
# can I sum the MeanDecreaseGini for a1 and a2 and compare to that of b1+b2?

답변:


46

먼저 중요성 측정 항목이 실제로 측정하는 내용을 명확히하고 싶습니다.

MeanDecreaseGini 는 훈련 중 분할 계산에 사용되는 Gini 불순물 지수를 기반으로하는 변수 중요도의 측정 값입니다. 일반적인 오해는 변수 중요도 메트릭이 AUC와 밀접한 관련이있는 모델 성능을 주장하는 데 사용되는 Gini를 참조하지만 이는 잘못된 것입니다. Breiman과 Cutler가 작성한 randomForest 패키지의 설명은 다음과 같습니다.

중요도
변수 m에서 노드 분할이 이루어질 때마다 두 하위 노드에 대한 지니 불순물 기준은 상위 노드보다 작습니다. 포리스트의 모든 트리에서 각 개별 변수에 대해 gini 감소를 더하면 순열 중요도 측정과 매우 일치하는 빠른 변수 중요도가 제공됩니다.

G=i=1ncpi(1pi)
ncpi

2 클래스 문제의 경우 50-50 샘플에 대해 최대화되고 균질 세트에 대해 최소화되는 다음 곡선이 생성됩니다. 2 종의 지니 불순물

I=GparentGsplit1Gsplit2

E[E[X|Y]]=E[X]

지금, 당신의 질문에 대답하는 것은 직접적으로는 하지 만 결합 MeanDecreaseGini를 얻기 위해 각 그룹의 모든 importances을 합산하지만, 당신에게 당신이 찾고있는 해답을 얻을 것이다 가중 평균을 계산 한 간단하게. 각 그룹 내에서 가변 주파수를 찾아야합니다.

다음은 R의 임의 포리스트 개체에서이를 가져 오는 간단한 스크립트입니다.

var.share <- function(rf.obj, members) {
  count <- table(rf.obj$forest$bestvar)[-1]
  names(count) <- names(rf.obj$forest$ncat)
  share <- count[members] / sum(count[members])
  return(share)
}

그룹의 변수 이름을 members 매개 변수로 전달하십시오.

이것이 귀하의 질문에 답변되기를 바랍니다. 관심이 있다면 그룹 중요도를 직접 얻는 함수를 작성할 수 있습니다.

편집 :
다음은 randomForest객체와 변수 이름이있는 벡터 목록이 주어진 그룹 중요도를 제공하는 함수입니다 . var.share이전에 정의한대로 사용합니다 . 입력 검사를 수행하지 않았으므로 올바른 변수 이름을 사용해야합니다.

group.importance <- function(rf.obj, groups) {
  var.imp <- as.matrix(sapply(groups, function(g) {
    sum(importance(rf.obj, 2)[g, ]*var.share(rf.obj, g))
  }))
  colnames(var.imp) <- "MeanDecreaseGini"
  return(var.imp)
}

사용 예 :

library(randomForest)                                                          
data(iris)

rf.obj <- randomForest(Species ~ ., data=iris)

groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
               Petal=c("Petal.Width", "Petal.Length"))

group.importance(rf.obj, groups)

>

      MeanDecreaseGini
Sepal         6.187198
Petal        43.913020

또한 그룹이 겹치는 경우에도 작동합니다.

overlapping.groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
                           Petal=c("Petal.Width", "Petal.Length"),
                           Width=c("Sepal.Width", "Petal.Width"), 
                           Length=c("Sepal.Length", "Petal.Length"))

group.importance(rf.obj, overlapping.groups)

>

       MeanDecreaseGini
Sepal          6.187198
Petal         43.913020
Width          30.513776
Length        30.386706

명확하고 엄격한 답변에 감사드립니다! 그룹 중요도에 대한 기능을 추가하지 않으려면 좋을 것입니다.
Max Ghenis

그 답변에 감사드립니다! 두 분의 질문이 있습니다. (1) 중요도는 다음과 같이 계산됩니다. : Breiman의 정의와 관련하여, 나는 거기에서 "지니 감소"이며, 중요도는 감소의 합이 될 것입니다. ? (2) 해당 예측 변수가 포함 된 포리스트의 모든 스플릿에 대한 평균 :이 특정 기능에 대한 스플릿 이 포함 된 모든 노드로 이를 대체 할 수 있습니까 ? 내가 완전히 이해할 수 있도록;)
Remi Mélisson

1
당신은 내가 정의에 대해 조금 더 생각하고있어서 R에 사용 된 randomForest 코드를 통해 제대로 대답했습니다. 나는 정직하기 위해 약간 벗어났다. 평균은 모든 노드가 아닌 모든 트리에 적용됩니다. 시간이 되 자마자 답변을 업데이트하겠습니다. 다음은 귀하의 질문에 대한 답변입니다 : (1) 예. 이것이 트리 수준에서 정의되는 방식입니다. 감소의 합은 모든 나무에 대해 평균됩니다. (2) 그렇습니다. 그것이 제가 말하고자하는 것이었지만 실제로는 그렇지 않습니다.
동안

4

클래스들에 대한 G = sum으로 정의 된 함수 [pi (1−pi)]는 실제로 엔트로피이며, 이는 분할을 평가하는 또 다른 방법입니다. 자식 노드와 부모 노드의 엔트로피의 차이점은 정보 획득입니다. GINI 불순물 함수는 G = 1- 클래스에 대한 합계 [pi ^ 2]입니다.

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.