のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

PyTorchでCIFAR10を既存のCIFAR10のDataset Classを使わずに分類する

Sponsored Links

皆さんこんにちは。
お元気ですか。雨天が増えてきて、出かけるのが億劫になっています。

PyTorchを使って画像認識データセットCIFAR10を分類しました。
KaggleでPyTorchユーザが増えてきたこともあり、勉強しました。

最近、この手のチュートリアルやExamplesに良しなにできる
データ処理専用クラスを予め作っていることがあります。

この状態は新しいデータセットを試したい場合に不便なので、
そのような内容が含まれないCIFAR10のコードを記述しました。

PyTorch

PyTorchとは

Deep Learningフレームワークです。
柔軟なネットワーク構築やGPUを利用した高速な行列演算を得意としています。
Chainerをforkして作られたので、実装方法が非常に似ています。
噂ではございますが、Chainerより高速だったりそうでなかったり。

http://pytorch.org/

インストー

上記の公式サイトを参考にしてください。
OS、Pythonのversion、CUDAの有無、バージョンによってコマンド変化しますが
公式にいけばそれらの選択ができ、その環境に応じたコマンドを提示します。

そのコマンドを入力すれば、環境の構築が完了します。
例えば、OSX, Python 2.7, pip, CUDAなしの場合、次のコマンドを実行します。

pip install http://download.pytorch.org/whl/torch-0.1.12.post2-cp27-none-macosx_10_7_x86_64.whl 
pip install torchvision 

CIFAR10を分類するコードを作ってみる。

ニューラルネットワークで設定する項目は
他のフレームワーク(TensorFlow, Chianerなど)と同じです。
Chainerと比較的似ている実装になっています。
ソースコード全体は次のgistに記載しています。

cifar10_example.py · GitHub

Import

はじめにimport文です。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import cPickle
import numpy as np
import os

CIFAR10の読み込み

CIFAR10を読み込むコードです。
同じディレクトリにCIFAR10のPython versionをダウンロードし、展開すれば完了です。
ソースコードは次を参考にさせていただきました。
今回用に少々修正を入れています。

qiita.com

def unpickle(file):
    fo = open(file, 'rb')
    dict = cPickle.load(fo)
    fo.close()
    return dict


def conv_data2image(data):
    return np.rollaxis(data.reshape((3, 32, 32)), 0, 3)


def load_cirar10(folder):
    """
    load cifar10

    :return: train_X, train_y, test_X, test_y
    """
    for i in range(1, 6):
        fname = os.path.join(folder, "%s%d" % ("data_batch_", i))
        data_dict = unpickle(fname)
        if i == 1:
            train_X = data_dict['data']
            train_y = data_dict['labels']
        else:
            train_X = np.vstack((train_X, data_dict['data']))
            train_y = np.hstack((train_y, data_dict['labels']))

    data_dict = unpickle(os.path.join(folder, 'test_batch'))
    test_X = data_dict['data']
    test_y = np.array(data_dict['labels'])

    train_X = [conv_data2image(x) for x in train_X]
    test_X = [conv_data2image(x) for x in test_X]

    return train_X, train_y, test_X, test_y

入力データの扱い

DataLoaderを使った方法はDatasetクラスを継承し、必要なメソッドを
作成したクラスに継承させる必要があります。

必要なメソッドは__getitem__と__len__です。
__getitem__にindexが与えられた時のデータの処理、
__len__にはデータの数を計算する処理を記載します。

class DataSet(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform

    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]

        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.x)

それでは、DataLoaderを作ります。第一引数に先程作成したDataset、
第二引数にbatch_size, 第三引数にデータの順番をshuffleするかどうかを与えます。
DataLoaderの利用方法は後述します。
PyTorchには、transformsパッケージがあり、
このtransformsを利用すると0-1へのスケーリング、
正規化、ランダムで切り抜きを行うなどの処理を記載できます。

train_d_loader = DataLoader(
    DataSet(
        x=train_X,
        y=train_y,
        transform=transforms.Compose(
            [
                transforms.ToTensor()
            ]
        )
    ), batch_size=64, shuffle=True
)

test_d_loader = DataLoader(
    DataSet(
        x=test_X,
        y=test_y,
        transform=transforms.Compose(
            [
                transforms.ToTensor()
            ]
        )
    ), batch_size=64, shuffle=False
)

Neural Networkの作り方

Convolutional Neural Networkを作りました。下記のコードを参考にしてください。
Chainerをベースにしているので、Chainerに類似している実装になります。
Chainerでは、__call__に宣言していましたが、PyTorchはforwardになります。

class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(512 * 4, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.max_pool2d(h, 2)
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.max_pool2d(h, 2)
        h = F.relu(self.conv5(h))
        h = F.relu(self.conv6(h))
        h = F.max_pool2d(h, 2)

        h = h.view(-1, 512 * 4)
        h = F.relu(self.fc1(h))
        h = F.dropout(h, training=self.training)
        h = F.log_softmax(self.fc2(h))
        return h

その他学習のためのOptimizerの準備

Optimizerの定義が必要です。特筆すべきところはありません。

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

学習

Chainerに非常に似ているコードです。
Variableを用いているのも、Chainerそっくりです。
少々前で宣言したDataLoaderを使っています。forで使うことで、
指定したバッチの数ずつ、データを取り出せます。

学習コード

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for index, (batch_train_x, batch_train_y) in enumerate(train_d_loader):
        train_variable = torch.autograd.Variable(batch_train_x)
        output = model(train_variable)
        loss = F.nll_loss(output, torch.autograd.Variable(batch_train_y))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]

    print ("training epoch: {} loss: {}".format(epoch, train_loss / len(train_d_loader)))

評価コード

学習部分とそこまで変わらない評価コードです。
Variableを使い、誤差と正答率を計算するコードです。
モデルの学習用と評価用の挙動は.train(), .eval()で切り替えられます。

    model.eval()
    test_loss = 0.0
    correct = 0
    for index, (batch_test_x, batch_test_y) in enumerate(test_d_loader):
        output = model(torch.autograd.Variable(batch_test_x))
        test_loss += F.nll_loss(output, torch.autograd.Variable(batch_test_y)).data[0]

        pred = output.data.max(1)[1]
        correct += pred.eq(batch_test_y).cpu().sum()

    print ("testing epoch: {} loss: {} accuracy: {}".format(epoch, test_loss / len(test_d_loader),
                                                            float(correct) / float(len(test_d_loader.dataset))))

標準出力

途中まで出力しました。後は省略しております。

training epoch: 0 loss: 2.19472245665
testing epoch: 0 loss: 1.85379125782 accuracy: 0.3254
training epoch: 1 loss: 1.68773823439
testing epoch: 1 loss: 1.42734462385 accuracy: 0.4772
training epoch: 2 loss: 1.36909840554
testing epoch: 2 loss: 1.21328892304 accuracy: 0.5669
training epoch: 3 loss: 1.1273400925
testing epoch: 3 loss: 0.970653118036 accuracy: 0.6634
training epoch: 4 loss: 0.941422916153
testing epoch: 4 loss: 0.885185660931 accuracy: 0.6911
training epoch: 5 loss: 0.811293978132
testing epoch: 5 loss: 0.770014607583 accuracy: 0.7351
training epoch: 6 loss: 0.713843686805
testing epoch: 6 loss: 0.762587684126 accuracy: 0.7432
training epoch: 7 loss: 0.640705172528
testing epoch: 7 loss: 0.716414518535 accuracy: 0.7554
training epoch: 8 loss: 0.59412489799
testing epoch: 8 loss: 0.745007351374 accuracy: 0.7571
training epoch: 9 loss: 0.559192490898
testing epoch: 9 loss: 0.786619110325 accuracy: 0.7396
training epoch: 10 loss: 0.522159374087
testing epoch: 10 loss: 0.766681446102 accuracy: 0.7504
training epoch: 11 loss: 0.500029139583
testing epoch: 11 loss: 0.739471812218 accuracy: 0.7623
training epoch: 12 loss: 0.489205267016
testing epoch: 12 loss: 0.755763524399 accuracy: 0.7664
training epoch: 13 loss: 0.479558758423
testing epoch: 13 loss: 0.801054440939 accuracy: 0.7527
training epoch: 14 loss: 0.4715372192
testing epoch: 14 loss: 0.862275337164 accuracy: 0.731
training epoch: 15 loss: 0.476616473731
testing epoch: 15 loss: 0.895243975301 accuracy: 0.7489
training epoch: 16 loss: 0.489420878252
testing epoch: 16 loss: 0.818035978717 accuracy: 0.7625
training epoch: 17 loss: 0.476847741119
testing epoch: 17 loss: 0.85643605626 accuracy: 0.7505
training epoch: 18 loss: 0.490220740435
testing epoch: 18 loss: 0.871974576729 accuracy: 0.7412
training epoch: 19 loss: 0.482988453422
testing epoch: 19 loss: 0.795684698529 accuracy: 0.7606
training epoch: 20 loss: 0.500881570677
testing epoch: 20 loss: 0.783078642747 accuracy: 0.7542
training epoch: 21 loss: 0.522443722843
testing epoch: 21 loss: 0.903285436023 accuracy: 0.7377
training epoch: 22 loss: 0.532127012151
testing epoch: 22 loss: 0.875951717123 accuracy: 0.739
training epoch: 23 loss: 0.544996414693
testing epoch: 23 loss: 0.914666230258 accuracy: 0.7358
training epoch: 24 loss: 0.555458197228
testing epoch: 24 loss: 0.860163988682 accuracy: 0.7438
training epoch: 25 loss: 0.577152646749
testing epoch: 25 loss: 0.969825881072 accuracy: 0.7164
training epoch: 26 loss: 0.574859146698
testing epoch: 26 loss: 0.866168110325 accuracy: 0.7408
training epoch: 27 loss: 0.574894333672
testing epoch: 27 loss: 0.920369815046 accuracy: 0.7318
training epoch: 28 loss: 0.603822341083
testing epoch: 28 loss: 0.969683260583 accuracy: 0.7406
training epoch: 29 loss: 0.630050289644
testing epoch: 29 loss: 0.859408840299 accuracy: 0.7507
training epoch: 30 loss: 0.642022043729
testing epoch: 30 loss: 0.890207239328 accuracy: 0.7272

終わりに

Chainerと類似しているので非常に学びやすく、
正直学習コストは殆どなかったかなと思います。

今後、どっちを使うかはケースバイケースですが、時々使っていこうと思います。