読者です 読者をやめる 読者になる 読者になる

のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

MNISTデータを読み込むプログラム(C++)

C++
Sponsored Links

皆さんこんにちは。
お元気ですか?私は中間審査が近くて泣きそうでございます。

今日はMNISTをC++で読み込んでみます。
MNISTとは、0~9まである手書き文字認識のデータセットです。MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burgesよりダウンロードが出来ます。一つ一つのデータは28x28で、機械学習のベンチマークでよく使われます。

こんなの
f:id:tereka:20140918095618g:plain

さて、今日はそれを読み込んでみます。

MNIST読み込み

Training用画像フォーマット

まず、大事なのはMNISTのデータ・フォーマットです。上記のホームページによると以下のようになっているようです。

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

1行目は何やらようわかりませんが、マジックナンバーとやらで、その次はデータの数、行数、列数となっています。
言い換えると最初の4つを読み込み残りを順々に読み込んでいけばいいということです。
行数も列のピクセル数もわかっているので、それを区切れば良いということです。

Training用ラベルフォーマット

次に、Labelです。まぁ正解ラベルがわからないと困っちゃいますね。ということでこちらもフォーマットを確認

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

ほぼ同じ要領で作ることができますね。

これをプログラムにすると以下のようになります。
StackOverFlowを参考にしました。結局ほぼサンプルコードと同じになってるような…

ソースコード

Dataset.hpp
#include <iostream>
#include <fstream>
#include <vector>

using namespace std;

class Mnist{
public:
	vector<vector<double> > readTrainingFile(string filename);
	vector<double> readLabelFile(string filename);
};
Dataset.cpp
#include "Dataset.hpp"

//バイト列からintへの変換
int reverseInt (int i) 
{
    unsigned char c1, c2, c3, c4;

    c1 = i & 255;
    c2 = (i >> 8) & 255;
    c3 = (i >> 16) & 255;
    c4 = (i >> 24) & 255;

    return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}

vector<vector<double> > Mnist::readTrainingFile(string filename){
	ifstream ifs(filename.c_str(),std::ios::in | std::ios::binary);
	int magic_number = 0;
	int number_of_images = 0;
	int rows = 0;
	int cols = 0;

	//ヘッダー部より情報を読取る。
	ifs.read((char*)&magic_number,sizeof(magic_number)); 
	magic_number= reverseInt(magic_number);
	ifs.read((char*)&number_of_images,sizeof(number_of_images));
	number_of_images= reverseInt(number_of_images);
	ifs.read((char*)&rows,sizeof(rows));
	rows= reverseInt(rows);
	ifs.read((char*)&cols,sizeof(cols));
	cols= reverseInt(cols);

	vector<vector<double> > images(number_of_images);
	cout << magic_number << " " << number_of_images << " " << rows << " " << cols << endl;

	for(int i = 0; i < number_of_images; i++){
		images[i].resize(rows * cols);

		for(int row = 0; row < rows; row++){
			for(int col = 0; col < cols; col++){
				unsigned char temp = 0;
				ifs.read((char*)&temp,sizeof(temp));
				images[i][rows*row+col] = (double)temp;
			}
		}
	}
	return images;
}

vector<double> Mnist::readLabelFile(string filename){
	ifstream ifs(filename.c_str(),std::ios::in | std::ios::binary);
	int magic_number = 0;
	int number_of_images = 0;

	//ヘッダー部より情報を読取る。
	ifs.read((char*)&magic_number,sizeof(magic_number)); 
	magic_number= reverseInt(magic_number);
	ifs.read((char*)&number_of_images,sizeof(number_of_images));
	number_of_images= reverseInt(number_of_images);

	vector<double> label(number_of_images);

	cout << number_of_images << endl;

	for(int i = 0; i < number_of_images; i++){
		unsigned char temp = 0;
		ifs.read((char*)&temp,sizeof(temp));
		label[i] = (double)temp;
	}
	return label;
}
main.cpp
#include <boost/numeric/ublas/vector.hpp>
#include <boost/numeric/ublas/matrix.hpp>
#include <boost/numeric/ublas/io.hpp>
#include "Dataset.hpp"

using namespace boost::numeric::ublas;

int main (void) {
    //絶対パスを入力してください。
    Mnist mnist;
    mnist.readTrainingFile("/train-images-idx3-ubyte");
    mnist.readLabelFile("train-labels-idx1-ubyte");
}
広告を非表示にする