皆さんこんにちは
お元気ですか。ついにクーラーが必要になってきました。
電気代が心配ですがなんとかなるでしょう。
本日はPyTorchの研究開発を加速する「pytorch-pfn-extras」を紹介します。
pytorch-pfn-extras
pytorch-pfn-extrasとは
PyTorchを使った研究開発の促進のために開発されているライブラリです。
こちらの開発元はChainerを開発していたPreferred Networks社によるものです。
Chainerの頃にはあったTrainerに似ている構成(厳密にはクラス構成が異なる)やIgnite連携が用意されており、便利に使えるのでは?と思っています。
MNISTサンプル
実装はこちらを参考にしてください、必要なポイントを解説します。
github.com
ニューラルネットワークの定義
ニューラルネットワークの定義はPyTorchの実装とほぼ同じです。
しかし、一つ違う点として、LazyConv2DやLazyLinearと呼ばれるモジュールをppeは独自実装しています。
PyTorchのConv2DやLinearの場合、入力するチャネルや次元数の指定が必要ですが、Lazy-の場合はその指定を省けます。
class Net(nn.Module): def __init__(self, lazy): super().__init__() if lazy: self.conv1 = ppe.nn.LazyConv2d(None, 20, 5, 1) self.conv2 = ppe.nn.LazyConv2d(None, 50, 5, 1) self.fc1 = ppe.nn.LazyLinear(None, 500) self.fc2 = ppe.nn.LazyLinear(None, 10) else: self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1) self.fc1 = nn.Linear(4*4*50, 500) self.fc2 = nn.Linear(500, 10)
学習部
ExtensionsManagerと呼ばれるクラスに拡張機能を管理するものがあります。
このManagerに設定することで、いつ、学習を止めるか、レポーティングを行うかなど、設定が可能です。
また、共通的に書かないといけないiteration回数を止める処理などを書く必要がほぼなくなります。
Chainer時代のTrainerだと、学習のメソッドを拡張するなりしなければなりませんが昨今の学習時にいじる系統(MixUpなど)だと不便になってくると思っていました。
train関数に学習中の処理も実装できるのでより一層、研究向きとしては便利になった印象です。
manager = ppe.training.ExtensionsManager( model, optimizer, args.epochs, extensions=my_extensions, iters_per_epoch=len(train_loader), stop_trigger=trigger) def train(manager, args, model, device, train_loader, optimizer): while not manager.stop_trigger: model.train() for batch_idx, (data, target) in enumerate(train_loader): with manager.run_iteration(): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) ppe.reporting.report({'train/loss': loss.item()}) loss.backward() optimizer.step()
拡張機能
extensionsのモジュールでは指定すれば様々なことが可能です。
このあたりの名前はChainerを継承しているので、元Chainer Userとしてはわかりやすい印象。
重みのスナップショット、統計情報の計算、ログの記録などを可能としています。
個人的にMLFlowあたりと連携してくれるととても嬉しいのですが、対応していただける日は来るのだろうか(「作って」みたいな要望が来そうだ‥)。
my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), extensions.ParameterStatistics(model, prefix='model'), extensions.VariableStatisticsPlot(model), extensions.Evaluator( test_loader, model, eval_func=lambda data, target: test(args, model, device, data, target), progress_bar=True), extensions.PlotReport( ['train/loss', 'val/loss'], 'epoch', filename='loss.png'), extensions.PrintReport(['epoch', 'iteration', 'train/loss', 'lr', 'model/fc2.bias/grad/min', 'val/loss', 'val/acc']), extensions.snapshot(), ]
最後に
PyTorchの便利モジュールが存在すること。
また、色々な拡張機能により、Kaggleの実装など、よりシンプルにかけそうだといった印象で実利用時(これから)に期待を持っています。
研究開発がメインな身としてはより開発が進むと嬉しいと感じています。