皆さんこんにちは
お元気でしょうか。最近は宅配スーパーによりますます外出しなくなっています。
本日はTPUをPyTorchで使ってみます。
GPUと比較して、TPUは汎用性をなくした代わりによりDeepLearningに必要な演算を高速にできるようにしたものです。
Why TPU?
TPUとは?
TPUはニューラルネットワークの演算専用のアーキテクチャです。
一言で使うモチベーションをお伝えするのであれば「速いから」の一言です
他の用途には利用できない分、大規模な乗算と加算を高速に演算できます。
詳細は次のサイトを見ていただくのが良いと思います。
www.atmarkit.co.jp
cloud.google.com
どの程度高速なのか
Cloud TPU | Google Cloud(次の図は左記から引用)によれば、GPUと比較して27倍の高速化と38%のコストを抑えました。
TPUを有効活用すれば、高速な計算を実現しつつ低コストな計算が可能です。
どうやって使うのか
現在は3つ選択肢があります。
- Google Colaboratory
- Google Cloud Computing
- Kaggle Kernel
Google Cloud Computingは一応使えますが、制限があります。また、Kaggle Kernelも週30hと時間制限があります。
まずは、今回はGoogle Colaboratoryで試すことにします。
※確か性能はKaggle Kernelのほうが良かったはずです。
PyTorch-XLA
Google社の製品であるため、当然ですが、Tensorflowに対応しています。
これをPyTorchで動作させることも可能です。そのライブラリがPyTorch-XLAです。
PyTorch-XLAのXLAはAccelerated Linear Algebraです。
XLAはTensorflowで利用されていますが、PyTorchのInterfaceでCloud TPUを利用するようにしたものです。
インストール
!pip install
実装
以下の実装を利用します。
colab.research.google.com
インストール部
インストールは次の通りです。
この実装を利用すれば、問題ありません。
VERSION = "20200325" #@param ["1.5" , "20200325", "nightly"] !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py !python pytorch-xla-env-setup.py --version $VERSION
結果可視化
# Result Visualization Helper from matplotlib import pyplot as plt M, N = 4, 6 RESULT_IMG_PATH = '/tmp/test_result.jpg' CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] def plot_results(images, labels, preds): images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N] inv_norm = transforms.Normalize( mean=(-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), std=(1/0.2023, 1/0.1994, 1/0.2010)) num_images = images.shape[0] fig, axes = plt.subplots(M, N, figsize=(16, 9)) fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)') for i, ax in enumerate(fig.axes): ax.axis('off') if i >= num_images: continue img, label, prediction = images[i], labels[i], preds[i] img = inv_norm(img) img = img.permute(1, 2, 0) # (C, M, N) -> (M, N, C) label, prediction = label.item(), prediction.item() if label == prediction: ax.set_title(u'\u2713', color='blue', fontsize=22) else: ax.set_title( 'X {}/{}'.format(CIFAR10_LABELS[label], CIFAR10_LABELS[prediction]), color='red') ax.imshow(img) plt.savefig(RESULT_IMG_PATH, transparent=True) ||< *** モデル構築部 モデルの構成は標準的なResNet18です。 ここで唯一変わるのは、torch_xlaのモジュールをインポートしていることです。 >|python| import numpy as np import os import time import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.utils.utils as xu import torchvision from torchvision import datasets, transforms class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = torch.flatten(out, 1) out = self.linear(out) return F.log_softmax(out, dim=1) def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2])
学習
さて、TPUを使う上で大きく変わってくるのはここからです。
次の通りです。
def train_resnet18(): torch.manual_seed(1) # Get and shard dataset into dataloaders norm = transforms.Normalize( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), norm, ]) transform_test = transforms.Compose([ transforms.ToTensor(), norm, ]) train_dataset = datasets.CIFAR10( root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())), train=True, download=True, transform=transform_train) test_dataset = datasets.CIFAR10( root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())), train=False, download=True, transform=transform_test) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'], drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=FLAGS['num_workers'], drop_last=True) # Scale learning rate to num cores learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size() # (2) Tensor Coreのコア数 # Get loss function, optimizer, and model device = xm.xla_device() # (3) 実行デバイスの取得 model = ResNet18().to(device) optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=FLAGS['momentum'], weight_decay=5e-4) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) # get_ordinalは動作しているプロセスの番号 def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, FLAGS['num_epochs'] + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS['metrics_debug']: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target # Start training processes def _mp_fn(rank, flags): global FLAGS FLAGS = flags torch.set_default_tensor_type('torch.FloatTensor') accuracy, data, pred, target = train_resnet18() if rank == 0: # Retrieve tensors that are on TPU core 0 and plot. plot_results(data.cpu(), pred.cpu(), target.cpu()) xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork') # (1) 実行用のプロセス起動します。
#### (1) xmp.spawnによるTPU上での実行
xmp.spawnはXLAのデバイス(TPU)で実行するためのプロセスを起動します。
#### (2) Tensor Coreのコア数
xm.xrt_world_size()はコア数(並列数)を取得できます。
この部分の実装はバッチのサイズに応じて最適化手法(SGD)の学習率を調整しています。
#### (3) 実行デバイスの取得
xm.xla_device()の利用により、デバイス情報を取得でき、TPU/CPUのどちらかが返ってきます。
#### (4) master_printを使った出力
xlaのデバイス上ではmaster_printを利用して出力しています。
masterのデバイスで出力するAPIとなります。(この手の内容ちゃんとした説明が見つからない‥)
### 利用後の感想
肌感として非常にGPUよりも高速ではあります。
概ね実装はほぼ同じで、少しの差分を修正すれば動作します。
ただ、時々メモリエラーを起こすなど不安定になることがあります。
この原因の特定が難しく今の所プロファイラが見当たらないため、まだまだ挙動が読めておらず、エラーの原因がわからないことも多くあります。
最後に
TPUを有効活用できれば、生産性が上がるのかなと思っています。
ただ、まだまだ私自身、使いこなすにはほど遠いと思っているところです。
少しずつ使いこなしていき、生産性を上げられるようにしたいと思っています。