のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

PyTorch XLAでTPUを操作する

Sponsored Links

皆さんこんにちは
お元気でしょうか。最近は宅配スーパーによりますます外出しなくなっています。

本日はTPUをPyTorchで使ってみます。
GPUと比較して、TPUは汎用性をなくした代わりによりDeepLearningに必要な演算を高速にできるようにしたものです。

Why TPU?

TPUとは?

TPUはニューラルネットワークの演算専用のアーキテクチャです。
一言で使うモチベーションをお伝えするのであれば「速いから」の一言です
他の用途には利用できない分、大規模な乗算と加算を高速に演算できます。

詳細は次のサイトを見ていただくのが良いと思います。
www.atmarkit.co.jp
cloud.google.com

どの程度高速なのか

Cloud TPU  |  Google Cloud(次の図は左記から引用)によれば、GPUと比較して27倍の高速化と38%のコストを抑えました。
TPUを有効活用すれば、高速な計算を実現しつつ低コストな計算が可能です。

f:id:tereka:20200330235102p:plain

どうやって使うのか

現在は3つ選択肢があります。

  1. Google Colaboratory
  2. Google Cloud Computing
  3. Kaggle Kernel

Google Cloud Computingは一応使えますが、制限があります。また、Kaggle Kernelも週30hと時間制限があります。
まずは、今回はGoogle Colaboratoryで試すことにします。
※確か性能はKaggle Kernelのほうが良かったはずです。

PyTorch-XLA

Google社の製品であるため、当然ですが、Tensorflowに対応しています。
これをPyTorchで動作させることも可能です。そのライブラリがPyTorch-XLAです。

github.com

PyTorch-XLAのXLAはAccelerated Linear Algebraです。
XLAはTensorflowで利用されていますが、PyTorchのInterfaceでCloud TPUを利用するようにしたものです。

準備

Colaboratoryでは、アクセラレータ(CPUorGPUorTPU)を選択できます。
ハードウェアアクセラレータをTPUに変更します。

インストール

!pip install 

実装

以下の実装を利用します。
colab.research.google.com

インストール部

インストールは次の通りです。
この実装を利用すれば、問題ありません。

VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
結果可視化
# Result Visualization Helper
from matplotlib import pyplot as plt

M, N = 4, 6
RESULT_IMG_PATH = '/tmp/test_result.jpg'
CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                 'dog', 'frog', 'horse', 'ship', 'truck']

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize(
      mean=(-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010),
      std=(1/0.2023, 1/0.1994, 1/0.2010))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(16, 9))
  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.permute(1, 2, 0) # (C, M, N) -> (M, N, C)
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'\u2713', color='blue', fontsize=22)
    else:
      ax.set_title(
          'X {}/{}'.format(CIFAR10_LABELS[label],
                          CIFAR10_LABELS[prediction]), color='red')
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)
||< 

*** モデル構築部
モデルの構成は標準的なResNet18です。
ここで唯一変わるのは、torch_xlaのモジュールをインポートしていることです。

>|python|
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
          nn.Conv2d(
              in_planes,
              self.expansion * planes,
              kernel_size=1,
              stride=stride,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out


class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)

def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])
学習

さて、TPUを使う上で大きく変わってくるのはここからです。
次の通りです。

def train_resnet18():
  torch.manual_seed(1)

  # Get and shard dataset into dataloaders
  norm = transforms.Normalize(
      mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
  transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      norm,
  ])
  transform_test = transforms.Compose([
      transforms.ToTensor(),
      norm,
  ])
  train_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=True,
      download=True,
      transform=transform_train)
  test_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=False,
      download=True,
      transform=transform_test)
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True) 
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size() # (2) Tensor Coreのコア数

  # Get loss function, optimizer, and model
  device = xm.xla_device() # (3) 実行デバイスの取得
  model = ResNet18().to(device)
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True) # get_ordinalは動作しているプロセスの番号

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

# Start training processes
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_resnet18()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork') # (1) 実行用のプロセス起動します。

#### (1) xmp.spawnによるTPU上での実行
xmp.spawnはXLAのデバイス(TPU)で実行するためのプロセスを起動します。

#### (2) Tensor Coreのコア数
xm.xrt_world_size()はコア数(並列数)を取得できます。
この部分の実装はバッチのサイズに応じて最適化手法(SGD)の学習率を調整しています。

#### (3) 実行デバイスの取得
xm.xla_device()の利用により、デバイス情報を取得でき、TPU/CPUのどちらかが返ってきます。

#### (4) master_printを使った出力
xlaのデバイス上ではmaster_printを利用して出力しています。
masterのデバイスで出力するAPIとなります。(この手の内容ちゃんとした説明が見つからない‥)

### 利用後の感想
肌感として非常にGPUよりも高速ではあります。
概ね実装はほぼ同じで、少しの差分を修正すれば動作します。
ただ、時々メモリエラーを起こすなど不安定になることがあります。
この原因の特定が難しく今の所プロファイラが見当たらないため、まだまだ挙動が読めておらず、エラーの原因がわからないことも多くあります。

最後に

TPUを有効活用できれば、生産性が上がるのかなと思っています。
ただ、まだまだ私自身、使いこなすにはほど遠いと思っているところです。
少しずつ使いこなしていき、生産性を上げられるようにしたいと思っています。