のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

新しくKaggle用のマシンを調達しました(2020年ver)

皆さんこんにちは
tereka114です。本日はKaggle Advent Calendar24日目の記事です。
今回はKaggle用マシンで購入したものとパーツ選びの基準をご紹介します。

qiita.com

Kaggle用マシン

Kaggle用マシンがなぜ必要か

Kaggleをするためにはノートパソコン一台でも十分です。
しかし、勝つには、最近の巨大データ、画像、NLPコンペの増加に伴いハイスペックPCが必要になってきました。
特に画像やNLPでは、CNN、BERTと呼ばれるニューラルネットワークの手法がソリューションを締めており、
普通の特徴量抽出の手法では勝てなくなってきています。

そのため、それなりの人たちの多くはGPUがあるマシンを持っていることが多くなりました。

クラウドがあるのではないか?といった疑問を持たれた方もいると思いますが、
はっきりいってオンプレ運用より高くなるので緊急時以外おすすめしません。

どのような要件のマシンが必要か

高性能なCPU、メモリ、GPUです。といっても元も子もないので、
どのようなところに影響するのかはパーツの選定の箇所で説明します。

パーツの選定

ここでは実際に組み上げたPCのパーツの紹介をしつつ、選定基準を記載していきます。

基本方針

CPU、GPU、メモリは良いものを選定するのは基本です。
しかし、意外に見落としがちなのはマザボと電源です。
この2パーツはとにかく、故障したときに原因がわかりにくいもので相当曲者です。
ここで失敗すると精神衛生が悲惨です。

各々のパーツ

CPU

lightgbmとテーブル加工の処理、画像の前処理で利用します。
画像処理のあるあるですが、良いGPUだけを準備してもCPUがボトルネックになれば高速化されないので、GPUに合わせてCPUも良くする必要があります。
可能な限りスペックをよくしたいなーと思っていたので、次のCPUを選択しました。
10コア20スレッドです。

メモリ

Kaggleでメモリが必要になるケースは概ね2パタンです。

  • 巨大なテーブルデータの場合、ほとんどのデータをメモリに載せる必要があるので
  • 高速化のためのキャッシュ、画像コンペだとすれば画像を読む、加工、廃棄をすることになりますが、廃棄せずメモリに持てると性能が向上します。

GPU

RTX3090です。私が参加するコンペの多くは画像、もしくは、ニューラルネットワーク
勝つソリューションが多いので、ここは譲れないところでした。

MSI GeForce RTX 3090 GAMING X TRIO 24G グラフィックスボード VD7347

MSI GeForce RTX 3090 GAMING X TRIO 24G グラフィックスボード VD7347

  • 発売日: 2020/09/24
  • メディア: Personal Computers

GPUの選定ですが、基本は性能とメモリです。
Kaggle界ではメモリの方が重要視されている気がします。

巨大モデルである程度のBatchsizeを担保しながら動かさないといけないので、11GBでは足りなくなってきていると感じています(MixedPrecisionも使っていますが厳しい)
そのため、24GBある3090を首を長くしてお待ちしていました。

ちなみに、普通の機械学習の勉強とかであれば、ColaboratoryにもGPUあるので、そういったところにGPUはおまかせするのも一つの手だと思います。

SSD

あまり考えずに選定していますが、最近のデータセットの規模から2TBは最低あっても良いと思います。
少ないとストレスなので、可能であれば大きめのものを買うと良いでしょう。

電源

先程簡単に説明しましたが、電源とマザボは最重要パーツの一つです。
というのも、電源とマザボが故障すると原因の調査が難しいものの一つで、問題が起こると、
一度総解体する必要も出てくるぐらい厄介なパーツです。

電源において重要なポイントとしては消費電力のギリギリのW数の電源を購入するのはいけません。
電源自体が劣化すると出力も下がり、運用していて落ちるようになります。
後々の平穏のためにも決してケチってはいけません。

そのため、200-300Wほどは余裕を持っておくことをおすすめします。

マザーボード

電源の箇所で紹介しましたが、最重要パーツの一つです。
可能な限り高めのもの(1万超え-3万程度)のものを購入するようにしています。
私のマザボ購入時のチェックはCPUの型番に対応しているか(これ対応していないと購入し直しです)。
メモリのスロットが何枚あるか、PCI-Eのスロットはどの程度あるかは確認しています。

私はよくASUSのものを購入しています。電源がついているかどうかなどわかりやすいので、気に入っています。

ASUS INTEL Z490 搭載 LGA1200 対応 ROG STRIX Z490-E GAMING 【 ATX 】

ASUS INTEL Z490 搭載 LGA1200 対応 ROG STRIX Z490-E GAMING 【 ATX 】

  • 発売日: 2020/05/21
  • メディア: Personal Computers

PCケース

タワー、ミドルタワーそしてフルタワーの種類があります。
GPU指す場合はフルタワーのものをおすすめです。でかいので余裕があります。

ファン

Intelの標準装備はあまり性能が良くないので必ず外部で購入するようにしています。
巨大ですが、それなりにコスパが良いものを使っています。
私は虎徹を購入するのが標準的なので今回もそれを買いました。

サイズ オリジナルCPUクーラー 虎徹 Mark II

サイズ オリジナルCPUクーラー 虎徹 Mark II

  • 発売日: 2017/06/02
  • メディア: 付属品

最後に

Kaggleを真面目にやりたいと思っている人は購入されることをおすすめします。
多分完成品を購入するよりも自作の方が安いです。

Kaggle Notebookで使えるトリックを紹介します

皆さんこんにちは。
お元気ですか。私はもう小麦そのものは食べる専門です。
本日はKaggle Notebookに関するトリックを紹介しようと思います。

Kaggle Notebook

Kaggle NotebookはKaggle上で動作する実行環境みたいなものです。
Kaggle上で動いているJupyter Notebookと解釈しても差し支えないでしょう。

昨年からこのKaggle Notebookで推論のみ提出コンペが増加しています。
推論のみ提出の嬉しいこととして、実行環境による差分がでないことがあります。
画像や自然言語処理は無限のモデルを生成するためのマシンリソースで殴ると勝ちやすいことがありますが、この場合、アンサンブルのみではなくある程度は工夫が必要になります。
反対に提出物作るのが格段に面倒になります。(ファイルのアップロード、推論コードの修正、コミット、サブミットを手動で行う)

黒魔術紹介

ここからは私が使っているトリックを紹介していきます。

  • コードを呼び出すほうがリソース管理が楽
  • 外部パッケージインストール
  • 例外
  • 小数点以下の結果に注意
  • コミットまでの速度を早くする

コードを呼び出すほうがリソース管理が楽

Kaggle KernelからはJupyter Notebookと同様の方式で実装を呼び出せます。
実装上はただの「!python predict.py」です。
通常のJupyter Notebookだと、コードを書いていくのが一番オーソドックスであり、本来の使い方だと思います。

近年のコンペだとチームマージした場合に問題が起こりやすいと思います。
よくあるのは、Keras/PyTorchの実装が混合するケースです。
その場合、PyTorchやTensorFlowのGPUメモリ開放の方法を実装せねばなりません。
また、GPUの稀に開放されないこともあるので面倒になるため、プロセスの終了をトリガーとすると開放もれがなくなるので便利です。
デバッグや更新が面倒なため、GPUメモリがシビアなコンペで利用します。

※ただし、実装の更新は面倒なので、コマンドライン引数の設定の工夫が必要です。

外部パッケージインストール

コンペによっては外部パッケージをインストールしたほうが良いものもあります。
私がKernelコンペで利用したもののひとつとしてapexがあります。
例えばapexを利用する場合のインストールコマンドは次のとおりです。
その後再起動もいらないので、長時間実行するコンペだと利用可能でしょう。

!cd ../input/apexpytorch && pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . && cd -

例外

Kaggle Kernelの提出時に例外が出るケースがあります。
しかし、例外の内容がさっぱりわからないので、解析が難しいです。
例外が見えすぎてもコンペ参加者がTestを例外を使ってハックするという事が起こるため、どこまで参加者に見せるかは悩ましい問題です。
私が見つけたものだと、このあたりです。

Exception 説明
Notebook Exceeded Allowed Compute メモリのリークなどが原因でプロセスが落ちた場合
Submission Scoring Error 提出形式に不備があった場合、スコアが計算できなくて例外が発生
Notebook Timeout 計算時間オーバー
Notebook Threw Exception 例外発生時の問題(よくわからない)

小数点以下の結果に注意

Kaggle Kernelを利用して提出する場合、基本例外が発生してもよくわかりません。
Google QAのコンペ(Spearman's rank correlation coefficient)では、出力結果によりかなりの例外が発生することを確認していました。
数式的には均等な値(1しかない)でなければ例外が発生しうる要因はなかったので、疑問に思っていました。

調べたら小数点の桁をうまく調整すると回避できたので、それが原因でないかと推測し、対処できました。
殆ど使う機会はないのかなとは思っていますが、頭の片隅に残しておくと役に立つ日が来るでしょう。

www.kaggle.com

コミットまでの速度を早くする

週利用の制限時間が42hしかないため、実行時間の無駄使いはコンペとしては致命傷になりうります。
そのため、一部の処理をスキップしたり早めることで、コミット完了までの時間を早められます。

提出前後(submit)でそれはデータセットの長さに変化があります。
つまり、データセットの長さで変化するような実装を入れておけば短くできるということです。
次の例では、sample_submission.csvを使っていますが、この長さが変わらないこともあるようです。
その場合は別の方法でデータ数を取得するのが望ましいでしょう。

例えば、こんなコードを書けば良いです。

import pandas as pd
submission_df = pd.read_csv("sample_submission.csv")
if len(submission_df) < 11: # データ数で判断する。他の方式でも可能
     submission_df.to_csv("submission.csv", index=False)
     exit()
# 以降後続の処理を記載する。

最後に

Kaggle Notebookは不便なところも多いため、工夫して使わないと厳しいです。
ただ、工夫しがいはあるので、個人的には好きなコンペの開催形式の一つです。
懲りずに参加しようかと思います。

種類が豊富な画像モデルライブラリpytorch-image-models(timm)の紹介

皆さんこんにちは
お元気ですか。私は小麦を暫く食べるだけにしたいです。(畑で見たくない‥)

さて、本日は最近勢いのあるモデルのライブラリ(pytorch-image-models)を紹介します。

pytorch-image-modelsとは

通称timmと呼ばれるパッケージです。
PyTorchの画像系のモデルや最適化手法(NAdamなど)が実装されています。
Kagglerもこのパッケージを利用することが増えています。

github.com

従来まで利用していたpretrained-models.pytorchは更新が止まっており、最新のモデルに追従できていないところがありました。
このモデルは例えば、EfficientNetやResNeStなどの実装もあります
モデルの検証も豊富でImageNetの様々なパタンで行われているのでこの中で最適なものを選択すると良いでしょう。詳しくはこちらへ。

github.com

インストールはpipから可能です。

pip install timm

モデルの利用方法

モデルの利用方法は簡単です。
timmモジュールのcreate_modelを用いることで、指定されたモデルを取得できます。
各モデルには、次のメソッドが実装されています。

forward・・・モデル全体の出力を獲得する
forward_features・・・特徴量出力部分(Pooling前)の出力を獲得する

どのようなモデルが実装されているかは、先述のResultsのページをご確認ください。

import timm
import torch
m = timm.create_model('efficientnet_b3', pretrained=True)
m.eval()
matrix = torch.rand(2, 3,224,224)

m(matrix).size() # (2, 1000)
m.forward_features(matrix).size() # (2, 1536, 7, 7)

最後に

非常にモデルの種類が豊富で使いやすいです。
これを利用すればほんの数行でResNeStなど新しいモデルを試すことができるので、これから使っていこうと思っています。

atmaCup#5へ参加しました!

皆さんこんにちは。
お元気ですか?私は元気です。

本日までatmaCup#5に参加し、Public12th, Private18thでした。

f:id:tereka:20200606200537p:plain

問題設定

装置に関連する問題で、テーブルと信号が提供されました。
評価はPR-AUCで、Positive sampleが少ないこともあり、スコアが安定しない問題でした。

取り組み

序盤

1日目から信号の問題だったので、
1DCNNに最初から取り組んでいました。

その地点でPublic 1位になっていたのですが、CVがうまく安定せず、
ここから、なかなかスコアが上がらなくなっていました。

中盤

CNNのチューニングや様々な謎トリック(interpolate mixup など)を持ち出して試していました。
あんまり効果がなくCV/LBも相関が取れなくて非常に頭を抱えていました。

終盤

最終日2日間ぐらいであまりに抜かれすぎたのもあり、
LGBMに手を出し始めましたがなかなか上がりませんでした。

LGBM、最後、Publicで少しブーストかけるのには役立ったのでその点良かったと思っています。

ソリューション

1DCNNを中心で進めていました。
最後の2日LGBMを試しており、PublicはNNがベストでしたが、
PrivateはLGBMも含めたほうが精度が高かったです。

今回、安定しないPR-AUCだったので、アンサンブルしたほうが良かったということでしょうか。

最後に

運営の皆様ありがとうございました!
参加者が多く、激戦で、非常に面白かったです。
(日々、Gotoさんが要望に合わせてguruguru更新していくのが凄かったです!)
今後もチャンスが有れば、参加させていただきたいと思います。

PyTorch XLAでTPUを操作する

皆さんこんにちは
お元気でしょうか。最近は宅配スーパーによりますます外出しなくなっています。

本日はTPUをPyTorchで使ってみます。
GPUと比較して、TPUは汎用性をなくした代わりによりDeepLearningに必要な演算を高速にできるようにしたものです。

Why TPU?

TPUとは?

TPUはニューラルネットワークの演算専用のアーキテクチャです。
一言で使うモチベーションをお伝えするのであれば「速いから」の一言です
他の用途には利用できない分、大規模な乗算と加算を高速に演算できます。

詳細は次のサイトを見ていただくのが良いと思います。
www.atmarkit.co.jp
cloud.google.com

どの程度高速なのか

Cloud TPU  |  Google Cloud(次の図は左記から引用)によれば、GPUと比較して27倍の高速化と38%のコストを抑えました。
TPUを有効活用すれば、高速な計算を実現しつつ低コストな計算が可能です。

f:id:tereka:20200330235102p:plain

どうやって使うのか

現在は3つ選択肢があります。

  1. Google Colaboratory
  2. Google Cloud Computing
  3. Kaggle Kernel

Google Cloud Computingは一応使えますが、制限があります。また、Kaggle Kernelも週30hと時間制限があります。
まずは、今回はGoogle Colaboratoryで試すことにします。
※確か性能はKaggle Kernelのほうが良かったはずです。

PyTorch-XLA

Google社の製品であるため、当然ですが、Tensorflowに対応しています。
これをPyTorchで動作させることも可能です。そのライブラリがPyTorch-XLAです。

github.com

PyTorch-XLAのXLAはAccelerated Linear Algebraです。
XLAはTensorflowで利用されていますが、PyTorchのInterfaceでCloud TPUを利用するようにしたものです。

準備

Colaboratoryでは、アクセラレータ(CPUorGPUorTPU)を選択できます。
ハードウェアアクセラレータをTPUに変更します。

インストール

!pip install 

実装

以下の実装を利用します。
colab.research.google.com

インストール部

インストールは次の通りです。
この実装を利用すれば、問題ありません。

VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
結果可視化
# Result Visualization Helper
from matplotlib import pyplot as plt

M, N = 4, 6
RESULT_IMG_PATH = '/tmp/test_result.jpg'
CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                 'dog', 'frog', 'horse', 'ship', 'truck']

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize(
      mean=(-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010),
      std=(1/0.2023, 1/0.1994, 1/0.2010))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(16, 9))
  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.permute(1, 2, 0) # (C, M, N) -> (M, N, C)
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'\u2713', color='blue', fontsize=22)
    else:
      ax.set_title(
          'X {}/{}'.format(CIFAR10_LABELS[label],
                          CIFAR10_LABELS[prediction]), color='red')
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)
||< 

*** モデル構築部
モデルの構成は標準的なResNet18です。
ここで唯一変わるのは、torch_xlaのモジュールをインポートしていることです。

>|python|
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
          nn.Conv2d(
              in_planes,
              self.expansion * planes,
              kernel_size=1,
              stride=stride,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out


class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)

def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])
学習

さて、TPUを使う上で大きく変わってくるのはここからです。
次の通りです。

def train_resnet18():
  torch.manual_seed(1)

  # Get and shard dataset into dataloaders
  norm = transforms.Normalize(
      mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
  transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      norm,
  ])
  transform_test = transforms.Compose([
      transforms.ToTensor(),
      norm,
  ])
  train_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=True,
      download=True,
      transform=transform_train)
  test_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=False,
      download=True,
      transform=transform_test)
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True) 
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size() # (2) Tensor Coreのコア数

  # Get loss function, optimizer, and model
  device = xm.xla_device() # (3) 実行デバイスの取得
  model = ResNet18().to(device)
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True) # get_ordinalは動作しているプロセスの番号

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

# Start training processes
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_resnet18()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork') # (1) 実行用のプロセス起動します。

#### (1) xmp.spawnによるTPU上での実行
xmp.spawnはXLAのデバイス(TPU)で実行するためのプロセスを起動します。

#### (2) Tensor Coreのコア数
xm.xrt_world_size()はコア数(並列数)を取得できます。
この部分の実装はバッチのサイズに応じて最適化手法(SGD)の学習率を調整しています。

#### (3) 実行デバイスの取得
xm.xla_device()の利用により、デバイス情報を取得でき、TPU/CPUのどちらかが返ってきます。

#### (4) master_printを使った出力
xlaのデバイス上ではmaster_printを利用して出力しています。
masterのデバイスで出力するAPIとなります。(この手の内容ちゃんとした説明が見つからない‥)

### 利用後の感想
肌感として非常にGPUよりも高速ではあります。
概ね実装はほぼ同じで、少しの差分を修正すれば動作します。
ただ、時々メモリエラーを起こすなど不安定になることがあります。
この原因の特定が難しく今の所プロファイラが見当たらないため、まだまだ挙動が読めておらず、エラーの原因がわからないことも多くあります。

最後に

TPUを有効活用できれば、生産性が上がるのかなと思っています。
ただ、まだまだ私自身、使いこなすにはほど遠いと思っているところです。
少しずつ使いこなしていき、生産性を上げられるようにしたいと思っています。