MNISTデータを読み込むプログラム(C++)
Sponsored Links
皆さんこんにちは。
お元気ですか?私は中間審査が近くて泣きそうでございます。
今日はMNISTをC++で読み込んでみます。
MNISTとは、0~9まである手書き文字認識のデータセットです。MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burgesよりダウンロードが出来ます。一つ一つのデータは28x28で、機械学習のベンチマークでよく使われます。
こんなの
さて、今日はそれを読み込んでみます。
MNIST読み込み
Training用画像フォーマット
まず、大事なのはMNISTのデータ・フォーマットです。上記のホームページによると以下のようになっているようです。
[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 60000 number of images 0008 32 bit integer 28 number of rows 0012 32 bit integer 28 number of columns 0016 unsigned byte ?? pixel 0017 unsigned byte ?? pixel ........ xxxx unsigned byte ?? pixel
1行目は何やらようわかりませんが、マジックナンバーとやらで、その次はデータの数、行数、列数となっています。
言い換えると最初の4つを読み込み残りを順々に読み込んでいけばいいということです。
行数も列のピクセル数もわかっているので、それを区切れば良いということです。
Training用ラベルフォーマット
次に、Labelです。まぁ正解ラベルがわからないと困っちゃいますね。ということでこちらもフォーマットを確認
[offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9.
ほぼ同じ要領で作ることができますね。
これをプログラムにすると以下のようになります。
StackOverFlowを参考にしました。結局ほぼサンプルコードと同じになってるような…
ソースコード
Dataset.hpp
#include <iostream> #include <fstream> #include <vector> using namespace std; class Mnist{ public: vector<vector<double> > readTrainingFile(string filename); vector<double> readLabelFile(string filename); };
Dataset.cpp
#include "Dataset.hpp" //バイト列からintへの変換 int reverseInt (int i) { unsigned char c1, c2, c3, c4; c1 = i & 255; c2 = (i >> 8) & 255; c3 = (i >> 16) & 255; c4 = (i >> 24) & 255; return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; } vector<vector<double> > Mnist::readTrainingFile(string filename){ ifstream ifs(filename.c_str(),std::ios::in | std::ios::binary); int magic_number = 0; int number_of_images = 0; int rows = 0; int cols = 0; //ヘッダー部より情報を読取る。 ifs.read((char*)&magic_number,sizeof(magic_number)); magic_number= reverseInt(magic_number); ifs.read((char*)&number_of_images,sizeof(number_of_images)); number_of_images= reverseInt(number_of_images); ifs.read((char*)&rows,sizeof(rows)); rows= reverseInt(rows); ifs.read((char*)&cols,sizeof(cols)); cols= reverseInt(cols); vector<vector<double> > images(number_of_images); cout << magic_number << " " << number_of_images << " " << rows << " " << cols << endl; for(int i = 0; i < number_of_images; i++){ images[i].resize(rows * cols); for(int row = 0; row < rows; row++){ for(int col = 0; col < cols; col++){ unsigned char temp = 0; ifs.read((char*)&temp,sizeof(temp)); images[i][rows*row+col] = (double)temp; } } } return images; } vector<double> Mnist::readLabelFile(string filename){ ifstream ifs(filename.c_str(),std::ios::in | std::ios::binary); int magic_number = 0; int number_of_images = 0; //ヘッダー部より情報を読取る。 ifs.read((char*)&magic_number,sizeof(magic_number)); magic_number= reverseInt(magic_number); ifs.read((char*)&number_of_images,sizeof(number_of_images)); number_of_images= reverseInt(number_of_images); vector<double> label(number_of_images); cout << number_of_images << endl; for(int i = 0; i < number_of_images; i++){ unsigned char temp = 0; ifs.read((char*)&temp,sizeof(temp)); label[i] = (double)temp; } return label; }
main.cpp
#include <boost/numeric/ublas/vector.hpp> #include <boost/numeric/ublas/matrix.hpp> #include <boost/numeric/ublas/io.hpp> #include "Dataset.hpp" using namespace boost::numeric::ublas; int main (void) { //絶対パスを入力してください。 Mnist mnist; mnist.readTrainingFile("/train-images-idx3-ubyte"); mnist.readLabelFile("train-labels-idx1-ubyte"); }