Experimenting with the MNIST Dataset¶
from pathlib import Path
import torch
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
data_root = Path('../data')
train_dataset = datasets.MNIST(
root=data_root, train=True, download=True, transform=transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True
)
print(f'We have {len(train_loader)} training examples')
train_iter = iter(train_loader)
n = 10
fig, axes = plt.subplots(1, n, figsize=(n, 4))
for i in range(n):
images, labels = next(train_iter)
image: Tensor = images[0]
label: int = labels[0]
axes[i].imshow(images[0].squeeze())
axes[i].set_title(f'{label}')
axes[i].set_xticks([])
axes[i].set_yticks([])
fig.subplots_adjust(top=1.4)
fig.suptitle(f'{n} Random Training Examples')
plt.show()
Shown below is the shape of the images. They are tensors of shape $(1, 1, 28, 28)$. The first dimension is created by the train_loader, the next three dimensions are created by the ToTensor dataset transform.
First dimension is the batch dimension. We set it to $1$ in the
train_loaderabove, but we will increase it when training.Second dimension is the channel dimension. Since MNIST images are grayscale they only have 1 channel (instead of the usual 3 channels for R G B).
Third dimension is height. MNIST images are 28 pixels high.
Fourth dimension is width. MNIST images are 28 pixels wide.
next(train_iter)[0].shape
Next we'll calculate the mean and standard deviation of the dataset so that we can normalize it properly before training.
def calculate_mean(loader: torch.utils.data.DataLoader) -> float:
total_sum: float = 0.0
total_pixels: int = 0
x: Tensor
for x, _ in loader:
total_sum += x.sum()
total_pixels += x.numel() # Alternatively can just do `+= 28*28`
return (total_sum / total_pixels).item()
def calculate_std(loader: torch.utils.data.DataLoader, mean: float) -> float:
squared_diff_sum: float = 0.0
total_pixels: int = 0
x: Tensor
for x, _ in loader:
squared_diff_sum += ((x - mean)**2).sum()
total_pixels += x.numel() # Alternatively can just do `+= 28*28`
return (squared_diff_sum / total_pixels).sqrt().item()
x_mean = calculate_mean(train_loader)
x_std = calculate_std(train_loader, x_mean)
x_mean, x_std
Now that we know the mean and standard deviation of the dataset, we can normalize it properly. After this transformation, the dataset should have a mean of 0 and standard deviation of 1.
transform = transforms.Compose([
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
train_dataset = datasets.MNIST(
root=data_root, train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True
)
normalized_x_mean = calculate_mean(train_loader)
normalized_x_std = calculate_std(train_loader, normalized_x_mean)
normalized_x_mean, normalized_x_std
Training¶
import random
import torch
import torch.nn.functional as F
import mnist
from mnist.model.dataloader import DataLoaderScheduler
from mnist.model.train import estimate_loss, train, train_basic_sdg
from mnist.utils.plot import show_images_with_labels
from mnist.utils.preprocessing import tensor2img
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print(f'Using {device}')
We use training and validation datasets with a batch size of 64.
train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
Small MLP¶
First we train with a simple one-layer multilayer perceptron. The layer has $28 \cdot 28 = 784$ inputs (one per pixel in the image) and $10$ outputs (one for each class label).
small_mlp_layer_in_outs = [28*28, 10]
mlp_small = mnist.model.mlp.MLP(small_mlp_layer_in_outs)
mlp_small.to(device)
diagnostics = train_basic_sdg(mlp_small, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=True, device=device)
Observe in 10 training epochs we achieve:
small_mlp_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
small_mlp_last_epoch
Large MLP¶
Next, we train with a larger multilayer perceptron. There are five layers in this MLP starting with $784$ inputs and finally reaching $10$ outputs.
large_mlp_layer_in_outs = [28*28, 28*28*4, 28*28*16, 28*28*4, 10]
mlp_large = mnist.model.mlp.MLP(large_mlp_layer_in_outs)
mlp_large.to(device)
diagnostics = train_basic_sdg(mlp_large, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=True, device=device)
Observe in 10 training epochs we achieve:
large_mlp_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
large_mlp_last_epoch
Calculating the increase in synapses (excluding biases) from the small MLP to the large MLP and comparing it to the increase in accuracy shows that this architecture is not scaling well.
small_mlp_synapses = sum(n_in * n_out for n_in, n_out in zip(small_mlp_layer_in_outs, small_mlp_layer_in_outs[1:]))
large_mlp_synapses = sum(n_in * n_out for n_in, n_out in zip(large_mlp_layer_in_outs, large_mlp_layer_in_outs[1:]))
print(f'Synapses increased by a factor of {large_mlp_synapses:,}/{small_mlp_synapses:,} = {large_mlp_synapses / small_mlp_synapses:,.2f}, validation accuracy increased by a factor of {large_mlp_last_epoch["Validation Accuracy"]}/{small_mlp_last_epoch["Validation Accuracy"]} = {large_mlp_last_epoch["Validation Accuracy"]/small_mlp_last_epoch["Validation Accuracy"]:,.4f}')
CNN¶
Next we try a convolutional neural network.
cnn = mnist.model.cnn.CNN()
cnn.to(device)
diagnostics = train_basic_sdg(cnn, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=False, device=device)
Observe in 10 training epochs we achieve:
cnn_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
cnn_last_epoch
The CNN is significantly more accurate than the large MLP while using significantly less parameters.
mlp_large_parameters = sum(p.numel() for p in mlp_large.parameters())
cnn_parameters = sum(p.numel() for p in cnn.parameters())
print(f'CNN has {cnn_parameters:,}/{mlp_large_parameters:,} = {100 * cnn_parameters / mlp_large_parameters:,.2f}% of the parameters of the large MLP, yet it was {cnn_last_epoch["Validation Accuracy"]:.4f}/{large_mlp_last_epoch["Validation Accuracy"]:.4f} = {cnn_last_epoch["Validation Accuracy"]/large_mlp_last_epoch["Validation Accuracy"]:,.4f} times more accurate on the validation set')
Adding in Random Image Transformations to Improve Model Accuracy¶
Before we really train the CNN, we can apply some random transformations to the training dataset before running it through the network. Shown below are some training images and some random transformations applied on those images:
- rotations between $-30^\circ$ and $30^\circ$
- horizontal shifts between $0\%$ and $25\%$ of the image width and vertical shifts between $0\%$ and $25\%$ of the image height
- scaling between $0.75$ and $1.25$
- shear between $-25^\circ$ and $25^\circ$
This will better reflect the variability of human handwriting.
transform1 = transforms.Compose([
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
transform2 = transforms.Compose([
transforms.RandomAffine(
degrees=30,
translate=(0.25, 0.25),
scale=(0.75, 1.25),
shear=25
),
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
untransformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True)
n = 10
fig, axes = plt.subplots(2, n, figsize=(n, 2))
for i in range(n):
image, label = random.choice(untransformed_train_dataset)
axes[0][i].imshow(tensor2img(transform1(image)).squeeze())
axes[0][i].set_title(f'{label}')
axes[0][i].set_xticks([])
axes[0][i].set_yticks([])
axes[1][i].imshow(tensor2img(transform2(image)).squeeze())
axes[1][i].set_xticks([])
axes[1][i].set_yticks([])
fig.subplots_adjust(top=0.75)
fig.suptitle(f'{n} Random Training Examples (top), Transformed Examples (bottom)')
plt.show()
However, it is important to not introduce these transforms too early, as the model may overfit details early on that are not fundamentally part of the digit. We will start with a few epochs being trained on the un-transformed data then slowly ramp up the magnitude of the random transformations.
Conveniently for us, the transforms.Compose object is mutable, allowing us to change the parameters for the random transformations during training. This is what is implemented in our custom DataLoaderScheduler class.
transform2 = transforms.Compose([
transforms.RandomAffine(
degrees=30,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=5
),
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
transformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform2)
train_dataset_loader = torch.utils.data.DataLoader(transformed_train_dataset, batch_size=1, shuffle=True)
n = 10
fig, axes = plt.subplots(1, n, figsize=(n, 2))
i = 0
for images, labels in train_dataset_loader:
if i == n:
break
if i >= int(n/2):
transform2.transforms[0].shear = [-30, 30]
transform2.transforms[0].degrees = [90, 270]
transform2.transforms[0].scale = [0.01, 2.9]
image = images[0]
label = labels[0]
axes[i].imshow(tensor2img(image).squeeze())
axes[i].set_title(label.item())
axes[i].set_xticks([])
axes[i].set_yticks([])
i += 1
fig.subplots_adjust(top=0.75)
fig.suptitle(f'Note how the magnitude of the transformations increases halfway through the batch')
plt.show()
Actual Training Time 😎¶
Here is the heavy-duty training setup for the convolutional neural network. We are using the data loader scheduler, which progressively increases the magnitude of the random transformations on the training data, along with AdamW for the optimizer and a cosine annealing learning rate scheduler, training for 100 epochs.
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import mnist
from mnist.model.dataloader import StepDataLoaderScheduler
from mnist.model.train import train
epochs = 30
x_mean, x_std = (0.13066114485263824, 0.30810731649398804)
data_root = Path('../data')
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print(f'Using {device}')
data_loader = StepDataLoaderScheduler(
root=data_root,
x_mean=x_mean,
x_std=x_std,
batch_size=64,
epochs=epochs,
degrees=5,
translate=0.25,
scale=0.25,
shear=5,
warmup_steps=1,
step_size=10
)
cnn = mnist.model.cnn.CNN()
cnn.to(device)
optimizer = torch.optim.AdamW(cnn.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
diagnostics = train(cnn, data_loader, epochs, optimizer, scheduler, flatten=False, device=device, quiet=True, show_progress=True)
diagnostics.head()
diagnostics.tail()
fix, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(diagnostics['Epoch'], diagnostics['Training Loss'], label='Training Loss')
axes[0].plot(diagnostics['Epoch'], diagnostics['Validation Loss'], label='Validation Loss')
axes[0].set_title('Loss during training')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (cross entropy)')
axes[0].set_yscale('log')
axes[0].legend()
axes[1].plot(diagnostics['Epoch'], diagnostics['Training Accuracy'] * 100, label='Training Accuracy')
axes[1].plot(diagnostics['Epoch'], diagnostics['Validation Accuracy'] * 100, label='Validation Accuracy')
axes[1].set_title('Accuracy during training')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (percent)')
axes[1].legend()
axes[2].plot(diagnostics['Epoch'], diagnostics['Learning Rate'] * 100, label='Learning Rate')
axes[2].set_title('Learning rate during training')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning rate')
plt.tight_layout()
plt.show()
Evaluating the Model with Validation Data¶
val_iter = iter(data_loader.loader('val'))
from mnist.utils.plot import show_images_with_labels
classes = list(map(str, range(10)))
images, labels = next(val_iter)
images = images[:4].to(device)
labels = labels[:4].to(device)
output = cnn(images)
_, predicted = torch.max(output, 1)
show_images_with_labels(images.to('cpu'), classes, labels.to('cpu'), predicted.to('cpu'))
Here we measure the loss and accuracy on the original validation set.
from mnist.model.train import dataset_loss
from torchvision import transforms, datasets
transform = transforms.Compose([
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the untransformed validation set (accuracy of {estimated_loss[1]*100:.2f}%).')
And here we measure the loss on the validation set with random affine transformations applied.
transform = transforms.Compose([
transforms.RandomAffine(
degrees=5,
translate=(0.25, 0.25),
scale=(0.75, 1.25),
shear=5
),
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the transformed validation set (accuracy of {estimated_loss[1]*100:.2f}%).')
We also have some hand-drawn digits that are not part of the MNIST dataset (I drew these myself :P). These images better reflect the kind of images that the model will receive in production, so it is important that it performs well on these.
from torchvision import transforms, datasets
test_data_root = Path('../custom-data')
x_mean, x_std = (0.13066114485263824, 0.30810731649398804)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
test_dataset = datasets.ImageFolder(root=test_data_root, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
from mnist.utils.plot import tensor2img
images, labels = next(iter(test_dataloader))
plt.imshow(tensor2img(images[0]))
plt.title(str(labels[0].item()))
plt.xticks([])
plt.yticks([])
plt.show()
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the test set (accuracy of {estimated_loss[1]*100:.2f}%).')
Saving the Model¶
Now we're going to save the model, both as a PyTorch state dict (so we can reload the model in our training setup), and as an ONNX model (which we will be using in production).
torch.save(cnn.state_dict(), './thousand-iteration-model.pth')
print('Model saved')
Here, we load the model from the state dict that we saved above and export it to ONNX.
cnn = mnist.model.cnn.CNN()
state_dict = torch.load('thousand-iteration-model.pth')
cnn.load_state_dict(state_dict)
cnn.eval()
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
cnn,
dummy_input,
'model.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
And here we load the ONNX model that we exported above and run inference on it.
import torch
import onnxruntime as ort
from torchvision import transforms, datasets
from mnist.utils.plot import show_images_with_labels
import numpy as np
x_mean, x_std = (0.13066114485263824, 0.30810731649398804)
transform = transforms.Compose([
transforms.ToTensor(), # (C, H, W)
transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
transformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
train_dataset_loader = torch.utils.data.DataLoader(transformed_train_dataset, batch_size=4, shuffle=True)
session = ort.InferenceSession('model.onnx')
inputs, labels = next(iter(train_dataset_loader))
outputs = session.run(None, {'input': inputs.numpy()})
predicted = np.argmax(outputs[0], axis=1) # Equivalent to `torch.max(outputs, 1)`
show_images_with_labels(inputs, list(map(str, range(10))), labels, predicted, 'ONNX')