SVDによるMNISTの分類

SVDによるMNISTの分類

やりたかったこと

SVDを用いてmnistの分類をする

※ mnist: 手書き数字のデータセット.1枚の画像のサイズは28x28

※ 参考文献: Eldén, Lars. Matrix methods in data mining and pattern recognition. Vol. 4. SIAM, 2007.

アルゴリズム

  • 0だけが書かれた画像を MM 枚集める.
  • 画像をベクトルにして並べて AR282×MA\in\mathbb{R}^{28^2\times M} を作る
  • AA を SVDして左特異ベクトルを mm 本抜いてくる(今回の実装では m=3m=3
  • 1~9についても同様にして左特異ベクトルを抜いてくる
  • テストデータと mm 本の基底の線形結合で最小二乗法を解いて,残差を求める
  • 残差が小さい基底のセットが予測値

結果

  • 90%で識別できた
import numpy as np
import numpy.linalg as nl
import pandas as pd
from matplotlib import pyplot as plt
# データの前処理
# データは"mnist csv"でググってダウンロードしてくる
# mnistの読み込み
mnist = pd.read_csv("mnist_train.csv")
# ラベルを作った
# ",".join(["label"]+[str(i) for i in range(784)])
# "label"キーでソート
mnist = mnist.sort_values(by="label")
# csvとして保存
mnist.to_csv("mnist_train_sorted.csv", index=False)
## ソート済みのトレーニングデータを読み込んで,ラベルとデータに分けた
mnist = pd.read_csv("mnist_train_sorted.csv")
label, data = np.split(mnist, [1], axis=1)

## 0から9までの数字がどこで区切れているかのインデックス
digit_index = [0]+[np.where(label==i)[0][-1] for i in range(9)] + [len(label)]
# singular image(基底)を3枚作る
## 画像のサイズ,基底の本数
n = 28; m = 3

## 基底づくり
base = []
for d in range(10):
    A = data.iloc[digit_index[d]:digit_index[d+1]+1].T
    U, S, Vᵀ = np.linalg.svd(A)
    # Ã = (U[:,:m] * S[:m]) @ Vᵀ[:m,:]
    fig, ax = plt.subplots(nrows=1,ncols=3,figsize=(12,4))
    for nᵢ in range(3):
        ax[nᵢ].imshow(U[:,nᵢ].reshape(n,n), cmap="gray")
        base.append(U[:,nᵢ])
    plt.show(); plt.close()

## 基底の保存
# np.savetxt("mnist_base.txt",np.array(base).T)

# 基底の読み込み
base = np.loadtxt("mnist_base.txt")
# テストデータの読み込み
test = pd.read_csv("mnist_test.csv", header=None)
label_test, data_test = np.split(test, [1], axis=1)
label_test = list(label_test[0])
TESTSET_LENGTH = len(label_test)
def 検証(tᵢ):
    正解 = label_test[tᵢ]
    残差リスト = []
    y = data_test.iloc[tᵢ].values
    y = y / nl.norm(y)
    for dᵢ in range(10):
        A = base[:,3*dᵢ:3*dᵢ+3]
        x = nl.solve(A.T @ A, A.T @ y)
        残差リスト.append(nl.norm(y - A@x)/nl.norm(y))
    予測 = np.argmin(残差リスト)
    if 予測 == 正解:
        return 1
    else:
        return 0
# テストデータに0から9が何枚含まれているか
分母 = [label_test.count(i) for i in range(10)]
正解数 = np.zeros(10,np.int)
for tᵢ in range(TESTSET_LENGTH):
    正解 = label_test[tᵢ]
    正解数[正解] += 検証(tᵢ)
正解率 = [正解数[i]/分母[i] for i in range(10)]
print("正解率(総合) %3.1f%" % (sum(正解数)*100/sum(分母)))
for i in range(10):
    print("%d | 正解率 %3.1f%" % (i, 正解率[i]*100))
plt.plot(range(10), 正解率)
plt.title("正解率")
plt.ylim(0,1)
plt.show(); plt.close()
正解率(総合) 90.2%
0, 正解率 96.5%
1, 正解率 99.2%
2, 正解率 89.5%
3, 正解率 89.2%
4, 正解率 87.6%
5, 正解率 82.2%
6, 正解率 95.1%
7, 正解率 86.8%
8, 正解率 85.7%
9, 正解率 88.9%

n = 28
for tᵢ in range(3):
    res = []
    y = data_test.iloc[tᵢ].values
    y = y / nl.norm(y)
    for dᵢ in range(10):
        A = base[:,3*dᵢ:3*dᵢ+3]
        x = nl.solve(A.T @ A, A.T @ y)
        res.append(nl.norm(y - A@x)/nl.norm(y))
    fig, ax = plt.subplots(nrows=1,ncols=2,figsize=(8,4))
    ax[0].imshow(data_test.iloc[tᵢ].values.reshape(n,n), cmap="gray")
    ax[1].plot(range(10), res)
    print("正解は%d,予測は%d" % (label_test.iloc[tᵢ], np.argmin(res)))
    plt.show(); plt.close()
正解は7,予測は7

正解は2,予測は2

正解は1,予測は1

コメント