Qual função de perda usar para classes desequilibradas (usando PyTorch)?

17

Eu tenho um conjunto de dados com 3 classes com os seguintes itens:

  • Classe 1: 900 elementos
  • Classe 2: 15000 elementos
  • Classe 3: 800 elementos

Preciso prever as classes 1 e 3, que sinalizam desvios importantes da norma. A classe 2 é o caso "normal" padrão com o qual não me importo.

Que tipo de função de perda eu usaria aqui? Eu estava pensando em usar CrossEntropyLoss, mas como há um desequilíbrio de classe, isso precisaria ser ponderado, suponho? Como isso funciona na prática? Assim (usando PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

Ou o peso deve ser invertido? ou seja, 1 / peso?

Essa é a abordagem correta para começar ou existem outros / melhores métodos que eu poderia usar?

obrigado

Muppet
fonte

Respostas:

13

Que tipo de função de perda eu usaria aqui?

A entropia cruzada é a função de perda para tarefas de classificação, equilibradas ou desequilibradas. É a primeira opção quando nenhuma preferência é criada a partir do conhecimento do domínio.

Isso precisaria ser ponderado, suponho? Como isso funciona na prática?

Sim. Peso da classec é o tamanho da maior classe dividida pelo tamanho da classe c.

Por exemplo, se a classe 1 tiver 900, a classe 2 tiver 15000 e a classe 3 tiver 800 amostras, seus pesos serão 16,67, 1,0 e 18,75, respectivamente.

Você também pode usar a classe menor como nominador, que fornece 0,889, 0,053 e 1,0, respectivamente. Isso é apenas uma re-escala, os pesos relativos são os mesmos.

Essa é a abordagem correta para começar ou existem outros / melhores métodos que eu poderia usar?

Sim, esta é a abordagem correta.

EDIT :

Graças ao @Muppet, também podemos usar sobre-amostragem de classe, o que equivale a usar pesos de classe . Isso é realizado WeightedRandomSamplerno PyTorch, usando os mesmos pesos mencionados acima.

Esmailiano
fonte
2
Eu só queria acrescentar que o uso do WeightedRandomSampler do PyTorch também ajudou, no caso de alguém estar olhando para isso.
Muppet
0

Quando você diz: Você também pode usar a classe mais pequena como nomeador, que fornece 0,889, 0,053 e 1,0, respectivamente. Isso é apenas uma redimensionamento, os pesos relativos são os mesmos.

Mas essa solução está em contradição com a primeira que você deu, como funciona?

Georges Matar
fonte