Sklearnを使った機械学習
皆さんこんにちは
お元気ですか。私は元気です。
さて、今日はSklearnを使ってみたいと思います。
Sklearnとは?
Pythonで使える機械学習のライブラリです。
インストールですが、以下のとおり
sudo pip install scikit-learn
How to Use
基本的に殆どの機械学習アルゴリズムの流れは以下のとおりです。これを抑えれば使えるでしょう。
1.学習させたいデータを用意する
2.学習させます
3.予測したいデータを入れてみます
4.できあがり
実際にやってみましょう。面倒なので、用意されているデータセットを使います。
1.学習させたいデータセットを用意する
今回は試しにやってみるだけなので、適当に準備します。
試験用手書き文字認識データセットがあるので利用させていただきましょう。
from sklearn.datasets import load_digits from sklearn.cross_validation import train_test_split digits = load_digits() print(digits.data.shape) data_train, data_test, label_train, label_test = train_test_split(digits.data, digits.target)
printすると(1797, 64)と出ます。つまり、1797データ64データということですね。
sklearn.cross_validationにトレーニングとテストの自動分割があるので、それを利用させて頂きます。
2.学習
さて、学習しましょう。アルゴリズムの設定、手法をまずは考えないといけないのですが、今回はSVMを使って、Classificationを行います。
クラスは0~9までの10通りです。(本来SVMは2値分類を行いますが、One vs Restと呼ばれるアルゴリズムを使って複数の分類を行っているようです)
from sklearn import svm lin_svc = svm.LinearSVC() lin_svc.fit(data_train, label_train)
3.予測です。
lin_svc.predict(data_test)
基本的にはこれだけです。しかし、正答率とかもちろんみたいですよね?
簡単な正答率はこのような形式で見れます。
from sklearn.metrics import classification_report print accuracy_score(label_test, predict)
正答率
0.946666666667
しかし、この状態だとどれがどのくらい正解かどうかわかりません。そこで、以下のようなメソッドを実行するとどの程度どれが正しかったかどうかわかります。
from sklearn.metrics import classification_report print classification_report(label_test, predict)
詳細な表
precision recall f1-score support 0 1.00 0.98 0.99 47 1 0.91 0.85 0.88 47 2 1.00 0.98 0.99 45 3 0.94 1.00 0.97 46 4 0.98 0.98 0.98 49 5 0.90 0.98 0.94 54 6 0.93 0.95 0.94 39 7 1.00 0.98 0.99 41 8 0.85 0.90 0.88 39 9 0.97 0.86 0.91 43 avg / total 0.95 0.95 0.95 450
4.できあがり
これだけあれば、ぶっちゃけたいていのことはできます。
ソースコード全文
from sklearn.datasets import load_digits from sklearn.cross_validation import train_test_split from sklearn import svm from sklearn.metrics import accuracy_score,classification_report digits = load_digits() print(digits.data.shape) data_train, data_test, label_train, label_test = train_test_split(digits.data, digits.target) lin_svc = svm.LinearSVC() lin_svc.fit(data_train, label_train) predict = lin_svc.predict(data_test) print accuracy_score(label_test, predict) print classification_report(label_test, predict)