のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

ChainerのTrainerを使ってみた

Sponsored Links

皆さんこんにちは
お元気ですか。最近、Chainer便利でびっくりしたような頃合いです。

頻繁に更新することで有名なChainerですが、久々にupgradeすると以前よりも
シンプルなタスクについて、簡単に学習ができます。

Trainer

Chainer version 1.11.0よりTrainerと呼ばれる機能が実装されています。
以前まで学習用バッチ処理を自前で書くようなことが
必要でしたが、これを使うことによってバッチ処理を書く必要がなくなります。

実際の機能としてはある処理をhockしたり、グラフを出力したり
レポートを表示したりと学習中に確認したいグラフは沢山あります。

それらのグラフを可視化したいといったことは往々にしてあります。

Trainerの基本的な使い方

Trainerを使うと、Progress Barやlogを自動的に吐き出すことができます。
通常のモードでは、Trainerを基本的に使うことができます。

Extensionsを使うことにより、Trainerを使えます。
殆どExample通りですが、以下が最低限のコードとなります。

# coding:utf-8
from __future__ import absolute_import
from __future__ import unicode_literals
import chainer
import chainer.datasets
from chainer import training
from chainer.training import extensions
import chainer.links as L
import chainer.functions as F


class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(None, n_units),
            l3=L.Linear(None, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, 32)
test_iter = chainer.iterators.SerialIterator(test, 32,
                                             repeat=False, shuffle=False)
model = L.Classifier(MLP(784, 10))
optimizer = chainer.optimizers.SGD()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (10, 'epoch'), out="result")

trainer.extend(extensions.Evaluator(test_iter, model, device=10))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=(10, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss',
     'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()

Trainerにはupdate方法を宣言します。
Trainer#extendを使うことで、一定の条件の元で起動します。

Extension 概要
Evaluator 一定の期間で評価する。(validation)
dump_graph グラフを表示する。
snapshot 一定の間隔(ユーザ指定)でモデルを保存する
LogReport ログとして出力する。
PrintReport print文を使って現状をprintする。(以下に例あり)
ProgressBar Progress Barを表示する。

上記の場合の出力例は次のとおりです。

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           0.624464    0.306581              0.850083       0.913538
2           0.282575    0.240019              0.919283       0.932608
     total [##########........................................] 21.87%
this epoch [#########.........................................] 18.67%
      4100 iter, 2 epoch / 10 epochs
    61.825 iters/sec. Estimated time to finish: 0:03:56.958209.

また、結果として、出力されるresult配下のディレクトリは次のとおりです。

-rw-r--r--  1 Tereka  staff     2250 10 24 23:33 cg.dot
-rw-------  1 Tereka  staff     2590 10 24 23:39 log
-rw-------  1 Tereka  staff  4775680 10 24 23:39 snapshot_iter_18750

DatasetMixinを使った拡張

ImageNetのサンプルにありますが、Real Time Augmentationを行うことができます。
これを応用すると様々な用途で使うことができて非常に便利です。

例えば、ファイルを順次読み出したい場合に
ファイルをデータとして渡しておき、それを処理するタイミングで順次読み出すことができます。
また、データを加工することも自由にできるため、自由にデータに変換を加えることができます。

chainer.dataset.DatasetMixinを使って以下のような拡張が可能です。
以下の拡張はシンプルです。chainer.dataset.DatasetMixin#get_exampleを使うと実現できます。
このメソッド内部にファイルを読み込む処理を作ると、
実際にファイルを読み込むことが可能となります。
例は次の通りです。

import skimage.io
class DatasetExampleMixin(chainer.dataset.DatasetMixin):
    def __init__(self,X,y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(X)

    def get_example(self,i):
        """
        Fileを読み出す処理
        """
        return skimage.io.imread(X[i]), y[i]

最後に

Trainer凄く便利!これを使いこなしてTrainerを使えるChainer使いになりましょう。