画像認識で、少し大きめなネットワークモデルを使って学習を行うときには、ゼロから学習を行うことはまずありません。学習済みのウェイトは広く公開されていますので、こうしたウェイト(pretrained weight)を初期値としてロードしてから学習を始めることが多いでしょう。

ちなみにFacebookの研究所から、初期値として既学習モデルを使っても、ランダムな値を用いても、最終精度はそんなに変わらないという論文(https://arxiv.org/abs/1811.08883)も出ているのですが、 初期段階での収束が早いことは学習状況の確認の観点からはありがたく、個人的にはいつも初期値を入れています。

初期値として最も一般的なのは、VGG16やResNet50といったモデルのImageNet学習ウェイトを使うものです。元々VGGやResNetは画像分類を行うモデルですが、物体検知やセグメンテーションでもベースネットとしてよく用いられます。また、入力画素数や分類クラス数はImageNetと一致しないのが普通なので、最終段の全結合層以外の部分のウェイトを使うことがほとんどです。

これ、前にも書いたかもだし、考えたら当然なのですが、「畳み込み層のウェイトは入力画素数が変わっても使えます」。ImageNetの画素数は224x224ですが、もっと大きな入力画素数を持つネットワークの初期ウェイトとしても十分役に立ってくれます。

さて、このImageNetによる学習、一度は自分でやってみたいと思いませんか?いつも他人が学習した魔法の初期値を使うというのはすっきりしない。

ところがこのImageNet学習、やってみると結構大変です。(第一にImageNetデータを入手する必要がありますが、これは個人でもちょっと調べれば手に入ると思います)

ImageNetは1000クラスに対して約1000枚ずつの学習画像があります。つまり約100万枚が1エポック。それなりの計算リソースがないといつまでたっても学習が進まない。GPUが必須なのは言うまでもないですが、意外と忘れがちなのがストレージの速度。ImageNet学習の場合は画像データをSSDなどの高速ストレージに置かないとIOが詰まってしまいます。

そういうことを諸々乗り越え、さぁ学習するぞと思ってもそれでもうまくいかないこともあります。特にVGGのようなパラメータの多いモデルのImageNet学習は、うまく収束させられずに何度か挫折した記憶があります。(^^;;

しかし我らに救う神あり。PyTorchにはImageNetの学習コードがexampleとして用意されているではないですか!

https://github.com/pytorch/examples/tree/master/imagenet

ここを参考にすることがImageNetを使った学習の近道に違いない。

PyTorchにはTorchVisionという画像パッケージがあって、いろいろな分類モデルの学習済みウェイトも提供されているのですが、MobileNetV1については提供が無いようです。なので今回はMobileNetV1のImageNet学習を行ってみることにします。

PyTorchのImageNet学習コードにMobileNetV1のモデルを追加し、optimizerや、学習率の変移、ウェイトの初期化、ウェイトの保存などを変更したコードおよび学習したウェイトを評価するコードをGitHubに置いておきます。

https://github.com/ponta256/train-mobilenet-w-imagenet

以下は4GPUのサーバでの実行例なので環境に応じて引数は調整してください。

root@5494fd53ca8e:~# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_basenet.py -a mobilenet /mnt/ssd/imagenet/ --batch-size=2048 --lr=0.001 --workers=10
 Epoch: [0][  0/626]     Time 44.410 (44.410)    Data 28.982 (28.982)    Loss 6.9200e 00 (6.9200e 00)   Acc@1   0.15 (  0.15)   Acc@5   0.54 (  0.54)
 Epoch: [0][ 10/626]     Time  1.189 ( 5.240)    Data  0.000 ( 2.635)    Loss 6.8801e 00 (6.9207e 00)   Acc@1   0.20 (  0.17)   Acc@5   1.12 (  0.79)
 Epoch: [0][ 20/626]     Time 11.564 ( 3.812)    Data 11.140 ( 1.950)    Loss 6.8466e 00 (6.8963e 00)   Acc@1   0.29 (  0.19)   Acc@5   0.93 (  0.87)
 Epoch: [0][ 30/626]     Time 15.662 ( 3.438)    Data 15.232 ( 1.840)    Loss 6.8108e 00 (6.8733e 00)   Acc@1   0.34 (  0.22)   Acc@5   0.98 (  0.92)
 Epoch: [0][ 40/626]     Time 15.529 ( 3.242)    Data 15.106 ( 1.779)    Loss 6.6936e 00 (6.8450e 00)   Acc@1   0.54 (  0.26)   Acc@5   1.90 (  1.07)
~ snip ~

学習中はGPUやCPUがきちんと回っていることを確認します。

root@5494fd53ca8e:~# top
top - 03:28:54 up 41 days, 23:38,  1 user,  load average: 6.91, 10.19, 9.30
Tasks:  16 total,  11 running,   5 sleeping,   0 stopped,   0 zombie
%Cpu(s):  7.1 us,  2.1 sy,  0.0 ni, 90.6 id,  0.2 wa,  0.0 hi,  0.0 si,  0.0 st
KiB Mem : 13194715 total, 49648152 free,  7063464 used, 75235536 buff/cache
KiB Swap: 13412659 total, 13403124 free,    95348 used. 12403100 avail Mem 
  PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME  COMMAND                    
 2734 root      20   0 40.932g 2.542g  97584 R 100.0  2.0   0:06.13 python3                    
 2736 root      20   0 40.926g 2.535g  97584 R 100.0  2.0   0:06.11 python3                    
 2737 root      20   0 40.942g 2.552g  97584 R 100.0  2.0   0:06.13 python3                    
 2739 root      20   0 40.911g 2.521g  97584 R 100.0  2.0   0:06.13 python3                    
 2740 root      20   0 40.912g 2.521g  97648 R 100.0  2.0   0:06.13 python3                    
 2741 root      20   0 40.893g 2.502g  97584 R 100.0  2.0   0:06.13 python3                    <
 2742 root      20   0 40.858g 2.468g  97584 R 100.0  2.0   0:06.13 python3                    
 2743 root      20   0 40.898g 2.507g  97648 R 100.0  2.0   0:06.12 python3                    
 2735 root      20   0 40.897g 2.507g  97648 R  93.3  2.0   0:06.14 python3                    
 2738 root      20   0 40.930g 2.540g  97584 R  93.3  2.0   0:06.12 python3     
~ snip ~
  root@5494fd53ca8e:~# nvidia-smi 
  Sun Jun  2 03:12:20 2019       
   ----------------------------------------------------------------------------- 
  | NVIDIA-SMI 410.78       Driver Version: 410.78       CUDA Version: 10.0     |
  |------------------------------- ---------------------- ---------------------- 
  | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
  | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
  |=============================== ====================== ======================|
  |   0  XXXXXX XXXXX        Off  | 00000000:05:00.0  On |                  Off |
  | 33%   72C    P0   184W / 250W |  24150MiB / 24448MiB |    100%      Default |
   ------------------------------- ---------------------- ---------------------- 
  |   1  XXXXXX XXXXX        Off  | 00000000:06:00.0 Off |                  Off |
  | 32%   72C    P0   176W / 250W |  23971MiB / 24449MiB |    100%      Default |
   ------------------------------- ---------------------- ---------------------- 
  |   2  XXXXXX XXXXX        Off  | 00000000:09:00.0 Off |                  Off |
  | 27%   68C    P0   161W / 250W |  23971MiB / 24449MiB |    100%      Default |
   ------------------------------- ---------------------- ---------------------- 
  |   3  XXXXXX XXXXX        Off  | 00000000:0A:00.0 Off |                  Off |
  | 25%   58C    P0   155W / 250W |  23971MiB / 24449MiB |    100%      Default |
   ------------------------------- ---------------------- ---------------------- 
  

学習は延々と続きますが、47エポック時点のウェイトを取り出して評価してみます。

Epoch: [47][610/626]	Time 11.078 ( 2.622)	Data 10.604 ( 1.644)	Loss 1.4328e 00 (1.4181e 00)	Acc@1  66.94 ( 66.38)	Acc@5  85.35 ( 85.66)
Epoch: [47][620/626]	Time  9.104 ( 2.616)	Data  8.685 ( 1.639)	Loss 1.4567e 00 (1.4186e 00)	Acc@1  65.62 ( 66.37)	Acc@5  84.57 ( 85.65)
Test: [ 0/25]	Time 31.851 (31.851)	Loss 1.0282e 00 (1.0282e 00)	Acc@1  74.41 ( 74.41)	Acc@5  91.70 ( 91.70)
Test: [10/25]	Time 24.389 ( 5.405)	Loss 1.7012e 00 (1.1890e 00)	Acc@1  60.16 ( 69.95)	Acc@5  83.35 ( 90.27)
Test: [20/25]	Time 17.614 ( 3.875)	Loss 1.8932e 00 (1.4850e 00)	Acc@1  56.64 ( 64.87)	Acc@5  78.56 ( 85.71)
 * Acc@1 64.728 Acc@5 85.690
$ python eval_basenet.py -a mobilenet --weight=model_best_weight.pth /mnt/ssd/imagenet/
~ snip ~
Test: [195/196]	Time  0.257 ( 0.692)	Loss 3.1799e 00 (1.4985e 00)	Acc@1  30.00 ( 64.79)	Acc@5  66.25 ( 85.88)
  * Acc@1 64.792 Acc@5 85.882

Top1スコアが64.8。MobileNetV1のオリジナル論文でのTop1スコアは70.6なので、学習経過の値としては良好と思います。

MobileNetはVGG16などに比べるとはるかに収束が速く学習しやすいと言えます。モデルによっては簡単に収束せずに苦労することもあると思いますが、基本は同じでいけるはず。

今回はImageNetのデータでモデルをスクラッチから学習させてみました。それほど頻度の高い作業ではないと思いますが、学習済みのウェイトが使えないときにはこんな風にImageNetデータで学習できるという参考にしていただければと思います。