Como um modelo de regressão logística simples atinge uma precisão de classificação de 92% no MNIST?

70

Embora todas as imagens no conjunto de dados MNIST estejam centralizadas, em uma escala semelhante e com a face para cima sem rotações, elas têm uma variação significativa de manuscrito que me intriga como um modelo linear atinge uma precisão de classificação tão alta.

Tanto quanto eu consigo visualizar, dada a variação significativa da caligrafia, os dígitos devem ser linearmente inseparáveis ​​em um espaço dimensional de 784, ou seja, deve haver um limite não linear pouco complexo (embora não muito complexo) que separa os dígitos diferentes , semelhante ao exemplo bem citado, em que classes positivas e negativas não podem ser separadas por nenhum classificador linear. Parece-me desconcertante como a regressão logística multi-classe produz uma precisão tão alta com características inteiramente lineares (sem características polinomiais).XOR

Como exemplo, dado qualquer pixel na imagem, diferentes variações manuscritas dos dígitos e podem tornar esse pixel iluminado ou não. Portanto, com um conjunto de pesos aprendidos, cada pixel pode fazer com que um dígito pareça um e um . Somente com uma combinação de valores de pixel é possível dizer se um dígito é ou . Isso é verdade para a maioria dos pares de dígitos. Então, como a regressão logística, que cega sua decisão de maneira independente em todos os valores de pixel (sem considerar nenhuma dependência entre pixels), é capaz de alcançar essas altas precisões.232323

Sei que estou errado em algum lugar ou estou superestimando a variação nas imagens. No entanto, seria ótimo se alguém pudesse me ajudar com uma intuição sobre como os dígitos são "quase" linearmente separáveis.

Nitish Agarwal
fonte
Veja o livro Statistical Learning with Sparsity: the Lasso and Generalizations 3.3.1 Exemplo: Dígitos manuscritos web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian
Fiquei curioso: como é que algo como um modelo linear penalizado (ou seja, glmnet) se sai sobre o problema? Se bem me lembro, o que você está relatando é a precisão fora da amostra não-penalizada.
Cliff AB

Respostas:

88

tl; dr Mesmo que este é um conjunto de dados de classificação de imagem, ele continua a ser um muito fácil tarefa, para a qual se pode facilmente encontrar um mapeamento direto de entradas para previsões.


Responda:

Essa é uma pergunta muito interessante e, graças à simplicidade da regressão logística, você pode encontrar a resposta.

O que a regressão logística faz é que cada imagem aceite entradas e multiplique-as com pesos para gerar sua previsão. O interessante é que, devido ao mapeamento direto entre entrada e saída (ou seja, nenhuma camada oculta), o valor de cada peso corresponde ao quanto cada uma das entradas é levada em consideração ao calcular a probabilidade de cada classe. Agora, pegando os pesos de cada classe e remodelando-os em (ou seja, a resolução da imagem), podemos dizer quais pixels são mais importantes para o cálculo de cada classe .78478428×28

Note, novamente, que esses são os pesos .

Agora, dê uma olhada na imagem acima e foque nos dois primeiros dígitos (ou seja, zero e um). Os pesos azuis significam que a intensidade desse pixel contribui muito para essa classe e os valores vermelhos significam que contribui negativamente.

Agora imagine como uma pessoa desenha um ? Ele desenha uma forma circular vazia no meio. Isso é exatamente o que os pesos captaram. De fato, se alguém desenha o meio da imagem, conta negativamente como um zero. Portanto, para reconhecer zeros, você não precisa de filtros sofisticados e recursos de alto nível. Você pode apenas olhar para os locais dos pixels desenhados e julgar de acordo com isso.0

A mesma coisa para o . Sempre tem uma linha vertical reta no meio da imagem. Tudo o resto conta negativamente.1

O resto dos dígitos é um pouco mais complicado, mas com pouca imaginação, você pode ver o , o , o e o . O restante dos números é um pouco mais difícil, que é o que realmente limita a regressão logística de atingir os anos 90.2378

Com isso, você pode ver que a regressão logística tem uma chance muito boa de acertar muitas imagens e é por isso que é tão alta.


O código para reproduzir a figura acima é um pouco datado, mas aqui está:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
Djib2011
fonte
12
Obrigado pela ilustração. Essas imagens de peso tornam mais claro como a precisão é tão alta. A multiplicação de pontos de uma imagem de dígito manuscrita com a imagem de peso correspondente ao rótulo verdadeiro da imagem 'parece' ser a mais alta em comparação com o produto de ponto com outros rótulos de peso para a maioria (ainda 92% me parece muito) das imagens no MNIST. Ainda assim, é um pouco surpreendente que e ou e raramente sejam classificados incorretamente um ao outro ao examinar a matriz de confusão. Enfim, é isso que é. Os dados nunca mentem. :)2378
Nitish Agarwal
13
Obviamente, ajuda que as amostras MNIST sejam centralizadas, dimensionadas e normalizadas por contraste antes que o classificador as veja. Você não precisa responder perguntas como "e se a margem do zero realmente passar pelo meio da caixa?" porque o pré-processador já percorreu um longo caminho para fazer com que todos os zeros pareçam iguais.
hobbs
11
@EricDuminil Adicionei um elogio ao script com sua sugestão. Muito obrigado pela contribuição! : D
Djib2011 13/09
11
@NitishAgarwal, se você acha que esta resposta é a resposta à sua pergunta, considere marcá-la como tal.
sintax
13
Para alguém interessado, mas não particularmente familiarizado com esse tipo de processamento, esta resposta fornece um exemplo intuitivo fantástico da mecânica.
chrylis -on strike- 15/09