のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

Chainerにおけるグラフ構造をループで書いてみる。

Sponsored Links

皆さんこんにちは
お元気ですか。私は元気です。

実は私、Chainerでのfor文でLinkとして作成できることを知らず、
今の複雑なネットワークにChainer使いにくいと思っていましたが、以下にサンプルがあって
こうすれば複雑なネットワークも組めるんだ。みたいなところがわかりました。

Deep Residual Network definition by Chainer · GitHub

ChainerのLink構造について

以下のスライドにChainer version1.5のチュートリアル解説があります。
このうち今回で必要な情報はChain、Link、Functionが何を示しているのかです。

www.slideshare.net

chainerの関数 概要
chainer.Function 関数
chainer.Link パラメータ付き関数
chainer.Chain パラメータ付き関数集合

これに基づいて、関数集合を構築していけば良いといったところです。

※上記の関数はv1.5以降です。1.4以前では多分異なるので気をつけてください。

グラフ構造を容易にかくには

部分構造を構築する。

通常に計算できるパラメータ付き関数集合を作りつつ、Linkに突っ込めば実装可能です。
後は、linkから必要な情報を取り出し、forwardを構築するのみです。

今回はVGGNetを用いて実施してみます。
VGGNetの部分構造であれば、以下のように書くことができます。
例えば、ニューラルネットワークのとある箇所を部分的に書くと以下のようになります。

class RoopBlock(chainer.Chain):
    def __init__(self,n_in,n_out,stride=1):
        super(RoopBlock,self).__init__(
            conv1 = L.Convolution2D(n_in,n_out,3,stride,1),
            conv2 = L.Convolution2D(n_out,n_in,3,stride,1),
            conv3 = L.Convolution2D(n_out,n_in,3,stride,1)
        )

    def __call__(self,x,t):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))

        return h

Chainと呼ばれる関数集合を構築することをまず行います。
これをforなどを使ってうまく書くと、VGGNetを以下のように記述することができます。

class VGGNet(chainer.Chain):
    def __init__(self):
        super(VGGNet,self).__init__()

        links = [("root0",RoopBlock(3,64))]
        n_in = 64
        n_out = 128
        for index in xrange(1,5,1):
            links += [("root{}".format(index),RoopBlock(n_in,n_out))]

            n_in *= 2
            n_out *= 2
        links += [("fc"),L.Linear(25088, 1000)]
        self.forward = links
        for link in links:
            self.add_link(*link)

        self.train = True

    def __call__(self, x, t):
        for name,func in self.forward:
            x = func(x)
        if self.train:
            self.loss = F.softmax_cross_entropy(x,t)
            self.accuracy = F.accuracy(x, t)
            return self.loss
        else:
            return F.softmax(x)

この書き方によるメリットは層を一つ増やしたいとなった場合に簡単に追加できることです。
add_linkを使うことで、パラメータをリンクとして登録しておきます。

ResNetを実際に実験するにあたって調べてて見つけた内容ですが、
この方法はチュートリアルにも掲載されていないので、あんまり見つけられないかもしれません。

参考文献は以下の通り。

GitHub - mitmul/chainer-cifar10: Various CNN models including Deep Residual Networks (ResNet) for CIFAR10 with Chainer (http://chainer.org)