Collect and pre-process data

Here we will be using the MNIST dataset, which is readily available through the scikit-learn package. You may also find it for example here. The MNIST (Modified National Institute of Standards and Technology) database is a large database of handwritten digits that is commonly used for training various image processing systems. The MNIST dataset consists of 70 000 images of size \( 28\times 28 \) pixels, each labeled from 0 to 9. The scikit-learn dataset we will use consists of a selection of 1797 images of size \( 8\times 8 \) collected and processed from this database.

To feed data into a feed-forward neural network we need to represent the inputs as a feature matrix \( X = (n_{inputs}, n_{features}) \). Each row represents an input, in this case a handwritten digit, and each column represents a feature, in this case a pixel. The correct answers, also known as labels or targets are represented as a 1D array of integers \( Y = (n_{inputs}) = (5, 3, 1, 8,...) \).

As an example, say we want to build a neural network using supervised learning to predict Body-Mass Index (BMI) from measurements of height (in m) and weight (in kg). If we have measurements of 5 people the feature matrix could be for example: $$ X = \begin{bmatrix} 1.85 & 81\\ 1.71 & 65\\ 1.95 & 103\\ 1.55 & 42\\ 1.63 & 56 \end{bmatrix} ,$$

and the targets would be: $$ Y = (23.7, 22.2, 27.1, 17.5, 21.1) $$

Since each input image is a 2D matrix, we need to flatten the image (i.e. "unravel" the 2D matrix into a 1D array) to turn the data into a feature matrix. This means we lose all spatial information in the image, such as locality and translational invariance. More complicated architectures such as Convolutional Neural Networks can take advantage of such information, and are most commonly applied when analyzing images.

# import necessary packages
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets


# ensure the same random numbers appear every time
np.random.seed(0)

# display images in notebook
%matplotlib inline
plt.rcParams['figure.figsize'] = (12,12)


# download MNIST dataset
digits = datasets.load_digits()

# define inputs and labels
inputs = digits.images
labels = digits.target

print("inputs = (n_inputs, pixel_width, pixel_height) = " + str(inputs.shape))
print("labels = (n_inputs) = " + str(labels.shape))


# flatten the image
# the value -1 means dimension is inferred from the remaining dimensions: 8x8 = 64
n_inputs = len(inputs)
inputs = inputs.reshape(n_inputs, -1)
print("X = (n_inputs, n_features) = " + str(inputs.shape))


# choose some random images to display
indices = np.arange(n_inputs)
random_indices = np.random.choice(indices, size=5)

for i, image in enumerate(digits.images[random_indices]):
    plt.subplot(1, 5, i+1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title("Label: %d" % digits.target[random_indices[i]])
plt.show()