Floresta aleatória e previsão

13

Estou tentando entender como a Random Forest funciona. Tenho uma idéia de como as árvores são construídas, mas não consigo entender como a Random Forest faz previsões em amostras fora do saco. Alguém poderia me dar uma explicação simples, por favor? :)

user1665355
fonte

Respostas:

16

Cada árvore na floresta é construída a partir de uma amostra de inicialização das observações em seus dados de treinamento. Essas observações na amostra de bootstrap constroem a árvore, enquanto as que não estão na amostra de bootstrap formam as amostras prontas para uso (ou OOB).

Deve ficar claro que as mesmas variáveis ​​estão disponíveis para casos nos dados usados ​​para construir uma árvore e para os casos na amostra OOB. Para obter previsões para a amostra OOB, cada uma é passada para baixo na árvore atual e as regras para a árvore são seguidas até que chegue ao nó do terminal. Isso produz as previsões OOB para essa árvore específica.

Esse processo é repetido várias vezes, cada árvore treinada em uma nova amostra de autoinicialização a partir dos dados de treinamento e previsões para as novas amostras OOB derivadas.

À medida que o número de árvores cresce, qualquer uma das amostras estará nas amostras OOB mais de uma vez, portanto, a "média" das previsões sobre as N árvores em que uma amostra está no OOB será usada como previsão OOB para cada amostra de treinamento para trees 1, ..., N. Por "média", usamos a média das previsões para uma resposta contínua, ou o voto da maioria pode ser usado para uma resposta categórica (o voto da maioria é a classe com mais votos no conjunto de árvores 1, ..., N).

Por exemplo, suponha que tivéssemos as seguintes previsões OOB para 10 amostras em um conjunto de treinamento em 10 árvores

set.seed(123)
oob.p <- matrix(rpois(100, lambda = 4), ncol = 10)
colnames(oob.p) <- paste0("tree", seq_len(ncol(oob.p)))
rownames(oob.p) <- paste0("samp", seq_len(nrow(oob.p)))
oob.p[sample(length(oob.p), 50)] <- NA
oob.p

> oob.p
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA     7     8     2     1    NA     5     3      2
samp2      6    NA     5     7     3    NA    NA    NA    NA     NA
samp3      3    NA     5    NA    NA    NA     3     5    NA     NA
samp4      6    NA    10     6    NA    NA     3    NA     6     NA
samp5     NA     2    NA    NA     2    NA     6     4    NA     NA
samp6     NA     7    NA     4    NA     2     4     2    NA     NA
samp7     NA    NA    NA     5    NA    NA    NA     3     9      5
samp8      7     1     4    NA    NA     5     6    NA     7     NA
samp9      4    NA    NA     3    NA     7     6     3    NA     NA
samp10     4     8     2     2    NA    NA     4    NA    NA      4

Onde NAsignifica que a amostra estava nos dados de treinamento para essa árvore (em outras palavras, não estava na amostra OOB).

A média dos não NAvalores para cada linha fornece a previsão OOB para cada amostra, para toda a floresta

> rowMeans(oob.p, na.rm = TRUE)
 samp1  samp2  samp3  samp4  samp5  samp6  samp7  samp8  samp9 samp10 
  4.00   5.25   4.00   6.20   3.50   3.80   5.50   5.00   4.60   4.00

À medida que cada árvore é adicionada à floresta, podemos calcular o erro OOB até incluir essa árvore. Por exemplo, abaixo estão as médias cumulativas para cada amostra:

FUN <- function(x) {
  na <- is.na(x)
  cs <- cumsum(x[!na]) / seq_len(sum(!na))
  x[!na] <- cs
  x
}
t(apply(oob.p, 1, FUN))

> print(t(apply(oob.p, 1, FUN)), digits = 3)
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA  7.00  7.50  5.67  4.50    NA   4.6  4.33    4.0
samp2      6    NA  5.50  6.00  5.25    NA    NA    NA    NA     NA
samp3      3    NA  4.00    NA    NA    NA  3.67   4.0    NA     NA
samp4      6    NA  8.00  7.33    NA    NA  6.25    NA  6.20     NA
samp5     NA     2    NA    NA  2.00    NA  3.33   3.5    NA     NA
samp6     NA     7    NA  5.50    NA  4.33  4.25   3.8    NA     NA
samp7     NA    NA    NA  5.00    NA    NA    NA   4.0  5.67    5.5
samp8      7     4  4.00    NA    NA  4.25  4.60    NA  5.00     NA
samp9      4    NA    NA  3.50    NA  4.67  5.00   4.6    NA     NA
samp10     4     6  4.67  4.00    NA    NA  4.00    NA    NA    4.0

Dessa maneira, vemos como a previsão é acumulada sobre as N árvores da floresta até uma determinada iteração. Se você ler nas linhas, o lado direito mais à direitaNA valor é o que eu mostro acima para a previsão do OOB. É assim que os traços do desempenho do OOB podem ser feitos - um RMSEP pode ser calculado para as amostras de OOB com base nas previsões de OOB acumuladas cumulativamente sobre as N árvores.

Observe que o código R mostrado não é obtido dos internos do código randomForest no randomForest pacote para R - acabei de um código simples para que você possa acompanhar o que está acontecendo quando as previsões de cada árvore são determinadas.

É porque cada árvore é construída a partir de uma amostra de bootstrap e que há um grande número de árvores em uma floresta aleatória, de modo que cada observação do conjunto de treinamento esteja na amostra OOB para uma ou mais árvores, que as previsões OOB podem ser fornecidas para todas as amostras nos dados de treinamento.

Eu encobri questões como dados ausentes para alguns casos de OOB, etc., mas essas questões também pertencem a uma única árvore de regressão ou classificação. Observe também que cada árvore em uma floresta usa apenas mtryvariáveis ​​selecionadas aleatoriamente.

Restabelecer Monica - G. Simpson
fonte
Ótima resposta Gavin! Quando você escreve "To get predictions for the OOB sample, each one is passed down the current tree and the rules for the tree followed until it arrives in a terminal node", você tem uma explicação simples do que rules for the treesão? E entendo samplecomo uma linha corretamente se entender que as amostras são groupsde observações nas quais as árvores dividem os dados?
user1665355
@ user1665355 Presumi que você entendeu como as árvores de regressão ou classificação foram construídas? As árvores na RF não são diferentes (exceto talvez nas regras de parada). Cada árvore divide os dados de treinamento em grupos de amostras com "valores" semelhantes para a resposta. A variável e a localização da divisão (por exemplo, pH> 4,5) que melhor prevê (ou seja, minimiza o "erro") formam a primeira divisão ou regra na árvore. Cada ramificação dessa divisão é considerada por sua vez e novas divisões / regras são identificadas, minimizando o "erro" da árvore. Este é o algoritmo de particionamento recursivo binário. As divisões são as regras.
Reintegrar Monica - G. Simpson
@ user1665355 Sim, desculpe, eu venho de um campo em que uma amostra é uma observação, uma linha no conjunto de dados. Mas quando você começa a falar sobre uma amostra de autoinicialização, esse é um conjunto de N observações, extraídas com a substituição dos dados de treinamento e, portanto, possui N linhas ou observações. Vou tentar limpar minha terminologia mais tarde.
Reintegrar Monica - G. Simpson
Obrigado! Sou muito novo na RF, desculpe-me por perguntas estúpidas :) Acho que entendi quase tudo o que você escreveu, uma explicação muito boa! Eu apenas me pergunto sobre A variável e a localização da divisão (por exemplo, pH> 4,5) que melhor prevê (ou seja, minimiza o "erro") formam a primeira divisão ou regra da árvore ... Não consigo entender qual é o erro. : / Estou lendo e tentando entender http://www.ime.unicamp.br/~ra109078/PED/Data%20Minig%20with%20R/Data%20Mining%20with%20R.pdf. Na página 115-116, os autores usam a RF para escolher variable importanceindicadores técnicos.
user1665355
O "erro" depende de que tipo de árvore está sendo ajustada. O desvio é a medida usual para respostas contínuas (gaussianas). No pacote rpart, o coeficiente de Gini é o padrão para respostas categóricas, mas existem outros para modelos diferentes etc. Você deve utilizar um bom livro sobre Árvores e RF se quiser implantá-lo com sucesso. As medidas de desvio variável são diferentes - elas medem a "importância" de cada variável no conjunto de dados, ver quanto algo muda quando essa variável é usada para ajustar uma árvore e quando essa variável não é usada.
Reinstate Monica - G. Simpson