Chainerの学習の様子をリモートで確認するExtensionを作った
皆さんこんにちは
お元気ですか。私はGWでリフレッシュして、生き返りました。
Kaggleをやっているとき(特に画像などの長い場合)にリモートで
今学習されているかどうか、誤差はどうかなどのモデルの
様子が気になることはありませんか?
私は画像認識系のコンペを実際に行っている時に、気になることがあります。
これどうしようかと考えていたのですが、歩いている時にふと思いついたので実装しました。
このアイデアの実装のために、新しいChainerのExtensionを開発しました。(Trainerを使う想定です)
アイデア
Slackであれば外出中も見れると考えました。
そのため、学習の途中経過(lossなど)を投稿すれば見れる!
実装イメージは次の図に掲載しました。
コードを見た限りだと、Extensionで実装できそうだったので、トライしました。
Extensionの実装方法
Extensionの実装ですが、先に必要な情報を__init__に実装します。
そして、__call__が呼び出され、処理をする仕組みとなっています。
今回は既に実装されているPrintReportを参考に実装します。
前準備
SlackのWeb APIのIncoming Webhooksと投稿するChannelとusernameが必要です。
SlackのAPIに必要な情報は以下のurlから遷移して取得してください。
コード
早速Extensionを実装しました。
表示情報は既に実装されているPrintReportと同じにしました。
_throw_slackにSlackに投稿する部分を定義しています。
__init__に初期設定で必要な情報、__call__に表示する部分を記載しました。
# coding:utf-8 import os import sys from chainer.training import extension from chainer.training.extensions import log_report as log_report_module import requests import json class SlackReport(extension.Extension): def __init__(self, entries, log_report='LogReport', username="", url="",channel="",out=sys.stdout): self._entries = entries self._log_report = log_report self._log_len = 0 # number of observations already printed self._out = out # format information entry_widths = [max(10, len(s)) for s in entries] header = ' '.join(('{:%d}' % w for w in entry_widths)).format( *entries) + '\n' self._header = header # printed at the first call templates = [] for entry, w in zip(entries, entry_widths): templates.append((entry, '{:<%dg} ' % w, ' ' * (w + 2))) self._templates = templates self.username = username self.url = url self.channel = channel def __call__(self, trainer): if self._header: self._throw_slack(self._header) self._header = None log_report = self._log_report if isinstance(log_report, str): log_report = trainer.get_extension(log_report) elif isinstance(log_report, log_report_module.LogReport): log_report(trainer) # update the log report else: raise TypeError('log report has a wrong type %s' % type(log_report)) log = log_report.log log_len = self._log_len while len(log) > log_len: self._observation_throw_slack(log[log_len]) log_len += 1 self._log_len = log_len def _throw_slack(self, text): try: payload_dic = { "text": text, "username": self.username, "channel": self.channel, } requests.post(self.url, data=json.dumps(payload_dic)) except: self._out.write("error!") def serialize(self, serializer): log_report = self._log_report if isinstance(log_report, log_report_module.LogReport): log_report.serialize(serializer['_log_report']) def _observation_throw_slack(self, observation): text = "" for entry, template, empty in self._templates: if entry in observation: text += template.format(observation[entry]) else: text += empty self._throw_slack(text)
使い方
APIに必要な情報の定義とExtensionにSlackReportを追加するのみです。
例えば、公式のMNISTサンプルでは次のExtensionを追記してください。
username,channel,urlはそれぞれ自分の設定を指定してください。
trainer.extend(SlackReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'], username="YOUR USER NAME", channel="YOUR SLACK CHANNEL", url="SLACK API URL", ))
Slackへの投稿結果
ExampleのMNISTで実験したら、こんな感じになります。ちょっとガタガタなのはご愛嬌。
外出時に気になる場合は皆さんも使ってみましょう。
設定が間違っているときの処理は…、結構適当です。