Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorromang <romang@amu.edu.pl>2016-09-13 19:08:45 +0300
committerromang <romang@amu.edu.pl>2016-09-13 19:13:23 +0300
commitcc7a48310f19423f269f19330c32090b866d3c90 (patch)
tree3e009f0e7c009f7f0f914e60e7dfa6b5294c3b1f
parent803a562d4b564ad3b241a07e397b82d7add77082 (diff)
add functions loading MNIST dataset
-rw-r--r--.gitignore1
-rw-r--r--examples/mnist/Makefile7
-rw-r--r--src/mnist.h94
-rw-r--r--src/test.cu4
4 files changed, 104 insertions, 2 deletions
diff --git a/.gitignore b/.gitignore
index 4dfd397b..53468680 100644
--- a/.gitignore
+++ b/.gitignore
@@ -39,3 +39,4 @@ build
# Examples
examples/*/*.gz
+examples/mnist/*ubyte
diff --git a/examples/mnist/Makefile b/examples/mnist/Makefile
index 7e4e812f..26f65554 100644
--- a/examples/mnist/Makefile
+++ b/examples/mnist/Makefile
@@ -2,9 +2,12 @@
all: download
-download: train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx3-ubyte.gz
+download: train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
-%.gz:
+%-ubyte: %-ubyte.gz
+ gzip -d < $^ > $@
+
+%-ubyte.gz:
wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@
clean:
diff --git a/src/mnist.h b/src/mnist.h
new file mode 100644
index 00000000..7727bacc
--- /dev/null
+++ b/src/mnist.h
@@ -0,0 +1,94 @@
+#pragma once
+
+#include <string>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+namespace datasets {
+namespace mnist {
+
+typedef unsigned char uchar;
+
+auto reverseInt = [](int i) {
+ unsigned char c1, c2, c3, c4;
+ c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
+ return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
+};
+
+std::vector<std::vector<float>> ReadImages(const std::string& full_path) {
+ std::ifstream file(full_path);
+
+ if (! file.is_open())
+ throw std::runtime_error("Cannot open file `" + full_path + "`!");
+
+ int magic_number = 0, n_rows = 0, n_cols = 0;
+
+ file.read((char *)&magic_number, sizeof(magic_number));
+ magic_number = reverseInt(magic_number);
+
+ if (magic_number != 2051)
+ throw std::runtime_error("Invalid MNIST image file!");
+
+ int number_of_images = 0;
+ file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
+ file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
+ file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
+
+ int image_size = n_rows * n_cols;
+ std::vector<std::vector<float>> _dataset(number_of_images, std::vector<float>(image_size));
+ unsigned char pixel = 0;
+
+ for (int i = 0; i < number_of_images; i++) {
+ for (int j = 0; j < image_size; j++) {
+ file.read((char*)&pixel, sizeof(pixel));
+ _dataset[i][j] = pixel / 255.0f;
+ }
+ }
+ return _dataset;
+}
+
+std::vector<int> ReadLabels(const std::string& full_path) {
+ std::ifstream file(full_path);
+
+ if (! file.is_open())
+ throw std::runtime_error("Cannot open file `" + full_path + "`!");
+
+ int magic_number = 0;
+ file.read((char *)&magic_number, sizeof(magic_number));
+ magic_number = reverseInt(magic_number);
+
+ if (magic_number != 2049)
+ throw std::runtime_error("Invalid MNIST label file!");
+
+ int number_of_labels = 0;
+ file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
+
+ std::vector<int> _dataset(number_of_labels);
+ for (int i = 0; i < number_of_labels; i++) {
+ file.read((char*)&_dataset[i], 1);
+ }
+
+ return _dataset;
+}
+
+} // namespace mnist
+} // namespace datasets
+
+
+//int main(int argc, const char *argv[]) {
+ //auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte");
+ //auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte");
+
+ //std::cout
+ //<< "Number of images: " << images.size() << std::endl
+ //<< "Image size: " << images[0].size() << std::endl;
+
+ //for (int i = 0; i < 3; i++) {
+ //for (int j = 0; j < images[i].size(); j++) {
+ //std::cout << images[i][j] << ",";
+ //}
+ //std::cout << " label=" << (int)labels[i] << std::endl;
+ //}
+ //return 0;
+//}
diff --git a/src/test.cu b/src/test.cu
index 4a2445fd..c2b0d62e 100644
--- a/src/test.cu
+++ b/src/test.cu
@@ -1,9 +1,13 @@
#include "marian.h"
+#include "mnist.h"
using namespace std;
int main(int argc, char** argv) {
+ /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/
+ /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/
+ /*std::cerr << images.size() << " " << images[0].size() << std::endl;*/
using namespace marian;
using namespace keywords;