画像認識で、少し大きめなネットワークモデルを使って学習を行うときには、ゼロから学習を行うことはまずありません。学習済みのウェイトは広く公開されていますので、こうしたウェイト(pretrained weight)を初期値としてロードしてから学習を始めることが多いでしょう。
ちなみにFacebookの研究所から、初期値として既学習モデルを使っても、ランダムな値を用いても、最終精度はそんなに変わらないという論文(https://arxiv.org/abs/1811.08883)も出ているのですが、 初期段階での収束が早いことは学習状況の確認の観点からはありがたく、個人的にはいつも初期値を入れています。
初期値として最も一般的なのは、VGG16やResNet50といったモデルのImageNet学習ウェイトを使うものです。元々VGGやResNetは画像分類を行うモデルですが、物体検知やセグメンテーションでもベースネットとしてよく用いられます。また、入力画素数や分類クラス数はImageNetと一致しないのが普通なので、最終段の全結合層以外の部分のウェイトを使うことがほとんどです。
これ、前にも書いたかもだし、考えたら当然なのですが、「畳み込み層のウェイトは入力画素数が変わっても使えます」。ImageNetの画素数は224x224ですが、もっと大きな入力画素数を持つネットワークの初期ウェイトとしても十分役に立ってくれます。
さて、このImageNetによる学習、一度は自分でやってみたいと思いませんか?いつも他人が学習した魔法の初期値を使うというのはすっきりしない。
ところがこのImageNet学習、やってみると結構大変です。(第一にImageNetデータを入手する必要がありますが、これは個人でもちょっと調べれば手に入ると思います)
ImageNetは1000クラスに対して約1000枚ずつの学習画像があります。つまり約100万枚が1エポック。それなりの計算リソースがないといつまでたっても学習が進まない。GPUが必須なのは言うまでもないですが、意外と忘れがちなのがストレージの速度。ImageNet学習の場合は画像データをSSDなどの高速ストレージに置かないとIOが詰まってしまいます。
そういうことを諸々乗り越え、さぁ学習するぞと思ってもそれでもうまくいかないこともあります。特にVGGのようなパラメータの多いモデルのImageNet学習は、うまく収束させられずに何度か挫折した記憶があります。(^^;;
しかし我らに救う神あり。PyTorchにはImageNetの学習コードがexampleとして用意されているではないですか!
https://github.com/pytorch/examples/tree/master/imagenet
ここを参考にすることがImageNetを使った学習の近道に違いない。
PyTorchにはTorchVisionという画像パッケージがあって、いろいろな分類モデルの学習済みウェイトも提供されているのですが、MobileNetV1については提供が無いようです。なので今回はMobileNetV1のImageNet学習を行ってみることにします。
PyTorchのImageNet学習コードにMobileNetV1のモデルを追加し、optimizerや、学習率の変移、ウェイトの初期化、ウェイトの保存などを変更したコードおよび学習したウェイトを評価するコードをGitHubに置いておきます。
https://github.com/ponta256/train-mobilenet-w-imagenet
以下は4GPUのサーバでの実行例なので環境に応じて引数は調整してください。
root@5494fd53ca8e:~# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_basenet.py -a mobilenet /mnt/ssd/imagenet/ --batch-size=2048 --lr=0.001 --workers=10 Epoch: [0][ 0/626] Time 44.410 (44.410) Data 28.982 (28.982) Loss 6.9200e 00 (6.9200e 00) Acc@1 0.15 ( 0.15) Acc@5 0.54 ( 0.54) Epoch: [0][ 10/626] Time 1.189 ( 5.240) Data 0.000 ( 2.635) Loss 6.8801e 00 (6.9207e 00) Acc@1 0.20 ( 0.17) Acc@5 1.12 ( 0.79) Epoch: [0][ 20/626] Time 11.564 ( 3.812) Data 11.140 ( 1.950) Loss 6.8466e 00 (6.8963e 00) Acc@1 0.29 ( 0.19) Acc@5 0.93 ( 0.87) Epoch: [0][ 30/626] Time 15.662 ( 3.438) Data 15.232 ( 1.840) Loss 6.8108e 00 (6.8733e 00) Acc@1 0.34 ( 0.22) Acc@5 0.98 ( 0.92) Epoch: [0][ 40/626] Time 15.529 ( 3.242) Data 15.106 ( 1.779) Loss 6.6936e 00 (6.8450e 00) Acc@1 0.54 ( 0.26) Acc@5 1.90 ( 1.07) ~ snip ~
学習中はGPUやCPUがきちんと回っていることを確認します。
root@5494fd53ca8e:~# top top - 03:28:54 up 41 days, 23:38, 1 user, load average: 6.91, 10.19, 9.30 Tasks: 16 total, 11 running, 5 sleeping, 0 stopped, 0 zombie %Cpu(s): 7.1 us, 2.1 sy, 0.0 ni, 90.6 id, 0.2 wa, 0.0 hi, 0.0 si, 0.0 st KiB Mem : 13194715 total, 49648152 free, 7063464 used, 75235536 buff/cache KiB Swap: 13412659 total, 13403124 free, 95348 used. 12403100 avail Mem PID USER PR NI VIRT RES SHR S %CPU %MEM TIME COMMAND 2734 root 20 0 40.932g 2.542g 97584 R 100.0 2.0 0:06.13 python3 2736 root 20 0 40.926g 2.535g 97584 R 100.0 2.0 0:06.11 python3 2737 root 20 0 40.942g 2.552g 97584 R 100.0 2.0 0:06.13 python3 2739 root 20 0 40.911g 2.521g 97584 R 100.0 2.0 0:06.13 python3 2740 root 20 0 40.912g 2.521g 97648 R 100.0 2.0 0:06.13 python3 2741 root 20 0 40.893g 2.502g 97584 R 100.0 2.0 0:06.13 python3 < 2742 root 20 0 40.858g 2.468g 97584 R 100.0 2.0 0:06.13 python3 2743 root 20 0 40.898g 2.507g 97648 R 100.0 2.0 0:06.12 python3 2735 root 20 0 40.897g 2.507g 97648 R 93.3 2.0 0:06.14 python3 2738 root 20 0 40.930g 2.540g 97584 R 93.3 2.0 0:06.12 python3 ~ snip ~ root@5494fd53ca8e:~# nvidia-smi Sun Jun 2 03:12:20 2019 ----------------------------------------------------------------------------- | NVIDIA-SMI 410.78 Driver Version: 410.78 CUDA Version: 10.0 | |------------------------------- ---------------------- ---------------------- | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |=============================== ====================== ======================| | 0 XXXXXX XXXXX Off | 00000000:05:00.0 On | Off | | 33% 72C P0 184W / 250W | 24150MiB / 24448MiB | 100% Default | ------------------------------- ---------------------- ---------------------- | 1 XXXXXX XXXXX Off | 00000000:06:00.0 Off | Off | | 32% 72C P0 176W / 250W | 23971MiB / 24449MiB | 100% Default | ------------------------------- ---------------------- ---------------------- | 2 XXXXXX XXXXX Off | 00000000:09:00.0 Off | Off | | 27% 68C P0 161W / 250W | 23971MiB / 24449MiB | 100% Default | ------------------------------- ---------------------- ---------------------- | 3 XXXXXX XXXXX Off | 00000000:0A:00.0 Off | Off | | 25% 58C P0 155W / 250W | 23971MiB / 24449MiB | 100% Default | ------------------------------- ---------------------- ----------------------
学習は延々と続きますが、47エポック時点のウェイトを取り出して評価してみます。
Epoch: [47][610/626] Time 11.078 ( 2.622) Data 10.604 ( 1.644) Loss 1.4328e 00 (1.4181e 00) Acc@1 66.94 ( 66.38) Acc@5 85.35 ( 85.66) Epoch: [47][620/626] Time 9.104 ( 2.616) Data 8.685 ( 1.639) Loss 1.4567e 00 (1.4186e 00) Acc@1 65.62 ( 66.37) Acc@5 84.57 ( 85.65) Test: [ 0/25] Time 31.851 (31.851) Loss 1.0282e 00 (1.0282e 00) Acc@1 74.41 ( 74.41) Acc@5 91.70 ( 91.70) Test: [10/25] Time 24.389 ( 5.405) Loss 1.7012e 00 (1.1890e 00) Acc@1 60.16 ( 69.95) Acc@5 83.35 ( 90.27) Test: [20/25] Time 17.614 ( 3.875) Loss 1.8932e 00 (1.4850e 00) Acc@1 56.64 ( 64.87) Acc@5 78.56 ( 85.71) * Acc@1 64.728 Acc@5 85.690
$ python eval_basenet.py -a mobilenet --weight=model_best_weight.pth /mnt/ssd/imagenet/ ~ snip ~ Test: [195/196] Time 0.257 ( 0.692) Loss 3.1799e 00 (1.4985e 00) Acc@1 30.00 ( 64.79) Acc@5 66.25 ( 85.88) * Acc@1 64.792 Acc@5 85.882
Top1スコアが64.8。MobileNetV1のオリジナル論文でのTop1スコアは70.6なので、学習経過の値としては良好と思います。
MobileNetはVGG16などに比べるとはるかに収束が速く学習しやすいと言えます。モデルによっては簡単に収束せずに苦労することもあると思いますが、基本は同じでいけるはず。
今回はImageNetのデータでモデルをスクラッチから学習させてみました。それほど頻度の高い作業ではないと思いますが、学習済みのウェイトが使えないときにはこんな風にImageNetデータで学習できるという参考にしていただければと思います。