Source code for deep_bottleneck.datasets.mnist

import numpy as np
from tensorflow import keras

from deep_bottleneck.datasets.base_dataset import Dataset


[docs]def load(): """Load the MNIST handwritten digits dataset Returns: The mnist datset. """ n_classes = 10 (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() X_train = np.reshape(X_train, [X_train.shape[0], -1]).astype('float32') / 255.0 X_test = np.reshape(X_test, [X_test.shape[0], -1]).astype('float32') / 255.0 X_train = X_train * 2.0 - 1.0 X_test = X_test * 2.0 - 1.0 dataset = Dataset.from_labelled_subset(X_train, y_train, X_test, y_test, n_classes) return dataset