改善 CNN 辨識率
以 mnist 為例
目的
學習影像辨識,多半會從手寫數字辨識開始,因為只有 0 — 9 十個數字需要辨識,辨識效果通常不會太差,程式寫起來也不複雜,再加上又有 mnist 這個現成的資料庫可以直接上手,所以 mnist 簡直可以說是學習人工智慧的 Hello World。
所以我的作法是先以最簡單的 NN 網路開始,然後是簡單的 CNN 網路,再來針對 CNN 作進一步的優化,最後 submit 到 Kaggle 的 Digit Recognizer competition。結果還不錯,在 Kaggle 的 leaderboard 上獲得 0.99857 的分數,總排名 39 名。
使用 Fully Connected NN 網路
一樣使用 keras 來開發,透過 keras 的函式下載 mnist 資料庫,mnist 的資料都是 28x28 的灰階圖形,所以設計 Fully Connected NN 網路時,會有 784 個輸入,然後中間使用 256 個 node 的隱藏層,最後是 10 個 node 的輸出。
訓練網路時使用 Cross Entropy 作為 loss 函式,Adam 作為 optimizer:
將 training set 分做兩部分,80% 作為訓練用,剩下的 20% 用來作驗證訓練結果,然後跑 10 次 epoch 得到訓練結果:
最後 val_acc 大概可以到 0.97 左右,相當不錯了,不過如果看曲線圖,你會發現已經 overfitting 了。
最後以 mnist 的測試資料來驗證,會得到大概 0.9766 的準確率:
不過 NN 不是重點,來試試 CNN 會不會得到比較好的結果。
NN 的程式碼放在這裡:
使用 CNN 網路
一開始設計一個 8 層的 CNN 網路,兩層捲積層,分別是 16 跟 32 個 filter:
訓練個 10 次 epoch 後,大概會得到 0.9899 的 val_acc,看起來已經比 NN 網路好了,訓練也不會在幾個 epoch後就 overfitting 了:
對測試資料的準確率也比 NN 網路高出 1.4 個百分點:
再偷偷看一下混淆矩陣:
程式碼放在這裡:
優化 CNN 網路
0.991 的準確率看起來也不錯,但有沒有辦法再提升一點。一開始學可以在 Kaggle competition 的 Kernels 參考其他人提出的方法,例如先試試用更深的 CNN 以及比較多的 epoch 試試:
總共 12 層的 CNN 網路,裡頭有 4 層的 Convolution Layer 來提取特徵,然後跑 20 epoch。不過這個比較深的 CNN 網路,我實測的結果是沒有比較好,我忘記確切數據了,但是用在 mnist 的判斷上增進效果有限。
另外一件可以作的事,是使用 Keras 中的 ImageDataGenerator 來多產生一些額外的影像資料作訓練。ImageDataGenerator 可以幫你的圖像作一些變化,例如旋轉、偏移中心、放大縮小等,多提供訓練資料的一些變異性,好處是讓訓練出來的模型可以抗旋轉、偏移等情況。
不過要作這件事,我們要先手動將 mnist 的訓練資料依 80%/20% 的比例分做訓練及驗證資料:
然後產生 generator。
參數如下:
rotation_range: 定義影像旋轉的角度區間,generator 會亂數來決定選轉角度
width_shift_range: 定義影像橫向偏移位置多寡,generator 會亂數決定偏移量
height_shift_range: 定義影像縱向偏移位置多寡,generator 會亂數決定偏移量
shear_range: 定義錯切角度
zoom_range: 定義放大縮小多寡
在訓練時改用 fit_generator,並且使用上述的 train_generator 作為訓練資料的來源 (先移除最後一行的 callbacks=[learning_rate_function] 參數)
:
在執行訓練的過程,大概會發現一個狀況:
從 epoch 6 開始,val_acc 一開始有提升,但是後面卻呈現震盪,有時候好有時候則比較差,所以猜測大概在使用梯度下降訓練參數時,因為固定 learning rate 沒有如預期的往最低點走,反而來回震盪的結果。關於這個問題可以參考 Andrew Ng 的解釋,透過下面這張圖也可以秒懂為什麼。
在訓練的過程中,梯度下降法 (Gradient Descent) 的目標就是要往最低點去優化,理想的結果應該是像圖中綠色點一樣,在幾次 iteration 後走到最低點。但如果我們的學習速率是固定的,結果會像紅色線一樣在兩邊來回震盪反而沒達到優化效果。但是如果一開始 learning rate 設的太小,收斂所需的時間會比較長,所以比較好的作法是根據執行的結果逐漸減少 learning rate。
這裡建立一個 ReduceLROnPlateau 物件,參數的設計是當 val_acc 在連續三個 epoch 都沒有下降時,將 Learning Rate 縮小一半,但最小不超過 0.00001。
然後將這個 Learning Rate Callback 指定給 fit_generator,在訓練的過程便會採用這個 Learning Rate 去作優化,部分執行結果如下:
從 epoch 17 開始,val_acc 不升反降,便在 epoch 19 時將 learning rate 調降一半,然後 val_acc 才又開始變好。
最後我們可以得到 0.996 的準確率,比一開始的 CNN 網路 0.991 高約 0.5 個百分點,雖然不多,但是已經越來越逼近 100% 的準確率了。
所有的程式碼與執行過程請參考:
結論
作這次的練習大致上可以掌握優化 CNN 的一些關鍵,其中包含增加訓練資料的變異性,以及透過 Learning Rate 的變動增加梯度下降法的效率。
還有沒有其他辦法能夠再更精進,或許可以再嘗試修改一些不同的參數,例如 batch size、epoch、optimizer 等等,多作一些實驗觀察結果看看。
這篇文章詳列一些可以參考的作法:
此外,我覺得多多參與 Kaggle 的 Competition 對於學習人工智慧很有幫助,你可以在 Kaggle 上參考別人的作法、思路來改善自己的程式,然後將辨識的結果上傳 Kaggle 有時會得到意料外的結果,像是這次我沒想到 submit 上去的結果可以排在前 50 名內 (雖然這並不是一個很 popular 的競賽),但是在剛學習的過程中能有這樣的結果會帶來不少成就感,也比較有動力會繼續下去。
目前為止寫的幾篇都是使用 Keras 作為開發框架,但我最近對 PyTorch 也蠻有興趣的,或許下一篇文章會試著用 PyTorch 來寫 mnist 的手寫數字辨識。
我現在也都還是在學習的階段,這篇文章花了我很多時間編寫,如果有任何錯誤的地方,歡迎幫我指正; 如果覺得我寫的好,那麼就給我一個「拍手」; 若是你覺得我寫的對你有幫助,可以給我五個「拍手」; 如果希望我繼續寫下去,那麼就給我超過五個以上的「拍手」吧,人畢竟需要靠一點點虛榮心才會有動力。