勉強しないとな~blog

ちゃんと勉強せねば…な電気設計エンジニアです。

OpenCVやってみる - 36. SVMで数字判定

前回からの続きです。
Jupyter notebookのデータ等の状態も引き継いでいます。
前回は、春のパン祭りシール台紙の画像から点数数字の輪郭を取得、数字テンプレートと比較して一致度を出して、固定の閾値で判定してみましたが、いまいちきっちりとは判定できませんでした。

検討してみたところ、結果的にOpenCVに含まれているSVM(Support Vector Machine)のライブラリを使うことでうまくいくようになりました。

判定方法検討

前回は抽出した点数数字の輪郭から、各数字のテンプレートと比較して、一致度が閾値を超えたもの(の中で最大の一致度を取る数字)を選びました。
各数字ごとへの一致度をそれぞれ単純に使っているだけですが、それより他の数字への一致度も考慮して判定するとよさそうな気がします。

各数字への一致度のベクトルを特徴量として、機械学習的な手法が使えるかも。

OpenCVの中にも機械学習のライブラリがあって、OpenCVチュートリアルでもいくつか紹介されているので、どれかでやってみたいと思います。

https://docs.opencv.org/4.5.2/d6/de2/tutorial_py_table_of_contents_ml.html

  • K-Nearest-Neighbor (kNN)
    特徴量空間に既知のラベルを持ったデータをマッピングし、新しいデータに対して、k番目までの最近傍データを選び、その中で多数決を行う
  • Support Vector Machine (SVM)
    特徴量空間で2つのクラスに分類する最適な平面(2次元なら直線)を選ぶ
    多クラスなら、One-to-OneかOne-to-RestのSVMを必要分用意する形になるよう
    https://www.baeldung.com/cs/svm-multiclass-classification
    OpenCVのライブラリではそこを自分で実装する必要はなく、多クラス分類をしてくれるようでした。
  • K-Means
    教師データなしでデータを分類する。
    何クラスに分けるかを指定して、特徴量空間でクラスごとの代表点を反復的な手法で選ぶ。一番近い代表点が所属するクラスとなる。

今回やっていることとしては、K-Meansは合わないかなと。

kNNは学習のコストは低いが、推論時の処理が重くなりがちとのこと。(全教師データ点に対しての特徴量距離を求める必要があるため)

SVMのほうが推論が軽そうなので、こちらを使いたいと思います。
実際のアプリケーションのイメージとしては、

  • スマホのカメラで画像データを連続的に取得
  • この画像に対してリアルタイムで点数計算を実施
  • 計算した点数と、どのように点数を認識したか、というのを画面に表示
  • 撮影条件によっておそらく正しくない結果が出る
  • ユーザが数字が正しく認識できている、と思ったら確定ボタンを押す

というものなので、入力データに対してすぐに判定できるのが望ましいです。

まずデータの確認

SVMを実施する前に、一致度データと実際の数字がどのような関係にあるのか見てみたいと思います。

といっても、一致度データは6次元ベクトルなので、完全に図示するのは難しく。なので平面グラフで見られる範囲、つまり2種類の数字の輪郭データを、それぞれへの一致度を使って見てみたいと思います。

年ごとにある程度状況が変わるかもしれないので、年ごとに分けて見てみます。

sims = similarities1 + similarities2
labels = labels1 + labels2
one_vs_zero_2019 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==0]
one_vs_two_2019 = [(sims[i][1], sims[i][2], label) for i,label in enumerate(labels) if label==1 or label==2]
one_vs_three_2019 = [(sims[i][1], sims[i][3], label) for i,label in enumerate(labels) if label==1 or label==3]
one_vs_five_2019 = [(sims[i][1], sims[i][4], label) for i,label in enumerate(labels) if label==1 or label==5]
one_vs_else_2019 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==-1]

sims = similarities3 + similarities4
labels = labels3 + labels4
one_vs_zero_2020 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==0]
one_vs_two_2020 = [(sims[i][1], sims[i][2], label) for i,label in enumerate(labels) if label==1 or label==2]
one_vs_five_2020 = [(sims[i][1], sims[i][4], label) for i,label in enumerate(labels) if label==1 or label==5]
one_vs_else_2020 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==-1]

sims = similarities5 + similarities6 + similarities7
labels = labels5 + labels6 + labels7
one_vs_zero_2021 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==0]
one_vs_two_2021 = [(sims[i][1], sims[i][2], label) for i,label in enumerate(labels) if label==1 or label==2]
one_vs_five_2021 = [(sims[i][1], sims[i][4], label) for i,label in enumerate(labels) if label==1 or label==5]
one_vs_else_2021 = [(sims[i][1], sims[i][0], label) for i,label in enumerate(labels) if label==1 or label==-1]

one_vs_zero = [one_vs_zero_2019, one_vs_zero_2020, one_vs_zero_2021]
one_vs_two = [one_vs_two_2019, one_vs_two_2020, one_vs_two_2021]
one_vs_three = [one_vs_three_2019]
one_vs_five = [one_vs_five_2019, one_vs_five_2020, one_vs_five_2021]
one_vs_else = [one_vs_else_2019, one_vs_else_2020, one_vs_else_2021]

years = ['2019', '2020', '2021']

plt.figure(figsize=(6.4,2.4), dpi=100)
plt.suptitle('One vs Zero', y=1.1)
for i,a in enumerate(one_vs_zero):
    x = [b[0] for b in a]
    y = [b[1] for b in a]
    c = [float(b[2]) for b in a]
    plt.subplot(1,3,1+i), plt.scatter(x,y,c=c), plt.title(years[i])
plt.show()

plt.figure(figsize=(6.4,2.4), dpi=100)
plt.suptitle('One vs Two', y=1.1)
for i,a in enumerate(one_vs_two):
    x = [b[0] for b in a]
    y = [b[1] for b in a]
    c = [float(b[2]) for b in a]
    plt.subplot(1,3,1+i), plt.scatter(x,y,c=c), plt.title(years[i])
plt.show()

plt.figure(figsize=(6.4,2.4), dpi=100)
plt.suptitle('One vs Three', y=1.1)
for i,a in enumerate(one_vs_three):
    x = [b[0] for b in a]
    y = [b[1] for b in a]
    c = [float(b[2]) for b in a]
    plt.subplot(1,3,1+i), plt.scatter(x,y,c=c), plt.title(years[i])
plt.show()

plt.figure(figsize=(6.4,2.4), dpi=100)
plt.suptitle('One vs Five', y=1.1)
for i,a in enumerate(one_vs_five):
    x = [b[0] for b in a]
    y = [b[1] for b in a]
    c = [float(b[2]) for b in a]
    plt.subplot(1,3,1+i), plt.scatter(x,y,c=c), plt.title(years[i])
plt.show()

plt.figure(figsize=(6.4,2.4), dpi=100)
plt.suptitle('One vs else', y=1.1)
for i,a in enumerate(one_vs_else):
    x = [b[0] for b in a]
    y = [b[1] for b in a]
    c = [float(b[2]) for b in a]
    plt.subplot(1,3,1+i), plt.scatter(x,y,c=c), plt.title(years[i])
plt.show()

f:id:nokixa:20220224034730p:plain

f:id:nokixa:20220224034732p:plain

f:id:nokixa:20220224034734p:plain

f:id:nokixa:20220224034737p:plain

f:id:nokixa:20220224034739p:plain

"1"とその他の数字を比較してみましたが、だいたい数字ごとに特徴量ベクトルが固まって分布しているよう。
ただ、どの数字でもない輪郭の分布とは重なってしまっています。

SVMで、他の数字への一致度を使ってうまく識別できればいいなと。

ちなみに"5"で1つ変な位置にある点は、輪郭検出時に"点"の文字まで含まれてしまったものかと考えられます。
これは学習データから除外しておかないと。

SVM試し

まずは一度SVMでどんな結果が出てくるのか、一部のデータで試してみたいと思います。
"1"のデータと"2"のデータを使ってみます。
特徴量ベクトルも、"1"と"2"への一致度のみにしてみます。
あと学習データ数は各数字で10としています。

リストのコピーでは、デフォルトでは参照渡しになるようで、何かやっているうちに元のデータを変更してしまいそう。Jupyter notebook上で行ったり来たりして色々試しているので、これだと都合が悪いので、copyモジュールのdeepcopy()を使いました。

https://murashun.jp/article/programming/python/python-list-copy-deepcopy.html

あとはrandomモジュールのsample()を使ってランダムサンプルを行いました。

https://note.nkmk.me/python-random-choice-sample-choices/

import copy
import random

all_vectors = copy.deepcopy(similarities1 + similarities2 + similarities3
                             + similarities4 + similarities5 + similarities6 + similarities7)
all_labels = copy.deepcopy(labels1 + labels2 + labels3 + labels4 + labels5 + labels6 + labels7)

# Remove inadequate contour data in img1
del all_vectors[30]
del all_labels[30]

numbers = [0, 1, 2, 3, 5]
labels = [-1] + numbers
selected_labels = [1, 2]

# Select feature vector elements to use
all_vectors = [[d for i,d in enumerate(vec) if numbers[i] in selected_labels] for vec in all_vectors]

n_train_data = 10
train_data = []
train_labels = []
n_test_data = 10
test_data = []
test_labels = []

for lab in selected_labels:
    samples = [vec for i,vec in enumerate(all_vectors) if all_labels[i]==lab]
    n = min(n_train_data, len(samples))
    train_data += random.sample(samples, n)
    train_labels += [lab] * n
    n = min(n_test_data, len(samples))
    test_data += random.sample(samples, n)
    test_labels += [lab] * n

[print(np.array(train_data[i]), ', ', train_labels[i]) for i in range(len(train_data))]
svm = cv2.ml.SVM_create()
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setType(cv2.ml.SVM_C_SVC)
svm.setC(1)
svm.setGamma(1)
svm.train(np.array(train_data, 'float32'), cv2.ml.ROW_SAMPLE, np.array(train_labels));

result = svm.predict(np.array(test_data, 'float32'))
print('SVM predict result: ')
print(result)
print('Comparison: ')
for i in range(len(test_labels)):
    print(result[1][i], ' - ', test_labels[i])
[0.8993755  0.82889575] ,  1
[0.9571013 0.8108691] ,  1
[0.93051445 0.79986185] ,  1
[0.9202969 0.8210506] ,  1
[0.94315743 0.8116947 ] ,  1
[0.9537984 0.8174602] ,  1
[0.9233544 0.7887045] ,  1
[0.93338346 0.8063096 ] ,  1
[0.90709674 0.8227937 ] ,  1
[0.94739175 0.79242164] ,  1
[0.8452033 0.8994955] ,  2
[0.82478577 0.939384  ] ,  2
[0.8292844 0.9430692] ,  2
[0.83571035 0.9575302 ] ,  2
[0.8436054 0.9455434] ,  2
[0.8402701  0.93678236] ,  2
[0.83998835 0.8954112 ] ,  2
[0.8377397  0.93956023] ,  2
[0.83261627 0.9558522 ] ,  2
[0.83023584 0.93284434] ,  2
SVM predict result: 
(0.0, array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.]], dtype=float32))
Comparison: 
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[1.]  -  1
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2
[2.]  -  2

ひとまずきちんと分類できているようです。
predict()の返り値は、よく分からない"0"という値と、入力データごとに推定したラベルになっています。

SVM少し掘り下げ

どんなSVMの分類器が得られたのか、以下の関数で調べられます。

  • getSupportVectors()
  • getDecisionFunction()
  • getUncompressedSupportVectors()

一応公式ドキュメントに説明がありましたが、いまいちよく分からず。
試しにやってみて、どんなものなのか確認してみたいと思います。

SVM Class Reference

print('getSupportVectors: ')
print(svm.getSupportVectors())
print('getDecisionFunction: ')
print(svm.getDecisionFunction(0))
print('getUncompressedSupportVectors: ')
print(svm.getUncompressedSupportVectors())
getSupportVectors: 
[[ 0.95603085 -1.245411  ]]
getDecisionFunction: 
(-0.23778849840164185, array([[1.]]), array([[0]], dtype=int32))
getUncompressedSupportVectors: 
[[0.8993755  0.82889575]
 [0.9571013  0.8108691 ]
 [0.93051445 0.79986185]
 [0.9202969  0.8210506 ]
 [0.94315743 0.8116947 ]
 [0.9537984  0.8174602 ]
 [0.9233544  0.7887045 ]
 [0.93338346 0.8063096 ]
 [0.90709674 0.8227937 ]
 [0.94739175 0.79242164]
 [0.8452033  0.8994955 ]
 [0.82478577 0.939384  ]
 [0.8292844  0.9430692 ]
 [0.83571035 0.9575302 ]
 [0.8436054  0.9455434 ]
 [0.8402701  0.93678236]
 [0.83998835 0.8954112 ]
 [0.8377397  0.93956023]
 [0.83261627 0.9558522 ]
 [0.83023584 0.93284434]]

getDecisionFunction()では、引数として決定関数(という呼び方でいいか?)のインデックスを与える必要がありますが、今回は2クラスへの分類なので、決定関数は1つだけになり、今回は0を与えています。

  • getUncompressedSupportVector()では、実際の推論に使われる圧縮されたサポートベクタの元となるサポートベクタが得られる、と書かれています。
    上の結果を見ると、どうも学習に使ったデータがそのまま出てきているよう。
    OpenCVチュートリアルのSVMのページを見ると、決定境界を決めるには学習用データ全てが必要というわけではなく、境界近くのデータだけあればいいよう。
    学習用データをもっと増やした場合、必要なデータだけに絞られるのか?学習用データが固まり過ぎているから全ての学習用データが出てきてしまったのか?

  • getSupportVectors()でサポートベクタが得られる、と書かれていますが、SVMの説明を見ると、サポートベクタとは境界近くのデータ点と書かれています。ただ、結果を見る限りこれはデータ点ではなく、決定境界を示す重みベクトルになっているような。getDecisionFunction()では、(retval, alpha, svidx)という形の返り値が得られますが、retvalが決定関数のバイアス項になるようです。

決定関数について、もう一度確認してみます。

w = svm.getSupportVectors()[0]
ret,alpha,svidx = svm.getDecisionFunction(0)
b = ret
for i,d in enumerate(test_data):
    val = w @ d - b
    predicted = svm.predict(np.reshape(np.array(d, 'float32'), (1,-1)))
    print('label: ', test_labels[i]
          , ', Function output: ', val
          , ', Predicted: ', predicted[1][0])
label:  1 , Function output:  0.09859859943389893 , Predicted:  [1.]
label:  1 , Function output:  0.11231565475463867 , Predicted:  [1.]
label:  1 , Function output:  0.08623391389846802 , Predicted:  [1.]
label:  1 , Function output:  0.16496700048446655 , Predicted:  [1.]
label:  1 , Function output:  0.13821208477020264 , Predicted:  [1.]
label:  1 , Function output:  0.09507524967193604 , Predicted:  [1.]
label:  1 , Function output:  0.13989132642745972 , Predicted:  [1.]
label:  1 , Function output:  0.16536468267440796 , Predicted:  [1.]
label:  1 , Function output:  0.1658780574798584 , Predicted:  [1.]
label:  1 , Function output:  0.1524949073791504 , Predicted:  [1.]
label:  2 , Function output:  -0.0744127631187439 , Predicted:  [2.]
label:  2 , Function output:  -0.14605987071990967 , Predicted:  [2.]
label:  2 , Function output:  -0.15364772081375122 , Predicted:  [2.]
label:  2 , Function output:  -0.2124718427658081 , Predicted:  [2.]
label:  2 , Function output:  -0.1557653546333313 , Predicted:  [2.]
label:  2 , Function output:  -0.12564247846603394 , Predicted:  [2.]
label:  2 , Function output:  -0.15429812669754028 , Predicted:  [2.]
label:  2 , Function output:  -0.13328897953033447 , Predicted:  [2.]
label:  2 , Function output:  -0.15880388021469116 , Predicted:  [2.]
label:  2 , Function output:  -0.12945902347564697 , Predicted:  [2.]

やっぱりgetSupportVectors()で重みベクトルが得られて、getDecisionFunction()でバイアス項が得られるようです。
入力ベクトルに重みベクトルを掛けて、バイアス項を引いてやると、判定値が得られて、この正負で判定するものと思われます。

3クラスの分類もやってみます。"1"、"2"、"5"の数字を使います。
今度は特徴量ベクトルとしては5つの数字への一致度全てを使ってみます。

all_vectors = copy.deepcopy(similarities1 + similarities2 + similarities3
                             + similarities4 + similarities5 + similarities6 + similarities7)
all_labels = copy.deepcopy(labels1 + labels2 + labels3 + labels4 + labels5 + labels6 + labels7)

# Remove inadequate contour data in img1
del all_vectors[30]
del all_labels[30]

def get_random_sample(data_in, labels_in, selected_labels, n_samples):
    data_rtn = []
    labels_rtn = []
    for lab in selected_labels:
        samples = [d for i,d in enumerate(data_in) if labels_in[i]==lab]
        n = min(n_samples, len(samples))
        data_rtn += random.sample(samples, n)
        labels_rtn += [lab] * n
    return data_rtn, labels_rtn

train_data, train_labels = get_random_sample(all_vectors, all_labels, [1,2,5], 10)
test_data, test_labels = get_random_sample(all_vectors, all_labels, [1,2,5], 10)

[print(np.array(train_data[i]), ', ', train_labels[i]) for i in range(len(train_data))]
svm = cv2.ml.SVM_create()
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setType(cv2.ml.SVM_C_SVC)
svm.setC(1)
svm.setGamma(1)
svm.train(np.array(train_data, 'float32'), cv2.ml.ROW_SAMPLE, np.array(train_labels));

result = svm.predict(np.array(test_data, 'float32'))
print('Comparison: ')

# Dictionary containing number of correct answers and number of same labels
svm_results = {-1:[0,0], 0:[0,0], 1:[0,0], 2:[0,0], 3:[0,0], 5:[0,0]}
for i,lab in enumerate(test_labels):
    if result[1][i] == lab:
        svm_results[lab][0] += 1
    svm_results[lab][1] += 1
for k,v in svm_results.items():
    print(k, ': ', v[0], ' / ', v[1])
[0.8684619  0.93794453 0.82305175 0.80380815 0.7853995 ] ,  1
[0.863555   0.9393367  0.8205488  0.7917124  0.78287023] ,  1
[0.8419271  0.9571013  0.8108691  0.77983546 0.7980355 ] ,  1
[0.83179814 0.9788645  0.7952227  0.7923971  0.7836161 ] ,  1
[0.870909  0.9359166 0.8235449 0.8104057 0.7868414] ,  1
[0.8693162  0.9437623  0.82522047 0.8118913  0.77726704] ,  1
[0.871623   0.9311241  0.81550264 0.81565213 0.79558474] ,  1
[0.8750252  0.91742384 0.78483945 0.80920565 0.8159622 ] ,  1
[0.8761335  0.9543413  0.8270361  0.80594933 0.78908885] ,  1
[0.8319478 0.894798  0.7914398 0.7971767 0.7837865] ,  1
[0.7703252  0.8415919  0.9335647  0.81709546 0.7735364 ] ,  2
[0.7810648  0.83210677 0.9345461  0.8076761  0.80540264] ,  2
[0.7528915  0.83172077 1.         0.80888104 0.80121213] ,  2
[0.779265   0.83273137 0.92147803 0.81171924 0.77139753] ,  2
[0.73760915 0.83789265 0.93405205 0.7996441  0.79096395] ,  2
[0.7610472  0.8352113  0.93603534 0.7986501  0.78484255] ,  2
[0.77799076 0.8399149  0.9400889  0.80451334 0.80509204] ,  2
[0.7862375  0.82478577 0.939384   0.8073074  0.7773469 ] ,  2
[0.7736371  0.83998835 0.8954112  0.8008472  0.76628774] ,  2
[0.7727292  0.8402701  0.93678236 0.8075204  0.78538513] ,  2
[0.8218305  0.8154667  0.7384431  0.8331704  0.92029697] ,  5
[0.8530085  0.7790276  0.72586    0.7505195  0.92765796] ,  5
[0.84800655 0.7884924  0.74963397 0.8336336  0.9487194 ] ,  5
[0.8432178  0.80467755 0.763133   0.8533464  0.93024945] ,  5
[0.855599   0.8009213  0.7863985  0.83297044 0.91814077] ,  5
[0.83844346 0.78525823 0.7594693  0.83071625 0.9581587 ] ,  5
[0.85112804 0.82673436 0.7607567  0.83964443 0.93797547] ,  5
[0.825503   0.7873881  0.73949605 0.83388895 0.9183223 ] ,  5
[0.8455853  0.7865896  0.7510412  0.77967346 0.938463  ] ,  5
[0.8390334  0.78312147 0.7487837  0.8266298  0.9461081 ] ,  5
Comparison: 
-1 :  0  /  0
0 :  0  /  0
1 :  10  /  10
2 :  10  /  10
3 :  0  /  0
5 :  10  /  10
print('getSupportVectors: ')
print(svm.getSupportVectors())
print('getDecisionFunction: ')
[print(svm.getDecisionFunction(i)) for i in range(svm.getSupportVectors().shape[0])]
print('getUncompressedSupportVectors: ')
print(svm.getUncompressedSupportVectors())
getSupportVectors: 
[[ 0.74236786  1.1534866  -1.2011049  -0.05879921  0.15569824]
 [ 0.13425273  1.5126095   0.5482252  -0.10417938 -1.4019974 ]
 [-0.60811514  0.35912287  1.7493302  -0.04538018 -1.5576956 ]]
getDecisionFunction: 
(0.6581481695175171, array([[1.]]), array([[0]], dtype=int32))
(0.5590379238128662, array([[1.]]), array([[1]], dtype=int32))
(-0.08549034595489502, array([[1.]]), array([[2]], dtype=int32))
getUncompressedSupportVectors: 
[[0.8358994  0.97376585 0.80774623 0.7944305  0.79416436]
 [0.8387221  0.94739175 0.79242164 0.7958192  0.79904383]
 [0.8761335  0.9543413  0.8270361  0.80594933 0.78908885]
 [0.8625072  0.9391679  0.8229808  0.81084424 0.80457014]
 [0.8419271  0.9571013  0.8108691  0.77983546 0.7980355 ]
 [0.8842514  0.93677527 0.82215434 0.8002985  0.8105224 ]
 [0.82942396 0.9743748  0.7959507  0.798544   0.77946776]
 [0.84554917 0.9603941  0.80359674 0.800022   0.81536746]
 [0.8716829  0.9202969  0.8210506  0.8102947  0.7935237 ]
 [0.88022274 0.92476493 0.8195491  0.8112865  0.7874757 ]
 [0.77455294 0.8228092  0.93951803 0.8037939  0.79569256]
 [0.7871597  0.8369811  0.91662616 0.81502336 0.77930087]
 [0.7602673  0.83667743 0.92887866 0.80211806 0.7976724 ]
 [0.7736371  0.83998835 0.8954112  0.8008472  0.76628774]
 [0.76749974 0.83696306 0.9345461  0.8042174  0.80406123]
 [0.79241073 0.8335312  0.9546793  0.80434114 0.8065423 ]
 [0.78609586 0.8381242  0.93519616 0.80665094 0.7609726 ]
 [0.7798084  0.83231604 0.9123746  0.8087669  0.7738839 ]
 [0.7915582  0.82690877 0.9614248  0.81141484 0.76603115]
 [0.81096166 0.83058804 0.9458052  0.80894995 0.76511675]
 [0.825503   0.7873881  0.73949605 0.83388895 0.9183223 ]
 [0.8522377  0.75933874 0.72682977 0.7764713  0.894319  ]
 [0.85112804 0.82673436 0.7607567  0.83964443 0.93797547]
 [0.8366933  0.76728594 0.7461595  0.79035497 0.94660616]
 [0.83844346 0.78525823 0.7594693  0.83071625 0.9581587 ]
 [0.8390334  0.78312147 0.7487837  0.8266298  0.9461081 ]
 [0.84800655 0.7884924  0.74963397 0.8336336  0.9487194 ]
 [0.84835213 0.8208981  0.78807247 0.7563367  0.9346628 ]
 [0.84590787 0.83696854 0.76865834 0.85480416 0.9347266 ]
 [0.84676135 0.8202787  0.78727037 0.76902366 0.95365864]]
w = svm.getSupportVectors()
dfs = [svm.getDecisionFunction(i) for i in range(3)]
b = np.array([df[0] for df in  dfs])
for i,d in enumerate(test_data):
    val = w @ d - b
    predicted = svm.predict(np.reshape(np.array(d, 'float32'), (1,-1)))
    print('label: ', test_labels[i]
          , ', Function outputs: ', val
          , ', Predicted: ', predicted[1][0])
label:  1 , Function outputs:  [0.18518424 0.25113189 0.05232799] , Predicted:  [1.]
label:  1 , Function outputs:  [0.15800738 0.19747007 0.02584302] , Predicted:  [1.]
label:  1 , Function outputs:  [ 0.19115281  0.18169904 -0.02307355] , Predicted:  [1.]
label:  1 , Function outputs:  [0.2127431  0.26644951 0.04008663] , Predicted:  [1.]
label:  1 , Function outputs:  [0.12235677 0.22019732 0.08422077] , Predicted:  [1.]
label:  1 , Function outputs:  [0.18681097 0.24292523 0.04249454] , Predicted:  [1.]
label:  1 , Function outputs:  [0.15254593 0.21868181 0.05251586] , Predicted:  [1.]
label:  1 , Function outputs:  [0.13946629 0.23918605 0.0861001 ] , Predicted:  [1.]
label:  1 , Function outputs:  [0.20872855 0.28805488 0.06570649] , Predicted:  [1.]
label:  1 , Function outputs:  [0.15355039 0.24784851 0.08067822] , Predicted:  [1.]
label:  2 , Function outputs:  [-0.17680824  0.11696589  0.28015423] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.16753972  0.11126626  0.26518631] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.16002786  0.1760323   0.32244027] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.18941259  0.18797886  0.36377156] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.1927557   0.11770147  0.29683733] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.16306555  0.10350084  0.25294673] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.19990551  0.16659194  0.35287762] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.17447698  0.11759734  0.27845466] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.16249359  0.16773802  0.31661165] , Predicted:  [2.]
label:  2 , Function outputs:  [-0.11817181  0.14852113  0.25307298] , Predicted:  [2.]
label:  5 , Function outputs:  [ 0.07100189 -0.23229852 -0.31692028] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.05144411 -0.23986143 -0.30492544] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.09993356 -0.1874423  -0.30099571] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.07407373 -0.20341051 -0.29110408] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.10216659 -0.2469826  -0.36276901] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.05273694 -0.28652012 -0.35287714] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.06737787 -0.26389527 -0.34489298] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.10729289 -0.15760526 -0.27851796] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.06865722 -0.22615045 -0.30842745] , Predicted:  [5.]
label:  5 , Function outputs:  [ 0.06620628 -0.18782762 -0.26765358] , Predicted:  [5.]

今回は3クラスの分類なので、重みベクトルおよびバイアス項は3つずつあります。
重みベクトルの分布を見ると、なんとなくどの分類をするものか分かります。

  • 1行目: "1"、"2"の分類 (正の値->"1"、負の値->"2")
  • 2行目: "1"、"5"の分類 (正の値->"1"、負の値->"5")
  • 3行目: "2"、"5"の分類 (正の値->"2"、負の値->"5")

全体データでSVM

今度は全体のデータを対象にしてSVM学習、推論をやってみたいと思います。
学習データ数は上と同じく各数字10としておきます。

all_vectors = copy.deepcopy(similarities1 + similarities2 + similarities3
                             + similarities4 + similarities5 + similarities6 + similarities7)
all_labels = copy.deepcopy(labels1 + labels2 + labels3 + labels4 + labels5 + labels6 + labels7)

# Remove inadequate contour data in img1
del all_vectors[30]
del all_labels[30]

train_data, train_labels = get_random_sample(all_vectors, all_labels, [-1,0,1,2,3,5], 10)

svm = cv2.ml.SVM_create()
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setType(cv2.ml.SVM_C_SVC)
svm.setC(1)
svm.setGamma(1)
svm.train(np.array(train_data, 'float32'), cv2.ml.ROW_SAMPLE, np.array(train_labels));

result = svm.predict(np.array(all_vectors, 'float32'))

# Dictionary containing number of correct answers and number of same labels
svm_results = {-1:[0,0], 0:[0,0], 1:[0,0], 2:[0,0], 3:[0,0], 5:[0,0]}
for i,lab in enumerate(all_labels):
    if result[1][i] == lab:
        svm_results[lab][0] += 1
    svm_results[lab][1] += 1
for k,v in svm_results.items():
    print(k, ': ', v[0], ' / ', v[1])
-1 :  60  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  0  /  2
5 :  29  /  29

だいたい正確に推論できているようです。
ただし、

  • どの数字でない輪郭で、どれかの数字として認識されてしまっているものがある
  • "3"は正しく推論されていない

という問題点があります。

1つ目の問題は、もっと前段階で候補輪郭を絞ることで対応したいなと。

2つ目の問題は、いくつか試したところ、SVMの"C"の値を変更することで対応できました。
SVMで決定境界を決めるときに、コスト関数として決定境界のマージンの大きさと誤分類の数を考慮しますが、"C"は誤分類数へのウェイトになるようで、この2項目のバランスを決定します。

Understanding SVM

今回は"3"のサンプル数が少なく、軽視されてしまったものと考えられます。
なので、"C"を思い切り大きくしてみます。

Cs = [1, 5, 10, 50, 100, 200]

for C in Cs:
    svm.setC(C)
    svm.train(np.array(train_data, 'float32'), cv2.ml.ROW_SAMPLE, np.array(train_labels));
    result = svm.predict(np.array(all_vectors, 'float32'))

    # Dictionary containing number of correct answers and number of same labels
    svm_results = {-1:[0,0], 0:[0,0], 1:[0,0], 2:[0,0], 3:[0,0], 5:[0,0]}
    for i,lab in enumerate(all_labels):
        if result[1][i] == lab:
            svm_results[lab][0] += 1
        svm_results[lab][1] += 1
    print('C: ', C)
    for k,v in svm_results.items():
        print(k, ': ', v[0], ' / ', v[1])
    print('')
C:  1
-1 :  60  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  0  /  2
5 :  29  /  29

C:  5
-1 :  60  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  0  /  2
5 :  29  /  29

C:  10
-1 :  58  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  0  /  2
5 :  29  /  29

C:  50
-1 :  75  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  2  /  2
5 :  29  /  29

C:  100
-1 :  76  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  2  /  2
5 :  29  /  29

C:  200
-1 :  77  /  89
0 :  27  /  27
1 :  78  /  78
2 :  39  /  39
3 :  2  /  2
5 :  29  /  29

"C"が50以上で正しく"3"を判定できています。また、数字でない輪郭の判別精度も少し上がっているようです。
今回は"C"の値としては100を選んでおきたいと思います。

ここまで

SVMを使って数字判定ができるようになりました。
今回はここで一旦区切りたいと思います。

次回は残っている項目の検討になります。

  • ICP前の初期変換行列での判定
  • ICP収束条件