NNFS
Neural network library from scratch
Loading...
Searching...
No Matches
mnist.cpp
Go to the documentation of this file.
1#include <stdio.h>
2#include <iostream>
3#include <fstream>
4#include <vector>
5#include <string>
6#include <cmath>
7#include <algorithm>
8#include <random>
9#include <filesystem>
10
11#include <stdio.h>
12#include <stdlib.h>
13#include <unistd.h>
14
15#include <Eigen/Core>
16
17#include <curl/curl.h>
18#include <zlib.h>
19
20const std::string BASE_URL = "http://yann.lecun.com/exdb/mnist/";
21
27bool file_exists(const std::string &filename)
28{
29 std::ifstream infile(filename);
30 return infile.good();
31}
32
33static size_t write_data(void *ptr, size_t size, size_t nmemb, void *stream)
34{
35 size_t written = fwrite(ptr, size, nmemb, (FILE *)stream);
36 return written;
37}
38
44void download_file(const std::string &url, const std::string &path)
45{
46 CURL *curl_handle;
47 FILE *out;
48
49 curl_global_init(CURL_GLOBAL_ALL);
50
51 /* init the curl session */
52 curl_handle = curl_easy_init();
53
54 /* set URL to get here */
55 curl_easy_setopt(curl_handle, CURLOPT_URL, url.c_str());
56
57 /* Switch on full protocol/debug output while testing */
58 curl_easy_setopt(curl_handle, CURLOPT_VERBOSE, 1L);
59
60 /* disable progress meter, set to 0L to enable and disable debug output */
61 curl_easy_setopt(curl_handle, CURLOPT_NOPROGRESS, 1L);
62
63 /* send all data to this function */
64 curl_easy_setopt(curl_handle, CURLOPT_WRITEFUNCTION, write_data);
65
66 /* open the file */
67 out = fopen(path.c_str(), "wb");
68 if (out)
69 {
70
71 /* write the page body to this file handle */
72 curl_easy_setopt(curl_handle, CURLOPT_WRITEDATA, out);
73
74 /* get it! */
75 curl_easy_perform(curl_handle);
76
77 /* close the header file */
78 fclose(out);
79 }
80
81 /* cleanup curl stuff */
82 curl_easy_cleanup(curl_handle);
83
84 curl_global_cleanup();
85}
86
93bool unzip_file(const std::string &gz_path, const std::string &out_path)
94{
95 gzFile gz = gzopen(gz_path.c_str(), "rb");
96 if (!gz)
97 {
98 std::cerr << "Error opening gz file: " << gz_path << std::endl;
99 return false;
100 }
101
102 std::ofstream outfile(out_path, std::ofstream::binary);
103 if (!outfile.is_open())
104 {
105 std::cerr << "Error opening output file: " << out_path << std::endl;
106 return false;
107 }
108
109 char buffer[1024];
110 int uncompressed_bytes;
111 while ((uncompressed_bytes = gzread(gz, buffer, sizeof(buffer))) > 0)
112 {
113 outfile.write(buffer, uncompressed_bytes);
114 }
115
116 outfile.close();
117 gzclose(gz);
118
119 return true;
120}
121
128Eigen::MatrixXd read_mnist_images(const std::string &filename)
129{
130 std::ifstream file(filename, std::ios::binary);
131 if (!file)
132 {
133 std::cout << "Error: Failed to open file: " << filename.c_str() << std::endl;
134 return Eigen::MatrixXd();
135 }
136
137 int magic_number = 0;
138 int num_images = 0;
139 int num_rows = 0;
140 int num_cols = 0;
141
142 file.read((char *)&magic_number, sizeof(magic_number));
143 magic_number = ntohl(magic_number);
144
145 if (magic_number != 2051)
146 {
147 std::cout << "Error: Invalid magic number in file: " << filename.c_str() << std::endl;
148 return Eigen::MatrixXd();
149 }
150
151 file.read((char *)&num_images, sizeof(num_images));
152 file.read((char *)&num_rows, sizeof(num_rows));
153 file.read((char *)&num_cols, sizeof(num_cols));
154
155 num_images = ntohl(num_images);
156 num_rows = ntohl(num_rows);
157 num_cols = ntohl(num_cols);
158
159 Eigen::MatrixXd images(num_rows * num_cols, num_images);
160
161 for (int i = 0; i < num_images; ++i)
162 {
163 for (int j = 0; j < num_rows * num_cols; ++j)
164 {
165 unsigned char pixel = 0;
166 file.read((char *)&pixel, sizeof(pixel));
167 images(j, i) = static_cast<double>(pixel) / 255.;
168 }
169 }
170 return images.transpose();
171}
172
179Eigen::MatrixXd read_mnist_labels(const std::string &filename)
180{
181 std::ifstream file(filename, std::ios::binary);
182 if (!file)
183 {
184 std::cout << "Error: Failed to open file: " << filename.c_str() << std::endl;
185 return Eigen::MatrixXd();
186 }
187
188 int magic_number = 0;
189 int num_labels = 0;
190
191 file.read((char *)&magic_number, sizeof(magic_number));
192 magic_number = ntohl(magic_number);
193
194 if (magic_number != 2049)
195 {
196 std::cout << "Error: Invalid magic number in file: " << filename.c_str();
197 return Eigen::MatrixXd();
198 }
199
200 file.read((char *)&num_labels, sizeof(num_labels));
201 num_labels = ntohl(num_labels);
202
203 Eigen::MatrixXd labels = Eigen::MatrixXd::Zero(10, num_labels);
204
205 for (int i = 0; i < num_labels; ++i)
206 {
207 unsigned char label = 0;
208 file.read((char *)&label, sizeof(label));
209 int index = static_cast<int>(label);
210 labels(index, i) = 1.0;
211 }
212 return labels.transpose();
213}
214
221std::tuple<Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd>
222fetch_mnist(const std::string &data_dir)
223{
224 const std::string train_images_url = BASE_URL + "train-images-idx3-ubyte.gz";
225 const std::string train_images_path = data_dir + "/train-images-idx3-ubyte.gz";
226 const std::string train_images_file = data_dir + "/train-images-idx3-ubyte";
227 const std::string train_labels_url = BASE_URL + "train-labels-idx1-ubyte.gz";
228 const std::string train_labels_path = data_dir + "/train-labels-idx1-ubyte.gz";
229 const std::string train_labels_file = data_dir + "/train-labels-idx1-ubyte";
230 const std::string test_images_url = BASE_URL + "t10k-images-idx3-ubyte.gz";
231 const std::string test_images_path = data_dir + "/t10k-images-idx3-ubyte.gz";
232 const std::string test_images_file = data_dir + "/t10k-images-idx3-ubyte";
233 const std::string test_labels_url = BASE_URL + "t10k-labels-idx1-ubyte.gz";
234 const std::string test_labels_path = data_dir + "/t10k-labels-idx1-ubyte.gz";
235 const std::string test_labels_file = data_dir + "/t10k-labels-idx1-ubyte";
236
237 if (!std::filesystem::exists(train_images_file) ||
238 !std::filesystem::exists(train_labels_file) ||
239 !std::filesystem::exists(test_images_file) ||
240 !std::filesystem::exists(test_labels_file))
241 {
242 std::filesystem::create_directories(data_dir);
243
244 std::cout << "Downloading MNIST training images..." << std::endl;
245 download_file(train_images_url, train_images_path);
246
247 std::cout << "Downloading MNIST training labels..." << std::endl;
248 download_file(train_labels_url, train_labels_path);
249
250 std::cout << "Downloading MNIST test images..." << std::endl;
251 download_file(test_images_url, test_images_path);
252
253 std::cout << "Downloading MNIST test labels..." << std::endl;
254 download_file(test_labels_url, test_labels_path);
255
256 std::cout << "Extracting MNIST training images..." << std::endl;
257 unzip_file(train_images_path, train_images_file);
258
259 std::cout << "Extracting MNIST training labels..." << std::endl;
260 unzip_file(train_labels_path, train_labels_file);
261
262 std::cout << "Extracting MNIST test images..." << std::endl;
263 unzip_file(test_images_path, test_images_file);
264
265 std::cout << "Extracting MNIST test labels..." << std::endl;
266 unzip_file(test_labels_path, test_labels_file);
267 }
268
269 Eigen::MatrixXd x_train = read_mnist_images(train_images_file);
270 Eigen::MatrixXd y_train = read_mnist_labels(train_labels_file);
271
272 Eigen::MatrixXd x_test = read_mnist_images(test_images_file);
273 Eigen::MatrixXd y_test = read_mnist_labels(test_labels_file);
274
275 return std::make_tuple(x_train, y_train, x_test, y_test);
276};
const std::string BASE_URL
Definition mnist.cpp:20
bool unzip_file(const std::string &gz_path, const std::string &out_path)
Unzip gz file and save to local path.
Definition mnist.cpp:93
Eigen::MatrixXd read_mnist_images(const std::string &filename)
Read a MNIST image file.
Definition mnist.cpp:128
Eigen::MatrixXd read_mnist_labels(const std::string &filename)
Read a MNIST label file.
Definition mnist.cpp:179
std::tuple< Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd > fetch_mnist(const std::string &data_dir)
Fetch MNIST dataset and save to local directory if it doesn't exist.
Definition mnist.cpp:222
void download_file(const std::string &url, const std::string &path)
Download file from URL and save to local path.
Definition mnist.cpp:44
bool file_exists(const std::string &filename)
Check if file exists.
Definition mnist.cpp:27