SVDによるMNISTの分類
やりたかったこと
SVDを用いてmnistの分類をする
※ mnist: 手書き数字のデータセット.1枚の画像のサイズは28x28
※ 参考文献: Eldén, Lars. Matrix methods in data mining and pattern recognition. Vol. 4. SIAM, 2007.
アルゴリズム
- 0だけが書かれた画像を 枚集める.
- 画像をベクトルにして並べて を作る
- を SVDして左特異ベクトルを 本抜いてくる(今回の実装では )
- 1~9についても同様にして左特異ベクトルを抜いてくる
- テストデータと 本の基底の線形結合で最小二乗法を解いて,残差を求める
- 残差が小さい基底のセットが予測値
結果
- 90%で識別できた
import numpy as np
import numpy.linalg as nl
import pandas as pd
from matplotlib import pyplot as plt
# データの前処理
# データは"mnist csv"でググってダウンロードしてくる
# mnistの読み込み
mnist = pd.read_csv("mnist_train.csv")
# ラベルを作った
# ",".join(["label"]+[str(i) for i in range(784)])
# "label"キーでソート
mnist = mnist.sort_values(by="label")
# csvとして保存
mnist.to_csv("mnist_train_sorted.csv", index=False)
## ソート済みのトレーニングデータを読み込んで,ラベルとデータに分けた
mnist = pd.read_csv("mnist_train_sorted.csv")
label, data = np.split(mnist, [1], axis=1)
## 0から9までの数字がどこで区切れているかのインデックス
digit_index = [0]+[np.where(label==i)[0][-1] for i in range(9)] + [len(label)]
# singular image(基底)を3枚作る
## 画像のサイズ,基底の本数
n = 28; m = 3
## 基底づくり
base = []
for d in range(10):
A = data.iloc[digit_index[d]:digit_index[d+1]+1].T
U, S, Vᵀ = np.linalg.svd(A)
# Ã = (U[:,:m] * S[:m]) @ Vᵀ[:m,:]
fig, ax = plt.subplots(nrows=1,ncols=3,figsize=(12,4))
for nᵢ in range(3):
ax[nᵢ].imshow(U[:,nᵢ].reshape(n,n), cmap="gray")
base.append(U[:,nᵢ])
plt.show(); plt.close()
## 基底の保存
# np.savetxt("mnist_base.txt",np.array(base).T)
# 基底の読み込み
base = np.loadtxt("mnist_base.txt")
# テストデータの読み込み
test = pd.read_csv("mnist_test.csv", header=None)
label_test, data_test = np.split(test, [1], axis=1)
label_test = list(label_test[0])
TESTSET_LENGTH = len(label_test)
def 検証(tᵢ):
正解 = label_test[tᵢ]
残差リスト = []
y = data_test.iloc[tᵢ].values
y = y / nl.norm(y)
for dᵢ in range(10):
A = base[:,3*dᵢ:3*dᵢ+3]
x = nl.solve(A.T @ A, A.T @ y)
残差リスト.append(nl.norm(y - A@x)/nl.norm(y))
予測 = np.argmin(残差リスト)
if 予測 == 正解:
return 1
else:
return 0
# テストデータに0から9が何枚含まれているか
分母 = [label_test.count(i) for i in range(10)]
正解数 = np.zeros(10,np.int)
for tᵢ in range(TESTSET_LENGTH):
正解 = label_test[tᵢ]
正解数[正解] += 検証(tᵢ)
正解率 = [正解数[i]/分母[i] for i in range(10)]
print("正解率(総合) %3.1f%" % (sum(正解数)*100/sum(分母)))
for i in range(10):
print("%d | 正解率 %3.1f%" % (i, 正解率[i]*100))
plt.plot(range(10), 正解率)
plt.title("正解率")
plt.ylim(0,1)
plt.show(); plt.close()
正解率(総合) 90.2%
0, 正解率 96.5%
1, 正解率 99.2%
2, 正解率 89.5%
3, 正解率 89.2%
4, 正解率 87.6%
5, 正解率 82.2%
6, 正解率 95.1%
7, 正解率 86.8%
8, 正解率 85.7%
9, 正解率 88.9%
n = 28
for tᵢ in range(3):
res = []
y = data_test.iloc[tᵢ].values
y = y / nl.norm(y)
for dᵢ in range(10):
A = base[:,3*dᵢ:3*dᵢ+3]
x = nl.solve(A.T @ A, A.T @ y)
res.append(nl.norm(y - A@x)/nl.norm(y))
fig, ax = plt.subplots(nrows=1,ncols=2,figsize=(8,4))
ax[0].imshow(data_test.iloc[tᵢ].values.reshape(n,n), cmap="gray")
ax[1].plot(range(10), res)
print("正解は%d,予測は%d" % (label_test.iloc[tᵢ], np.argmin(res)))
plt.show(); plt.close()
正解は7,予測は7
正解は2,予測は2
正解は1,予測は1














コメント
コメントを投稿