Importância relativa de um conjunto de preditores na classificação de florestas aleatórias em R

31

Eu gostaria de determinar a importância relativa de conjuntos de variáveis ​​para um randomForestmodelo de classificação em R. A importancefunção fornece a MeanDecreaseGinimétrica para cada preditor individual - é tão simples quanto resumir isso em cada preditor de um conjunto?

Por exemplo:

# Assumes df has variables a1, a2, b1, b2, and outcome
rf <- randomForest(outcome ~ ., data=df)
importance(rf)
# To determine whether the "a" predictors are more important than the "b"s,
# can I sum the MeanDecreaseGini for a1 and a2 and compare to that of b1+b2?
Max Ghenis
fonte

Respostas:

46

Primeiro, eu gostaria de esclarecer qual a métrica de importância realmente mede.

MeanDecreaseGini é uma medida de importância variável com base no índice de impureza de Gini usado para o cálculo de divisões durante o treinamento. Um equívoco comum é que a métrica de importância variável se refere ao Gini usado para afirmar o desempenho do modelo que está intimamente relacionado à AUC, mas isso está errado. Aqui está a explicação do pacote randomForest escrito por Breiman e Cutler:

Importância de Gini
Toda vez que uma divisão de um nó é feita na variável m, o critério de impureza de gini para os dois nós descendentes é menor que o nó pai. A soma das diminuições de gini para cada variável individual em todas as árvores da floresta fornece uma importância variável rápida que geralmente é muito consistente com a medida de importância da permutação.

O índice de impureza Gini é definido como Onde é o número de classes na variável de destino e é a razão dessa classe.

G=Eu=1 1ncpEu(1 1-pEu)
ncpEu

Para um problema de duas classes, isso resulta na seguinte curva que é maximizada para a amostra 50-50 e minimizada para os conjuntos homogêneos: Gini impureza para 2 classes

A importância é então calculada como calculada a média sobre todas as divisões na floresta que envolvem o preditor em questão. Como essa é uma média, ela pode ser facilmente estendida para ser calculada a média de todas as divisões nas variáveis ​​contidas em um grupo.

Eu=Gpumarent-GspeuEut1 1-GspeuEut2

Olhando mais de perto, sabemos que a importância de cada variável é uma condição condicional à variável usada e a médiaDez decréscimoGini do grupo seria apenas a média dessas importâncias ponderadas no compartilhamento que essa variável é usada na floresta em comparação com as outras variáveis ​​do mesmo grupo. Isso ocorre porque a propriedade da torre

E[E[X|Y]]=E[X]

Agora, para responder sua pergunta diretamente, não é tão simples como resumir todas as importâncias em cada grupo para obter o MeanDecreaseGini combinado, mas calcular a média ponderada fornecerá a resposta que você está procurando. Nós apenas precisamos encontrar as frequências variáveis ​​dentro de cada grupo.

Aqui está um script simples para obtê-los de um objeto de floresta aleatório no R:

var.share <- function(rf.obj, members) {
  count <- table(rf.obj$forest$bestvar)[-1]
  names(count) <- names(rf.obj$forest$ncat)
  share <- count[members] / sum(count[members])
  return(share)
}

Basta passar os nomes das variáveis ​​no grupo como o parâmetro members.

Espero que isso responda à sua pergunta. Posso escrever uma função para obter as importâncias do grupo diretamente, se for de seu interesse.

EDIT:
Aqui está uma função que dá importância ao grupo, dado um randomForestobjeto e uma lista de vetores com nomes de variáveis. Ele usa var.sharecomo definido anteriormente. Eu não fiz nenhuma verificação de entrada, portanto, você precisa se certificar de usar os nomes de variáveis ​​corretos.

group.importance <- function(rf.obj, groups) {
  var.imp <- as.matrix(sapply(groups, function(g) {
    sum(importance(rf.obj, 2)[g, ]*var.share(rf.obj, g))
  }))
  colnames(var.imp) <- "MeanDecreaseGini"
  return(var.imp)
}

Exemplo de uso:

library(randomForest)                                                          
data(iris)

rf.obj <- randomForest(Species ~ ., data=iris)

groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
               Petal=c("Petal.Width", "Petal.Length"))

group.importance(rf.obj, groups)

>

      MeanDecreaseGini
Sepal         6.187198
Petal        43.913020

Também funciona para grupos sobrepostos:

overlapping.groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
                           Petal=c("Petal.Width", "Petal.Length"),
                           Width=c("Sepal.Width", "Petal.Width"), 
                           Length=c("Sepal.Length", "Petal.Length"))

group.importance(rf.obj, overlapping.groups)

>

       MeanDecreaseGini
Sepal          6.187198
Petal         43.913020
Width          30.513776
Length        30.386706
enquanto
fonte
Obrigado pela resposta clara e rigorosa! Se você não se importasse em adicionar uma função para importância de grupo, isso seria ótimo.
precisa saber é o seguinte
Obrigado por essa resposta! Duas perguntas, se você tiver um minuto: (1) A importância é então calculada como ... : com relação à definição de Breiman, eu sou a "diminuição de gini" lá, e a importância seria a soma das diminuições, correto ? (2) calculada a média de todas as divisões na floresta que envolvem o preditor em questão : Posso substituir isso por todos os nós que envolvem uma divisão nesse recurso específico ? Para ter certeza que eu compreendo perfeitamente;)
Remi Mélisson
11
Seu comentário me fez pensar um pouco mais nas definições, então eu procurei no código randomForest usado no R para respondê-lo corretamente. Eu tenho sido um pouco fora para ser honesto. A média é feita sobre todas as árvores e nem todos os nós. Atualizarei a resposta assim que tiver tempo. Aqui estão as respostas para sua pergunta por enquanto: (1) sim. É assim que é definido no nível da árvore. A soma das diminuições é então calculada sobre todas as árvores. (2) Sim, era isso que eu queria dizer, mas na verdade não é válido.
enquanto
4

A função definida acima como G = soma das classes [pi (1-pi)] é na verdade a entropia, que é outra maneira de avaliar uma divisão. A diferença entre a entropia nos nós filhos e o nó pai é o ganho de informação. A função de impureza GINI é G = 1 - soma nas classes [pi ^ 2].

Sowmya Iyer
fonte