Estou trabalhando em um problema de análise de sentimento, os dados se parecem com este:
label instances
5 1190
4 838
3 239
1 204
2 127
Portanto, meus dados estão desequilibrados, visto que 1190 instances
estão marcados com 5
. Para a classificação, estou usando o SVC do scikit . O problema é que eu não sei como equilibrar meus dados da maneira certa para calcular com precisão a precisão, recuperação, exatidão e pontuação f1 para o caso multiclasse. Então, tentei as seguintes abordagens:
Primeiro:
wclf = SVC(kernel='linear', C= 1, class_weight={1: 10})
wclf.fit(X, y)
weighted_prediction = wclf.predict(X_test)
print 'Accuracy:', accuracy_score(y_test, weighted_prediction)
print 'F1 score:', f1_score(y_test, weighted_prediction,average='weighted')
print 'Recall:', recall_score(y_test, weighted_prediction,
average='weighted')
print 'Precision:', precision_score(y_test, weighted_prediction,
average='weighted')
print '\n clasification report:\n', classification_report(y_test, weighted_prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, weighted_prediction)
Segundo:
auto_wclf = SVC(kernel='linear', C= 1, class_weight='auto')
auto_wclf.fit(X, y)
auto_weighted_prediction = auto_wclf.predict(X_test)
print 'Accuracy:', accuracy_score(y_test, auto_weighted_prediction)
print 'F1 score:', f1_score(y_test, auto_weighted_prediction,
average='weighted')
print 'Recall:', recall_score(y_test, auto_weighted_prediction,
average='weighted')
print 'Precision:', precision_score(y_test, auto_weighted_prediction,
average='weighted')
print '\n clasification report:\n', classification_report(y_test,auto_weighted_prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, auto_weighted_prediction)
Terceiro:
clf = SVC(kernel='linear', C= 1)
clf.fit(X, y)
prediction = clf.predict(X_test)
from sklearn.metrics import precision_score, \
recall_score, confusion_matrix, classification_report, \
accuracy_score, f1_score
print 'Accuracy:', accuracy_score(y_test, prediction)
print 'F1 score:', f1_score(y_test, prediction)
print 'Recall:', recall_score(y_test, prediction)
print 'Precision:', precision_score(y_test, prediction)
print '\n clasification report:\n', classification_report(y_test,prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, prediction)
F1 score:/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:676: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1172: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1082: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
0.930416613529
No entanto, estou recebendo avisos como este:
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1172:
DeprecationWarning: The default `weighted` averaging is deprecated,
and from version 0.18, use of precision, recall or F-score with
multiclass or multilabel data or pos_label=None will result in an
exception. Please set an explicit value for `average`, one of (None,
'micro', 'macro', 'weighted', 'samples'). In cross validation use, for
instance, scoring="f1_weighted" instead of scoring="f1"
Como posso lidar corretamente com meus dados desequilibrados para calcular as métricas do classificador da maneira certa?
python
machine-learning
nlp
artificial-intelligence
scikit-learn
new_with_python
fonte
fonte
average
parâmetro no terceiro caso?Respostas:
Acho que há muita confusão sobre quais pesos são usados para quê. Não tenho certeza se sei exatamente o que o incomoda, então vou abordar vários tópicos, tenha paciência;).
Pesos de classe
Os pesos do
class_weight
parâmetro são usados para treinar o classificador . Eles não são usados no cálculo de nenhuma das métricas que você está usando : com pesos de classe diferentes, os números serão diferentes simplesmente porque o classificador é diferente.Basicamente, em cada classificador scikit-learn, os pesos das classes são usados para dizer ao seu modelo o quão importante é uma classe. Isso significa que durante o treinamento, o classificador fará um esforço extra para classificar adequadamente as classes com pesos elevados.
Como eles fazem isso é específico do algoritmo. Se quiser detalhes sobre como ele funciona para o SVC e o documento não fizer sentido para você, fique à vontade para mencioná-lo.
As métricas
Depois de ter um classificador, você deseja saber como está seu desempenho. Aqui você pode usar as métricas que você mencionou:
accuracy
,recall_score
,f1_score
...Normalmente, quando a distribuição de classes é desequilibrada, a precisão é considerada uma escolha ruim, pois dá pontuações altas aos modelos que apenas prevêem a classe mais frequente.
Não vou detalhar todas essas métricas, mas observo que, com exceção de
accuracy
, elas são naturalmente aplicadas no nível da classe: como você pode ver nesteprint
relatório de classificação, elas são definidas para cada classe. Eles se baseiam em conceitos comotrue positives
oufalse negative
que exigem a definição de qual classe é a positiva .O aviso
Você recebe este aviso porque está usando o f1-score, recall e precisão sem definir como eles devem ser calculados! A pergunta poderia ser reformulada: a partir do relatório de classificação acima, como você produz um número global para a pontuação f1? Você poderia:
avg / total
resultado acima. Também é chamado de média de macro .'weighted'
no scikit-learn pesará a pontuação f1 pelo apoio da classe: quanto mais elementos uma classe tiver, mais importante será a pontuação f1 para essa classe no cálculo.Essas são 3 das opções no scikit-learn; o aviso está lá para dizer que você deve escolher uma . Portanto, você deve especificar um
average
argumento para o método de pontuação.Qual você escolher depende de como você deseja medir o desempenho do classificador: por exemplo, a macro-média não leva em conta o desequilíbrio da classe e a pontuação f1 da classe 1 será tão importante quanto a pontuação f1 da classe 5. Se você usar a média ponderada, no entanto, terá mais importância para a classe 5.
Toda a especificação do argumento nessas métricas não é superclara no scikit-learn agora, ela ficará melhor na versão 0.18 de acordo com os documentos. Eles estão removendo alguns comportamentos padrão não óbvios e emitindo avisos para que os desenvolvedores os percebam.
Pontuação de computação
A última coisa que quero mencionar (sinta-se à vontade para pular se você estiver ciente disso) é que as pontuações só são significativas se forem calculadas em dados que o classificador nunca viu . Isso é extremamente importante, pois qualquer pontuação obtida nos dados que foram usados para ajustar o classificador é completamente irrelevante.
Esta é uma maneira de fazer isso usando
StratifiedShuffleSplit
, que fornece a você uma divisão aleatória de seus dados (após embaralhamento) que preserva a distribuição do rótulo.Espero que isto ajude.
fonte
class_weight={1:10}
significa um dado que tem 3 classes?ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of labels for any class cannot be less than 2.
. Está funcionando bem com a divisão de teste de trem, mas alguém pode me ajudar por que estou recebendo esse erro com SSS? Obrigado.Muitas respostas detalhadas aqui, mas não acho que você está respondendo às perguntas certas. Pelo que entendi a pergunta, existem duas preocupações:
1
Você pode usar a maioria das funções de pontuação no scikit-learn tanto com problemas multiclasse quanto com problemas de classe única. Ex.:
Dessa forma, você acaba com números tangíveis e interpretáveis para cada uma das classes.
Então...
2
... você pode dizer se os dados desequilibrados são mesmo um problema. Se a pontuação para as classes menos representadas (classes 1 e 2) for menor do que para as classes com mais amostras de treinamento (classes 4 e 5), então você sabe que os dados desequilibrados são de fato um problema e você pode agir de acordo, como descrito em algumas das outras respostas neste tópico. No entanto, se a mesma distribuição de classe estiver presente nos dados que você deseja prever, seus dados de treinamento desequilibrado são um bom representante dos dados e, portanto, o desequilíbrio é uma coisa boa.
fonte
precision_recall_fscore_support
? As etiquetas são impressas por encomenda?average=None
e defina os rótulos, então você obtém a métrica que está procurando, para cada um dos rótulos especificados.Questão colocada
Respondendo à pergunta 'qual métrica deve ser usada para classificação multiclasse com dados desequilibrados': Macro-F1-medir. Macro Precisão e Macro Recall também podem ser usados, mas eles não são tão facilmente interpretáveis como para a classificação binária, eles já estão incorporados na medida F e métricas em excesso complicam a comparação de métodos, ajuste de parâmetros e assim por diante.
A micro média é sensível ao desequilíbrio de classe: se o seu método, por exemplo, funciona bem para os rótulos mais comuns e bagunça totalmente os outros, a micro média métrica mostra bons resultados.
A média de ponderação não é adequada para dados desequilibrados, porque ela pondera por contagens de rótulos. Além disso, é dificilmente interpretável e impopular: por exemplo, não há menção de tal média na seguinte pesquisa muito detalhada que eu recomendo fortemente que você analise:
Pergunta específica do aplicativo
No entanto, voltando à sua tarefa, eu pesquisaria 2 tópicos:
Métricas comumente usadas. Como posso inferir depois de olhar a literatura, existem 2 métricas de avaliação principais:
( link ) - observe que os autores trabalham com quase a mesma distribuição de classificações, consulte a Figura 5.
( link )
( link ) - eles exploram tanto a precisão quanto o MSE, considerando o último ser melhor
( link ) - eles utilizam o scikit-learn para avaliação e abordagens de linha de base e afirmam que seu código está disponível; entretanto, não consigo encontrar, então se precisar, escreva uma carta aos autores, o trabalho é bem novo e parece ter sido escrito em Python.
Custo de erros diferentes . Se você se preocupa mais em evitar erros grosseiros, por exemplo, atribuir uma crítica de 1 a 5 estrelas ou algo assim, dê uma olhada no MSE; se a diferença importa, mas não tanto, tente MAE, uma vez que não ao quadrado diff; caso contrário, fique com Precisão.
Sobre abordagens, não métricas
Tente abordagens de regressão, por exemplo, SVR , uma vez que geralmente superam classificadores Multiclass como SVC ou OVA SVM.
fonte
Em primeiro lugar, é um pouco mais difícil usar apenas a análise de contagem para saber se seus dados estão desequilibrados ou não. Por exemplo: 1 em 1000 observação positiva é apenas um ruído, erro ou um avanço na ciência? Nunca se sabe.
Portanto, é sempre melhor usar todo o seu conhecimento disponível e escolher seu status com toda sabedoria.
Ok, e se estiver realmente desequilibrado?
Mais uma vez - observe seus dados. Às vezes, você pode encontrar uma ou duas observações multiplicadas por cem vezes. Às vezes é útil criar essas observações falsas de uma classe.
Se todos os dados estiverem limpos, a próxima etapa é usar pesos de classe no modelo de previsão.
E quanto às métricas multiclasse?
Na minha experiência, nenhuma de suas métricas é normalmente usada. Há duas razões principais.
Primeiro: é sempre melhor trabalhar com probabilidades do que com predição sólida (porque de que outra forma você poderia separar modelos com predição de 0,9 e 0,6 se ambos fornecem a mesma classe?)
E segundo: é muito mais fácil comparar seus modelos de predição e construir novos aqueles dependendo de apenas uma boa métrica. Pela
minha experiência, eu poderia recomendar logloss ou MSE (ou apenas significar erro quadrático).
Como consertar os avisos do sklearn?
Simplesmente (como yangjie notou) sobrescrever o
average
parâmetro com um destes valores:'micro'
(calcular métricas globalmente),'macro'
(calcular métricas para cada rótulo) ou'weighted'
(igual ao macro, mas com pesos automáticos).Todos os seus avisos vieram depois de chamar funções de métricas com
average
valor padrão'binary'
que é impróprio para previsão multiclasse.Boa sorte e divirta-se com o aprendizado de máquina!
Edit:
eu encontrei outra recomendação do respondente para mudar para abordagens de regressão (por exemplo, SVR) com a qual não posso concordar. Pelo que me lembro, não existe nem mesmo regressão multiclasse. Sim, existe uma regressão multilabel que é muito diferente e sim, é possível, em alguns casos, alternar entre regressão e classificação (se as classes forem classificadas de alguma forma), mas é muito raro.
O que eu recomendaria (no escopo do scikit-learn) é tentar outras ferramentas de classificação muito poderosas: aumento de gradiente , floresta aleatória (minha favorita), KNeighbors e muito mais.
Depois disso, você pode calcular a média aritmética ou geométrica entre as previsões e, na maioria das vezes, você obterá resultados ainda melhores.
fonte