のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

Lasagneを使ったNeural Networkの構築

Sponsored Links

皆さんこんにちは
お元気ですか。私は元気です。

本日は、Lasagneを紹介したいと思います。

What is Lasagne?

Lasagneは簡単にニューラルネットワークを構築できるライブラリです。
Lasagne/Lasagne · GitHub

  1. CNNやRNN(LSTM含)をサポートしている
  2. 複数の入力と複数の出力をサポートしている
  3. 複数の最適化関数を用意している。(AdamやRMSPropなど)
  4. Theanoを活用しているので、自由に誤差関数を決めることができる。また、その誤差関数については微分をする必要がない。

Kaggleとかでは、誤差関数が違うので、自分で簡単に誤差を定義できる点は大きいですね。
システムに組み込むよりも研究とかコンペティションに強い感じでしょうか。

では実際に使ってみましょう。

How to install

pip install -r https://raw.githubusercontent.com/Lasagne/Lasagne/master/requirements.txt
pip install https://github.com/Lasagne/Lasagne/archive/master.zip

Lasagne Example

もちろん、相変わらずのMNISTでいきます。
exampleを元にして作りますが、但し、少々変更したい箇所があるので、その部分は変更します。

データセットの構築

MNISTです。ほぼチュートリアルで使われているままです。

def load_dataset():
    from urllib import urlretrieve
    import cPickle as pickle

    def pickle_load(f, encoding):
        return pickle.load(f)

    url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
    filename = 'mnist.pkl.gz'
    if not os.path.exists(filename):
        print("Downloading MNIST dataset...")
        urlretrieve(url, filename)

    import gzip
    with gzip.open(filename, 'rb') as f:
        data = pickle_load(f, encoding='latin-1')

    X_train, y_train = data[0]
    X_val, y_val = data[1]
    X_test, y_test = data[2]

    return X_train,y_train.astype(np.uint8),X_val, y_val.astype(np.uint8),X_test, y_test.astype(np.uint8)

データ構築

モデルの構築を行います。2層の隠れ層を持つネットワークになります。

    neural_network = lasagne.layers.InputLayer(
        shape=(batch_size, input_dim),
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=512,
        nonlinearity=lasagne.nonlinearities.rectify,
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=64,
        nonlinearity=lasagne.nonlinearities.rectify,
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=n_classes,
        nonlinearity=lasagne.nonlinearities.softmax,
    )

目的関数の設定

今回はlasagneで用意しているCross Entropy誤差関数を活用します。

objective = lasagne.objectives.Objective(neural_network,
    loss_function=lasagne.objectives.categorical_crossentropy)
    loss = objective.get_loss(input_var, target=target_var)

学習用のパラメータと関数を設定する

学習をするためには、学習用のパラメータの全取得することと、どう更新するか必要になります。
それらの情報をlasagneを活用して取得できます。

    params = lasagne.layers.get_all_params(neural_network, trainable=True)
    updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01,
                                                momentum=0.9)
    train_fn = theano.function([input_var, target_var], loss, updates=updates)

学習

最後に学習です。やってみます。

  nlist = np.arange(N)
    for epoch in xrange(100):
        np.random.shuffle(nlist)

        for j in xrange(N / batch_size):
            ns = nlist[batch_size*j:batch_size*(j+1)]
            train_loss = train_fn(X_train[ns], y_train[ns])
        loss, acc = test(X_test, y_test)
        print("%d: train_loss=%.4f, test_loss=%.4f, test_accuracy=%.4f" % (epoch+1, train_loss, loss, acc))

標準出力

1: train_loss=0.3850, test_loss=0.2214, test_accuracy=0.9369
2: train_loss=0.1198, test_loss=0.1573, test_accuracy=0.9531
3: train_loss=0.1432, test_loss=0.1224, test_accuracy=0.9642
4: train_loss=0.0814, test_loss=0.1071, test_accuracy=0.9678
5: train_loss=0.0478, test_loss=0.1040, test_accuracy=0.9690
6: train_loss=0.1020, test_loss=0.0853, test_accuracy=0.9756
7: train_loss=0.0914, test_loss=0.0795, test_accuracy=0.9757
8: train_loss=0.0298, test_loss=0.0786, test_accuracy=0.9742
9: train_loss=0.0038, test_loss=0.0759, test_accuracy=0.9768
10: train_loss=0.0281, test_loss=0.0724, test_accuracy=0.9776
11: train_loss=0.0482, test_loss=0.0702, test_accuracy=0.9786
12: train_loss=0.0191, test_loss=0.0719, test_accuracy=0.9775
13: train_loss=0.0474, test_loss=0.0699, test_accuracy=0.9792
14: train_loss=0.0104, test_loss=0.0700, test_accuracy=0.9799
15: train_loss=0.0085, test_loss=0.0696, test_accuracy=0.9788
16: train_loss=0.0450, test_loss=0.0688, test_accuracy=0.9800
17: train_loss=0.0092, test_loss=0.0677, test_accuracy=0.9798
18: train_loss=0.0104, test_loss=0.0709, test_accuracy=0.9798
19: train_loss=0.0084, test_loss=0.0682, test_accuracy=0.9799
20: train_loss=0.0035, test_loss=0.0707, test_accuracy=0.9791
21: train_loss=0.0235, test_loss=0.0701, test_accuracy=0.9803
22: train_loss=0.0033, test_loss=0.0693, test_accuracy=0.9812
23: train_loss=0.0014, test_loss=0.0694, test_accuracy=0.9804
24: train_loss=0.0027, test_loss=0.0700, test_accuracy=0.9804
25: train_loss=0.0021, test_loss=0.0709, test_accuracy=0.9804
26: train_loss=0.0032, test_loss=0.0715, test_accuracy=0.9801
27: train_loss=0.0012, test_loss=0.0722, test_accuracy=0.9797
28: train_loss=0.0043, test_loss=0.0710, test_accuracy=0.9809
29: train_loss=0.0021, test_loss=0.0714, test_accuracy=0.9810
30: train_loss=0.0010, test_loss=0.0727, test_accuracy=0.9802
31: train_loss=0.0051, test_loss=0.0724, test_accuracy=0.9807
32: train_loss=0.0022, test_loss=0.0726, test_accuracy=0.9808
33: train_loss=0.0018, test_loss=0.0734, test_accuracy=0.9806
34: train_loss=0.0012, test_loss=0.0732, test_accuracy=0.9810
35: train_loss=0.0013, test_loss=0.0738, test_accuracy=0.9805
36: train_loss=0.0016, test_loss=0.0741, test_accuracy=0.9805
37: train_loss=0.0009, test_loss=0.0735, test_accuracy=0.9809
38: train_loss=0.0013, test_loss=0.0747, test_accuracy=0.9803
39: train_loss=0.0013, test_loss=0.0751, test_accuracy=0.9809
40: train_loss=0.0004, test_loss=0.0746, test_accuracy=0.9810
41: train_loss=0.0005, test_loss=0.0758, test_accuracy=0.9807
42: train_loss=0.0011, test_loss=0.0767, test_accuracy=0.9803
43: train_loss=0.0007, test_loss=0.0767, test_accuracy=0.9805
44: train_loss=0.0005, test_loss=0.0761, test_accuracy=0.9805
45: train_loss=0.0016, test_loss=0.0755, test_accuracy=0.9810
46: train_loss=0.0006, test_loss=0.0766, test_accuracy=0.9804
47: train_loss=0.0019, test_loss=0.0769, test_accuracy=0.9805
48: train_loss=0.0004, test_loss=0.0771, test_accuracy=0.9804
49: train_loss=0.0007, test_loss=0.0770, test_accuracy=0.9807
50: train_loss=0.0009, test_loss=0.0775, test_accuracy=0.9807
51: train_loss=0.0002, test_loss=0.0779, test_accuracy=0.9806
52: train_loss=0.0007, test_loss=0.0774, test_accuracy=0.9813
53: train_loss=0.0006, test_loss=0.0777, test_accuracy=0.9807
54: train_loss=0.0006, test_loss=0.0779, test_accuracy=0.9809
55: train_loss=0.0003, test_loss=0.0779, test_accuracy=0.9806
56: train_loss=0.0017, test_loss=0.0782, test_accuracy=0.9806
57: train_loss=0.0005, test_loss=0.0782, test_accuracy=0.9808
58: train_loss=0.0005, test_loss=0.0793, test_accuracy=0.9810
59: train_loss=0.0007, test_loss=0.0788, test_accuracy=0.9809
60: train_loss=0.0009, test_loss=0.0795, test_accuracy=0.9805
61: train_loss=0.0007, test_loss=0.0791, test_accuracy=0.9813
62: train_loss=0.0007, test_loss=0.0797, test_accuracy=0.9808
63: train_loss=0.0004, test_loss=0.0803, test_accuracy=0.9809
64: train_loss=0.0010, test_loss=0.0799, test_accuracy=0.9806
65: train_loss=0.0004, test_loss=0.0804, test_accuracy=0.9808
66: train_loss=0.0006, test_loss=0.0801, test_accuracy=0.9809
67: train_loss=0.0005, test_loss=0.0798, test_accuracy=0.9814
68: train_loss=0.0003, test_loss=0.0803, test_accuracy=0.9807
69: train_loss=0.0003, test_loss=0.0806, test_accuracy=0.9809
70: train_loss=0.0004, test_loss=0.0809, test_accuracy=0.9807
71: train_loss=0.0007, test_loss=0.0809, test_accuracy=0.9807
72: train_loss=0.0006, test_loss=0.0808, test_accuracy=0.9808
73: train_loss=0.0004, test_loss=0.0817, test_accuracy=0.9804
74: train_loss=0.0004, test_loss=0.0816, test_accuracy=0.9808
75: train_loss=0.0005, test_loss=0.0814, test_accuracy=0.9808
76: train_loss=0.0003, test_loss=0.0814, test_accuracy=0.9807
77: train_loss=0.0003, test_loss=0.0816, test_accuracy=0.9809
78: train_loss=0.0002, test_loss=0.0821, test_accuracy=0.9811
79: train_loss=0.0005, test_loss=0.0819, test_accuracy=0.9807
80: train_loss=0.0003, test_loss=0.0819, test_accuracy=0.9808
81: train_loss=0.0002, test_loss=0.0821, test_accuracy=0.9810
82: train_loss=0.0008, test_loss=0.0819, test_accuracy=0.9807
83: train_loss=0.0007, test_loss=0.0825, test_accuracy=0.9808
84: train_loss=0.0003, test_loss=0.0822, test_accuracy=0.9809
85: train_loss=0.0005, test_loss=0.0822, test_accuracy=0.9812
86: train_loss=0.0002, test_loss=0.0830, test_accuracy=0.9809
87: train_loss=0.0010, test_loss=0.0828, test_accuracy=0.9809
88: train_loss=0.0003, test_loss=0.0830, test_accuracy=0.9808
89: train_loss=0.0003, test_loss=0.0829, test_accuracy=0.9808
90: train_loss=0.0004, test_loss=0.0830, test_accuracy=0.9808
91: train_loss=0.0003, test_loss=0.0835, test_accuracy=0.9809
92: train_loss=0.0005, test_loss=0.0833, test_accuracy=0.9810
93: train_loss=0.0004, test_loss=0.0834, test_accuracy=0.9809
94: train_loss=0.0003, test_loss=0.0836, test_accuracy=0.9809
95: train_loss=0.0002, test_loss=0.0835, test_accuracy=0.9811
96: train_loss=0.0002, test_loss=0.0837, test_accuracy=0.9808
97: train_loss=0.0002, test_loss=0.0839, test_accuracy=0.9808
98: train_loss=0.0005, test_loss=0.0844, test_accuracy=0.9805
99: train_loss=0.0001, test_loss=0.0843, test_accuracy=0.9807
100: train_loss=0.0004, test_loss=0.0845, test_accuracy=0.9807

ソースコード全文

import os

import numpy as np
import theano
import theano.tensor as T

import lasagne

def load_dataset():
    from urllib import urlretrieve
    import cPickle as pickle

    def pickle_load(f, encoding):
        return pickle.load(f)

    url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
    filename = 'mnist.pkl.gz'
    if not os.path.exists(filename):
        print("Downloading MNIST dataset...")
        urlretrieve(url, filename)

    import gzip
    with gzip.open(filename, 'rb') as f:
        data = pickle_load(f, encoding='latin-1')

    X_train, y_train = data[0]
    X_val, y_val = data[1]
    X_test, y_test = data[2]

    return X_train,y_train.astype(np.uint8),X_val, y_val.astype(np.uint8),X_test, y_test.astype(np.uint8)


def main():

    X_train,y_train,X_val, y_val,X_test, y_test = load_dataset()
    N,input_dim = X_train.shape
    input_var = T.matrix('x')
    target_var = T.ivector('y')
    batch_size = 100

    n_classes = 10

    print(X_train.shape,y_train.shape)

    neural_network = lasagne.layers.InputLayer(
        shape=(batch_size, input_dim),
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=512,
        nonlinearity=lasagne.nonlinearities.rectify,
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=64,
        nonlinearity=lasagne.nonlinearities.rectify,
    )
    neural_network = lasagne.layers.DenseLayer(
        neural_network,
        num_units=n_classes,
        nonlinearity=lasagne.nonlinearities.softmax,
    )

    objective = lasagne.objectives.Objective(neural_network,
    loss_function=lasagne.objectives.categorical_crossentropy)
    loss = objective.get_loss(input_var, target=target_var)

    params = lasagne.layers.get_all_params(neural_network, trainable=True)
    updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01,
                                                momentum=0.9)
    train_fn = theano.function([input_var, target_var], loss, updates=updates)

    loss_eval = objective.get_loss(input_var, target=target_var,
                                   deterministic=True)
    pred = T.argmax(
        lasagne.layers.get_output(neural_network, input_var, deterministic=True),
        axis=1)
    accuracy = T.mean(T.eq(pred, target_var), dtype=theano.config.floatX)
    test = theano.function([input_var, target_var], [loss_eval, accuracy])

    nlist = np.arange(N)
    for epoch in xrange(100):
        np.random.shuffle(nlist)

        for j in xrange(N / batch_size):
            ns = nlist[batch_size*j:batch_size*(j+1)]
            train_loss = train_fn(X_train[ns], y_train[ns])
        loss, acc = test(X_test, y_test)
        print("%d: train_loss=%.4f, test_loss=%.4f, test_accuracy=%.4f" % (epoch+1, train_loss, loss, acc))

if __name__ == '__main__':
    main()