간단한 로지스틱 회귀 모델은 어떻게 MNIST에서 92 %의 분류 정확도를 달성합니까?


64

MNIST 데이터 세트의 모든 이미지가 비슷한 스케일로 중심을 맞추고 회전하지 않고 위로 향하더라도 선형 모델이 이러한 높은 분류 정확도를 달성하는 방법을 의아해하는 중요한 필기 변형이 있습니다.

내가 볼 수있는 한, 상당한 필기 변형을 고려할 때, 숫자는 784 차원 공간에서 선형으로 분리 할 수 ​​없어야합니다. 즉, 서로 다른 숫자를 분리하는 약간 복잡한 (매우 복잡하지는 않지만) 비선형 경계가 있어야합니다 , 양성 클래스와 음성 클래스를 선형 분류기로 분리 할 수없는 잘 인용 된 예제 와 유사합니다 . 다중 클래스 로지스틱 회귀 분석이 완전히 선형적인 특징 (다항식 특징 없음)으로 어떻게 이러한 높은 정확도를 생성하는지는 당황 스럽습니다.XOR

예를 들어, 이미지의 임의의 픽셀이 주어지면, 숫자 및 의 다른 필기 변형은 해당 픽셀을 조명하거나 조명하지 않을 수 있습니다. 따라서 학습 된 가중치 세트를 사용하면 각 픽셀 이 뿐만 아니라 처럼 숫자를 볼 수 있습니다 . 픽셀 값의 조합으로 만 숫자가 인지 인지를 말할 수 있어야합니다 . 이것은 대부분의 숫자 쌍에 해당됩니다. 따라서 로지스틱 회귀 분석은 픽셀 간 종속성을 전혀 고려하지 않고 모든 픽셀 값에 독립적으로 결정을 내리는 방식으로 어떻게 높은 정확도를 얻을 수 있습니까?232323

어딘가에 잘못되었거나 이미지의 변형을 과대 평가하고 있음을 알고 있습니다. 그러나 누군가가 숫자가 선형 적으로 분리 가능한 '거의'직관에 대해 도움을 줄 수 있다면 좋을 것입니다.


희소성 통계 학습 : 올가미 및 일반화 교과서 3.3.1 예 : 필기 숫자 web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

궁금한 점이 있습니다. 형벌 된 선형 모델 (예 : glmnet)과 같은 것이 문제에 얼마나 잘 작용합니까? 내가 기억한다면, 당신이보고있는 것은 처벌되지 않은 샘플 외부 정확도입니다.
Cliff AB

답변:


82

tl; dr 비록 이것이 이미지 분류 데이터 셋이지만, 입력에서 예측까지 의 직접적인 매핑 을 쉽게 찾을 수 있는 매우 쉬운 작업으로 남아 있습니다 .


대답:

이것은 매우 흥미로운 질문이며 로지스틱 회귀의 단순성 덕분에 실제로 답을 찾을 수 있습니다.

로지스틱 회귀가하는 것은 각 이미지가 입력을 받아 가중치를 곱하여 예측을 생성하는 것입니다. 흥미로운 점은 입력과 출력 사이의 직접 매핑 (즉, 숨겨진 계층 없음)으로 인해 각 가중치의 값 이 각 클래스의 확률을 계산할 때 각 입력 중 하나가 얼마나 많이 고려되는지에 해당한다는 것입니다 . 이제 각 클래스의 가중치를 가져 와서 (즉, 이미지 해상도) 로 재구성하여 각 클래스의 계산에 어떤 픽셀이 가장 중요한지 알 수 있습니다 .78478428×28

다시, 이것들은 가중치 입니다.

이제 위의 이미지를보고 처음 두 자리 (즉, 0과 1)에 초점을 맞추십시오. 파란색 가중치는이 픽셀의 강도가 해당 클래스에 많은 기여를하고 빨간색 값은 부정적인 기여를 의미합니다.

이제 사람이 어떻게 그리는지 상상해보십시오 . 그는 비어있는 원형을 그립니다. 그것이 정확히 무게가 들어온 것입니다. 실제로 누군가가 이미지의 중간을 그리면 음수 로 0으로 계산 됩니다. 따라서 0을 인식하기 위해 정교한 필터와 고급 기능이 필요하지 않습니다. 그려진 픽셀 위치를보고 이에 따라 판단 할 수 있습니다.0

마찬가지입니다 . 이미지 중간에 항상 수직선이 있습니다. 다른 모든 것은 부정적으로 계산됩니다.1

나머지 자릿수는 조금 더 복잡하지만 상상력이 거의 없으면 , , 및 있습니다. 나머지 숫자는 조금 더 어려워서 로지스틱 회귀가 90 년대에 도달하는 것을 실제로 제한합니다.2378

이를 통해 로지스틱 회귀 분석에서 많은 이미지를 얻을 수있는 가능성이 매우 높아서 점수가 너무 높아지는 것을 알 수 있습니다.


위의 그림을 재현하는 코드는 약간 오래되었지만 여기에 있습니다.

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

9
일러스트레이션 주셔서 감사합니다. 이러한 중량 이미지는 정확도가 얼마나 높은지를 더 명확하게합니다. 이미지의 실제 레이블에 해당하는 무게 이미지와 필기 숫자 이미지의 도트 곱셈은 다른 무게 레이블이있는 도트 제품과 비교할 때 가장 높은 것으로 간주됩니다 (여전히 92 %가 나에게 많이 보입니다) MNIST의 이미지 중 하나입니다. 그럼에도 불구 하고 혼동 행렬을 조사 할 때 와 또는 과 이 서로 잘못 분류되는 경우는 거의 없습니다. 어쨌든, 이것이 바로 그 것입니다. 데이터는 절대 거짓말하지 않습니다. :)2378
Nitish Agarwal

13
물론 MNIST 샘플을 분류자가 볼 수 있기 전에 중심을 맞추고 크기를 조정하고 대비를 표준화하는 데 도움이됩니다. "0의 가장자리가 실제로 상자의 가운데를 통과하면 어떻게됩니까?"와 같은 질문을 해결할 필요가 없습니다. 프리 프로세서는 이미 모든 0을 동일하게 보이도록 먼 길을 가고 있기 때문입니다.
홉스

1
@EricDuminil 나는 당신의 제안과 함께 스크립트에 칭찬을 추가했습니다. 입력 해 주셔서 감사합니다! : D
Djib2011

1
@NitishAgarwal,이 답변이 귀하의 질문에 대한 답변이라고 생각되면, 그렇게 표시하십시오.
sintax

7
이러한 종류의 처리에 관심이 있지만 특별히 익숙하지 않은 사람에게는이 답변이 역학의 환상적인 직관적 예를 제공합니다.
chrylis
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.