ROC médio para validação cruzada repetida em 10 vezes com estimativas de probabilidade

15

Estou planejando usar a validação cruzada estratificada de 10 vezes repetida (10 vezes) em cerca de 10.000 casos usando o algoritmo de aprendizado de máquina. Cada vez que a repetição será feita com sementes aleatórias diferentes.

Nesse processo, crio 10 instâncias de estimativas de probabilidade para cada caso. 1 instância de estimativa de probabilidade para cada uma das 10 repetições da validação cruzada de 10 vezes

Posso calcular em média 10 probabilidades para cada caso e criar uma nova curva ROC média (representando resultados de CV repetido 10 vezes), que pode ser comparada com outras curvas ROC por comparações emparelhadas?

user97953
fonte

Respostas:

13

A partir da sua descrição, parece fazer todo o sentido: não apenas você pode calcular a curva ROC média, mas também a variação em torno dela para criar intervalos de confiança. Deve dar uma idéia de quão estável é o seu modelo.

Por exemplo, assim:

insira a descrição da imagem aqui

Aqui eu coloquei curvas ROC individuais, bem como a curva média e os intervalos de confiança. Há áreas em que as curvas concordam, portanto, temos menos variações e há áreas em que elas discordam.

Para CV repetido, basta repeti-lo várias vezes e obter a média total de todas as dobras individuais:

insira a descrição da imagem aqui

É bastante semelhante à imagem anterior, mas fornece estimativas mais estáveis ​​(ou seja, confiáveis) da média e variância.

Aqui está o código para obter o enredo:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Para CV repetido:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Fonte de inspiração: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html

Alexey Grigorev
fonte
3

Não está correto para as probabilidades médias porque isso não representaria as previsões que você está tentando validar e envolve contaminação nas amostras de validação.

Observe que 100 repetições de validação cruzada de 10 vezes podem ser necessárias para obter precisão adequada. Ou use o bootstrap de otimismo do Efron-Gong, que requer menos iterações para a mesma precisão (consulte, por exemplo rms, validatefunções do pacote R ).

As curvas ROC não são de forma alguma esclarecedoras para este problema. Use uma pontuação de precisão adequada e acompanhe-a com oc-index (probabilidade de concordância; AUROC), que é muito mais fácil de lidar do que a curva, uma vez que é calculada com facilidade e rapidez usando a estatística Wilcoxon-Mann-Whitney.

Frank Harrell
fonte
Você poderia explicar melhor por que a média não está correta?
DataD'oh
Já indicado. Você precisa validar a medida que você usará no campo.
31717 Frank Fellowski em