20const std::string
BASE_URL =
"http://yann.lecun.com/exdb/mnist/";
29 std::ifstream infile(filename);
33static size_t write_data(
void *ptr,
size_t size,
size_t nmemb,
void *stream)
35 size_t written = fwrite(ptr, size, nmemb, (FILE *)stream);
49 curl_global_init(CURL_GLOBAL_ALL);
52 curl_handle = curl_easy_init();
55 curl_easy_setopt(curl_handle, CURLOPT_URL, url.c_str());
58 curl_easy_setopt(curl_handle, CURLOPT_VERBOSE, 1L);
61 curl_easy_setopt(curl_handle, CURLOPT_NOPROGRESS, 1L);
64 curl_easy_setopt(curl_handle, CURLOPT_WRITEFUNCTION, write_data);
67 out = fopen(path.c_str(),
"wb");
72 curl_easy_setopt(curl_handle, CURLOPT_WRITEDATA, out);
75 curl_easy_perform(curl_handle);
82 curl_easy_cleanup(curl_handle);
84 curl_global_cleanup();
93bool unzip_file(
const std::string &gz_path,
const std::string &out_path)
95 gzFile gz = gzopen(gz_path.c_str(),
"rb");
98 std::cerr <<
"Error opening gz file: " << gz_path << std::endl;
102 std::ofstream outfile(out_path, std::ofstream::binary);
103 if (!outfile.is_open())
105 std::cerr <<
"Error opening output file: " << out_path << std::endl;
110 int uncompressed_bytes;
111 while ((uncompressed_bytes = gzread(gz, buffer,
sizeof(buffer))) > 0)
113 outfile.write(buffer, uncompressed_bytes);
130 std::ifstream file(filename, std::ios::binary);
133 std::cout <<
"Error: Failed to open file: " << filename.c_str() << std::endl;
134 return Eigen::MatrixXd();
137 int magic_number = 0;
142 file.read((
char *)&magic_number,
sizeof(magic_number));
143 magic_number = ntohl(magic_number);
145 if (magic_number != 2051)
147 std::cout <<
"Error: Invalid magic number in file: " << filename.c_str() << std::endl;
148 return Eigen::MatrixXd();
151 file.read((
char *)&num_images,
sizeof(num_images));
152 file.read((
char *)&num_rows,
sizeof(num_rows));
153 file.read((
char *)&num_cols,
sizeof(num_cols));
155 num_images = ntohl(num_images);
156 num_rows = ntohl(num_rows);
157 num_cols = ntohl(num_cols);
159 Eigen::MatrixXd images(num_rows * num_cols, num_images);
161 for (
int i = 0; i < num_images; ++i)
163 for (
int j = 0; j < num_rows * num_cols; ++j)
165 unsigned char pixel = 0;
166 file.read((
char *)&pixel,
sizeof(pixel));
167 images(j, i) =
static_cast<double>(pixel) / 255.;
170 return images.transpose();
181 std::ifstream file(filename, std::ios::binary);
184 std::cout <<
"Error: Failed to open file: " << filename.c_str() << std::endl;
185 return Eigen::MatrixXd();
188 int magic_number = 0;
191 file.read((
char *)&magic_number,
sizeof(magic_number));
192 magic_number = ntohl(magic_number);
194 if (magic_number != 2049)
196 std::cout <<
"Error: Invalid magic number in file: " << filename.c_str();
197 return Eigen::MatrixXd();
200 file.read((
char *)&num_labels,
sizeof(num_labels));
201 num_labels = ntohl(num_labels);
203 Eigen::MatrixXd labels = Eigen::MatrixXd::Zero(10, num_labels);
205 for (
int i = 0; i < num_labels; ++i)
207 unsigned char label = 0;
208 file.read((
char *)&label,
sizeof(label));
209 int index =
static_cast<int>(label);
210 labels(index, i) = 1.0;
212 return labels.transpose();
221std::tuple<Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd>
224 const std::string train_images_url =
BASE_URL +
"train-images-idx3-ubyte.gz";
225 const std::string train_images_path = data_dir +
"/train-images-idx3-ubyte.gz";
226 const std::string train_images_file = data_dir +
"/train-images-idx3-ubyte";
227 const std::string train_labels_url =
BASE_URL +
"train-labels-idx1-ubyte.gz";
228 const std::string train_labels_path = data_dir +
"/train-labels-idx1-ubyte.gz";
229 const std::string train_labels_file = data_dir +
"/train-labels-idx1-ubyte";
230 const std::string test_images_url =
BASE_URL +
"t10k-images-idx3-ubyte.gz";
231 const std::string test_images_path = data_dir +
"/t10k-images-idx3-ubyte.gz";
232 const std::string test_images_file = data_dir +
"/t10k-images-idx3-ubyte";
233 const std::string test_labels_url =
BASE_URL +
"t10k-labels-idx1-ubyte.gz";
234 const std::string test_labels_path = data_dir +
"/t10k-labels-idx1-ubyte.gz";
235 const std::string test_labels_file = data_dir +
"/t10k-labels-idx1-ubyte";
237 if (!std::filesystem::exists(train_images_file) ||
238 !std::filesystem::exists(train_labels_file) ||
239 !std::filesystem::exists(test_images_file) ||
240 !std::filesystem::exists(test_labels_file))
242 std::filesystem::create_directories(data_dir);
244 std::cout <<
"Downloading MNIST training images..." << std::endl;
247 std::cout <<
"Downloading MNIST training labels..." << std::endl;
250 std::cout <<
"Downloading MNIST test images..." << std::endl;
253 std::cout <<
"Downloading MNIST test labels..." << std::endl;
256 std::cout <<
"Extracting MNIST training images..." << std::endl;
257 unzip_file(train_images_path, train_images_file);
259 std::cout <<
"Extracting MNIST training labels..." << std::endl;
260 unzip_file(train_labels_path, train_labels_file);
262 std::cout <<
"Extracting MNIST test images..." << std::endl;
263 unzip_file(test_images_path, test_images_file);
265 std::cout <<
"Extracting MNIST test labels..." << std::endl;
266 unzip_file(test_labels_path, test_labels_file);
275 return std::make_tuple(x_train, y_train, x_test, y_test);
const std::string BASE_URL
Definition mnist.cpp:20
bool unzip_file(const std::string &gz_path, const std::string &out_path)
Unzip gz file and save to local path.
Definition mnist.cpp:93
Eigen::MatrixXd read_mnist_images(const std::string &filename)
Read a MNIST image file.
Definition mnist.cpp:128
Eigen::MatrixXd read_mnist_labels(const std::string &filename)
Read a MNIST label file.
Definition mnist.cpp:179
std::tuple< Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd > fetch_mnist(const std::string &data_dir)
Fetch MNIST dataset and save to local directory if it doesn't exist.
Definition mnist.cpp:222
void download_file(const std::string &url, const std::string &path)
Download file from URL and save to local path.
Definition mnist.cpp:44
bool file_exists(const std::string &filename)
Check if file exists.
Definition mnist.cpp:27