Experimenting with the MNIST Dataset

In [1]:
from pathlib import Path

import torch
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
In [2]:
data_root = Path('../data')
train_dataset = datasets.MNIST(
    root=data_root, train=True, download=True, transform=transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True
)
print(f'We have {len(train_loader)} training examples')
We have 60000 training examples
In [3]:
train_iter = iter(train_loader)

n = 10

fig, axes = plt.subplots(1, n, figsize=(n, 4))
for i in range(n):
    images, labels = next(train_iter)
    image: Tensor = images[0]
    label: int = labels[0]
    axes[i].imshow(images[0].squeeze())
    axes[i].set_title(f'{label}')
    axes[i].set_xticks([])
    axes[i].set_yticks([])
fig.subplots_adjust(top=1.4)
fig.suptitle(f'{n} Random Training Examples')
plt.show()
No description has been provided for this image

Shown below is the shape of the images. They are tensors of shape $(1, 1, 28, 28)$. The first dimension is created by the train_loader, the next three dimensions are created by the ToTensor dataset transform.

  • First dimension is the batch dimension. We set it to $1$ in the train_loader above, but we will increase it when training.

  • Second dimension is the channel dimension. Since MNIST images are grayscale they only have 1 channel (instead of the usual 3 channels for R G B).

  • Third dimension is height. MNIST images are 28 pixels high.

  • Fourth dimension is width. MNIST images are 28 pixels wide.

In [4]:
next(train_iter)[0].shape
Out[4]:
torch.Size([1, 1, 28, 28])

Next we'll calculate the mean and standard deviation of the dataset so that we can normalize it properly before training.

In [5]:
def calculate_mean(loader: torch.utils.data.DataLoader) -> float:
    total_sum: float = 0.0
    total_pixels: int = 0

    x: Tensor
    for x, _ in loader:
        total_sum += x.sum()
        total_pixels += x.numel() # Alternatively can just do `+= 28*28`

    return (total_sum / total_pixels).item()

def calculate_std(loader: torch.utils.data.DataLoader, mean: float) -> float:
    squared_diff_sum: float = 0.0
    total_pixels: int = 0

    x: Tensor
    for x, _ in loader:
        squared_diff_sum += ((x - mean)**2).sum()
        total_pixels += x.numel() # Alternatively can just do `+= 28*28`

    return (squared_diff_sum / total_pixels).sqrt().item()
In [6]:
x_mean = calculate_mean(train_loader)
x_std = calculate_std(train_loader, x_mean)
x_mean, x_std
Out[6]:
(0.13066068291664124, 0.3081083595752716)

Now that we know the mean and standard deviation of the dataset, we can normalize it properly. After this transformation, the dataset should have a mean of 0 and standard deviation of 1.

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
In [8]:
train_dataset = datasets.MNIST(
    root=data_root, train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True
)
normalized_x_mean = calculate_mean(train_loader)
normalized_x_std = calculate_std(train_loader, normalized_x_mean)
normalized_x_mean, normalized_x_std
Out[8]:
(-6.353487265187141e-07, 0.9999954700469971)

Training

In [9]:
import random

import torch
import torch.nn.functional as F

import mnist
from mnist.model.dataloader import DataLoaderScheduler
from mnist.model.train import estimate_loss, train, train_basic_sdg
from mnist.utils.plot import show_images_with_labels
from mnist.utils.preprocessing import tensor2img
In [10]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(f'Using {device}')
Using cuda

We use training and validation datasets with a batch size of 64.

In [11]:
train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Small MLP

First we train with a simple one-layer multilayer perceptron. The layer has $28 \cdot 28 = 784$ inputs (one per pixel in the image) and $10$ outputs (one for each class label).

In [12]:
small_mlp_layer_in_outs = [28*28, 10]
mlp_small = mnist.model.mlp.MLP(small_mlp_layer_in_outs)
mlp_small.to(device)
diagnostics = train_basic_sdg(mlp_small, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=True, device=device)
epoch: 0, lr: 1.0000e-02, train loss: 2.5722, train accuracy: 0.0845, val loss: 2.5652, val accuracy: 0.0811
epoch: 1, lr: 7.9433e-03, train loss: 0.3471, train accuracy: 0.9005, val loss: 0.3385, val accuracy: 0.9078
epoch: 2, lr: 6.3096e-03, train loss: 0.3171, train accuracy: 0.9077, val loss: 0.3148, val accuracy: 0.9105
epoch: 3, lr: 5.0119e-03, train loss: 0.3220, train accuracy: 0.9078, val loss: 0.2876, val accuracy: 0.9184
epoch: 4, lr: 3.9811e-03, train loss: 0.3075, train accuracy: 0.9094, val loss: 0.2946, val accuracy: 0.9172
epoch: 5, lr: 3.1623e-03, train loss: 0.2952, train accuracy: 0.9145, val loss: 0.2950, val accuracy: 0.9173
epoch: 6, lr: 2.5119e-03, train loss: 0.2912, train accuracy: 0.9163, val loss: 0.2894, val accuracy: 0.9141
epoch: 7, lr: 1.9953e-03, train loss: 0.2825, train accuracy: 0.9186, val loss: 0.2877, val accuracy: 0.9186
epoch: 8, lr: 1.5849e-03, train loss: 0.2966, train accuracy: 0.9158, val loss: 0.2859, val accuracy: 0.9198
epoch: 9, lr: 1.2589e-03, train loss: 0.2893, train accuracy: 0.9206, val loss: 0.2991, val accuracy: 0.9133

Observe in 10 training epochs we achieve:

In [13]:
small_mlp_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
small_mlp_last_epoch
Out[13]:
Epoch                  9.000000
Learning Rate          0.001259
Training Loss          0.289251
Training Accuracy      0.920625
Validation Loss        0.299050
Validation Accuracy    0.913281
Name: 9, dtype: float64

Large MLP

Next, we train with a larger multilayer perceptron. There are five layers in this MLP starting with $784$ inputs and finally reaching $10$ outputs.

In [14]:
large_mlp_layer_in_outs = [28*28, 28*28*4, 28*28*16, 28*28*4, 10]
mlp_large = mnist.model.mlp.MLP(large_mlp_layer_in_outs)
mlp_large.to(device)
diagnostics = train_basic_sdg(mlp_large, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=True, device=device)
epoch: 0, lr: 1.0000e-02, train loss: 2.3317, train accuracy: 0.0741, val loss: 2.3307, val accuracy: 0.0753
epoch: 1, lr: 7.9433e-03, train loss: 0.3142, train accuracy: 0.9111, val loss: 0.3007, val accuracy: 0.9150
epoch: 2, lr: 6.3096e-03, train loss: 0.2830, train accuracy: 0.9186, val loss: 0.2885, val accuracy: 0.9178
epoch: 3, lr: 5.0119e-03, train loss: 0.2708, train accuracy: 0.9228, val loss: 0.2823, val accuracy: 0.9208
epoch: 4, lr: 3.9811e-03, train loss: 0.2848, train accuracy: 0.9197, val loss: 0.2833, val accuracy: 0.9194
epoch: 5, lr: 3.1623e-03, train loss: 0.2766, train accuracy: 0.9225, val loss: 0.2736, val accuracy: 0.9233
epoch: 6, lr: 2.5119e-03, train loss: 0.2648, train accuracy: 0.9241, val loss: 0.2787, val accuracy: 0.9205
epoch: 7, lr: 1.9953e-03, train loss: 0.2575, train accuracy: 0.9248, val loss: 0.2705, val accuracy: 0.9247
epoch: 8, lr: 1.5849e-03, train loss: 0.2407, train accuracy: 0.9328, val loss: 0.2596, val accuracy: 0.9252
epoch: 9, lr: 1.2589e-03, train loss: 0.2693, train accuracy: 0.9247, val loss: 0.2748, val accuracy: 0.9222

Observe in 10 training epochs we achieve:

In [15]:
large_mlp_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
large_mlp_last_epoch
Out[15]:
Epoch                  9.000000
Learning Rate          0.001259
Training Loss          0.269296
Training Accuracy      0.924687
Validation Loss        0.274788
Validation Accuracy    0.922188
Name: 9, dtype: float64

Calculating the increase in synapses (excluding biases) from the small MLP to the large MLP and comparing it to the increase in accuracy shows that this architecture is not scaling well.

In [16]:
small_mlp_synapses = sum(n_in * n_out for n_in, n_out in zip(small_mlp_layer_in_outs, small_mlp_layer_in_outs[1:]))
large_mlp_synapses = sum(n_in * n_out for n_in, n_out in zip(large_mlp_layer_in_outs, large_mlp_layer_in_outs[1:]))
print(f'Synapses increased by a factor of {large_mlp_synapses:,}/{small_mlp_synapses:,} = {large_mlp_synapses / small_mlp_synapses:,.2f}, validation accuracy increased by a factor of {large_mlp_last_epoch["Validation Accuracy"]}/{small_mlp_last_epoch["Validation Accuracy"]} = {large_mlp_last_epoch["Validation Accuracy"]/small_mlp_last_epoch["Validation Accuracy"]:,.4f}')
Synapses increased by a factor of 81,165,952/7,840 = 10,352.80, validation accuracy increased by a factor of 0.9221875/0.91328125 = 1.0098

CNN

Next we try a convolutional neural network.

In [17]:
cnn = mnist.model.cnn.CNN()
cnn.to(device)
diagnostics = train_basic_sdg(cnn, train_loader, val_loader, epochs=10, max_lr=0.01, min_lr=0.001, flatten=False, device=device)
epoch: 0, lr: 1.0000e-02, train loss: 2.3015, train accuracy: 0.0950, val loss: 2.3027, val accuracy: 0.1000
epoch: 1, lr: 7.9433e-03, train loss: 0.0874, train accuracy: 0.9734, val loss: 0.0866, val accuracy: 0.9742
epoch: 2, lr: 6.3096e-03, train loss: 0.0687, train accuracy: 0.9786, val loss: 0.0670, val accuracy: 0.9781
epoch: 3, lr: 5.0119e-03, train loss: 0.0550, train accuracy: 0.9834, val loss: 0.0538, val accuracy: 0.9817
epoch: 4, lr: 3.9811e-03, train loss: 0.0473, train accuracy: 0.9856, val loss: 0.0496, val accuracy: 0.9841
epoch: 5, lr: 3.1623e-03, train loss: 0.0369, train accuracy: 0.9892, val loss: 0.0472, val accuracy: 0.9831
epoch: 6, lr: 2.5119e-03, train loss: 0.0337, train accuracy: 0.9908, val loss: 0.0417, val accuracy: 0.9864
epoch: 7, lr: 1.9953e-03, train loss: 0.0369, train accuracy: 0.9900, val loss: 0.0426, val accuracy: 0.9853
epoch: 8, lr: 1.5849e-03, train loss: 0.0293, train accuracy: 0.9919, val loss: 0.0466, val accuracy: 0.9841
epoch: 9, lr: 1.2589e-03, train loss: 0.0305, train accuracy: 0.9911, val loss: 0.0456, val accuracy: 0.9844

Observe in 10 training epochs we achieve:

In [18]:
cnn_last_epoch = diagnostics.loc[diagnostics['Epoch'].idxmax()]
cnn_last_epoch
Out[18]:
Epoch                  9.000000
Learning Rate          0.001259
Training Loss          0.030516
Training Accuracy      0.991094
Validation Loss        0.045632
Validation Accuracy    0.984375
Name: 9, dtype: float64

The CNN is significantly more accurate than the large MLP while using significantly less parameters.

In [19]:
mlp_large_parameters = sum(p.numel() for p in mlp_large.parameters())
cnn_parameters = sum(p.numel() for p in cnn.parameters())
print(f'CNN has {cnn_parameters:,}/{mlp_large_parameters:,} = {100 * cnn_parameters / mlp_large_parameters:,.2f}% of the parameters of the large MLP, yet it was {cnn_last_epoch["Validation Accuracy"]:.4f}/{large_mlp_last_epoch["Validation Accuracy"]:.4f} = {cnn_last_epoch["Validation Accuracy"]/large_mlp_last_epoch["Validation Accuracy"]:,.4f} times more accurate on the validation set')
CNN has 67,562/81,184,778 = 0.08% of the parameters of the large MLP, yet it was 0.9844/0.9222 = 1.0674 times more accurate on the validation set

Adding in Random Image Transformations to Improve Model Accuracy

Before we really train the CNN, we can apply some random transformations to the training dataset before running it through the network. Shown below are some training images and some random transformations applied on those images:

  • rotations between $-30^\circ$ and $30^\circ$
  • horizontal shifts between $0\%$ and $25\%$ of the image width and vertical shifts between $0\%$ and $25\%$ of the image height
  • scaling between $0.75$ and $1.25$
  • shear between $-25^\circ$ and $25^\circ$

This will better reflect the variability of human handwriting.

In [20]:
transform1 = transforms.Compose([
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
transform2 = transforms.Compose([
    transforms.RandomAffine(
        degrees=30,
        translate=(0.25, 0.25),
        scale=(0.75, 1.25),
        shear=25
    ),
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])

untransformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True)

n = 10

fig, axes = plt.subplots(2, n, figsize=(n, 2))
for i in range(n):
    image, label = random.choice(untransformed_train_dataset)
    axes[0][i].imshow(tensor2img(transform1(image)).squeeze())
    axes[0][i].set_title(f'{label}')
    axes[0][i].set_xticks([])
    axes[0][i].set_yticks([])
    axes[1][i].imshow(tensor2img(transform2(image)).squeeze())
    axes[1][i].set_xticks([])
    axes[1][i].set_yticks([])
fig.subplots_adjust(top=0.75)
fig.suptitle(f'{n} Random Training Examples (top), Transformed Examples (bottom)')
plt.show()
No description has been provided for this image

However, it is important to not introduce these transforms too early, as the model may overfit details early on that are not fundamentally part of the digit. We will start with a few epochs being trained on the un-transformed data then slowly ramp up the magnitude of the random transformations.

Conveniently for us, the transforms.Compose object is mutable, allowing us to change the parameters for the random transformations during training. This is what is implemented in our custom DataLoaderScheduler class.

In [21]:
transform2 = transforms.Compose([
    transforms.RandomAffine(
        degrees=30,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1),
        shear=5
    ),
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])

transformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform2)
train_dataset_loader = torch.utils.data.DataLoader(transformed_train_dataset, batch_size=1, shuffle=True)

n = 10

fig, axes = plt.subplots(1, n, figsize=(n, 2))
i = 0
for images, labels in train_dataset_loader:
    if i == n:
        break
    if i >= int(n/2):
        transform2.transforms[0].shear = [-30, 30]
        transform2.transforms[0].degrees = [90, 270]
        transform2.transforms[0].scale = [0.01, 2.9]
    image = images[0]
    label = labels[0]
    axes[i].imshow(tensor2img(image).squeeze())
    axes[i].set_title(label.item())
    axes[i].set_xticks([])
    axes[i].set_yticks([])
    i += 1
fig.subplots_adjust(top=0.75)
fig.suptitle(f'Note how the magnitude of the transformations increases halfway through the batch')
plt.show()
No description has been provided for this image

Actual Training Time 😎

Here is the heavy-duty training setup for the convolutional neural network. We are using the data loader scheduler, which progressively increases the magnitude of the random transformations on the training data, along with AdamW for the optimizer and a cosine annealing learning rate scheduler, training for 100 epochs.

In [22]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import mnist
from mnist.model.dataloader import StepDataLoaderScheduler
from mnist.model.train import train

epochs = 30
x_mean, x_std = (0.13066114485263824, 0.30810731649398804)
data_root = Path('../data')

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f'Using {device}')

data_loader = StepDataLoaderScheduler(
    root=data_root,
    x_mean=x_mean,
    x_std=x_std,
    batch_size=64,
    epochs=epochs,
    degrees=5,
    translate=0.25,
    scale=0.25,
    shear=5,
    warmup_steps=1,
    step_size=10
)

cnn = mnist.model.cnn.CNN()
cnn.to(device)

optimizer = torch.optim.AdamW(cnn.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

diagnostics = train(cnn, data_loader, epochs, optimizer, scheduler, flatten=False, device=device, quiet=True, show_progress=True)
Using cuda
100%|██████████| 30/30 [03:07<00:00,  6.26s/it]
In [23]:
diagnostics.head()
Out[23]:
EpochLearning RateTraining LossTraining AccuracyValidation LossValidation Accuracy
000.0001002.3180000.0948442.3169370.095625
110.0001000.1021080.9721870.1018020.970938
220.0000990.0601100.9823440.0635480.979219
330.0000980.0527830.9851560.0512960.983125
440.0000960.0324080.9910940.0461410.985313
In [24]:
diagnostics.tail()
Out[24]:
EpochLearning RateTraining LossTraining AccuracyValidation LossValidation Accuracy
25250.0000080.0027740.9998440.0539810.987031
26260.0000050.0034180.9996870.0464010.987656
27270.0000030.0035250.9996870.0483570.988125
28280.0000020.0028380.9996870.0530360.987344
29290.0000010.0026610.9998440.0519140.987031
In [25]:
fix, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(diagnostics['Epoch'], diagnostics['Training Loss'], label='Training Loss')
axes[0].plot(diagnostics['Epoch'], diagnostics['Validation Loss'], label='Validation Loss')
axes[0].set_title('Loss during training')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (cross entropy)')
axes[0].set_yscale('log')
axes[0].legend()
axes[1].plot(diagnostics['Epoch'], diagnostics['Training Accuracy'] * 100, label='Training Accuracy')
axes[1].plot(diagnostics['Epoch'], diagnostics['Validation Accuracy'] * 100, label='Validation Accuracy')
axes[1].set_title('Accuracy during training')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (percent)')
axes[1].legend()
axes[2].plot(diagnostics['Epoch'], diagnostics['Learning Rate'] * 100, label='Learning Rate')
axes[2].set_title('Learning rate during training')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning rate')
plt.tight_layout()
plt.show()
No description has been provided for this image

Evaluating the Model with Validation Data

In [26]:
val_iter = iter(data_loader.loader('val'))
In [27]:
from mnist.utils.plot import show_images_with_labels
classes = list(map(str, range(10)))
images, labels = next(val_iter)
images = images[:4].to(device)
labels = labels[:4].to(device)
output = cnn(images)
_, predicted = torch.max(output, 1)
show_images_with_labels(images.to('cpu'), classes, labels.to('cpu'), predicted.to('cpu'))
No description has been provided for this image

Here we measure the loss and accuracy on the original validation set.

In [28]:
from mnist.model.train import dataset_loss
from torchvision import transforms, datasets
transform = transforms.Compose([
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the untransformed validation set (accuracy of {estimated_loss[1]*100:.2f}%).')
Model achieved loss of 0.0495 on the untransformed validation set (accuracy of 98.80%).

And here we measure the loss on the validation set with random affine transformations applied.

In [29]:
transform = transforms.Compose([
    transforms.RandomAffine(
        degrees=5,
        translate=(0.25, 0.25),
        scale=(0.75, 1.25),
        shear=5
    ),
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
val_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the transformed validation set (accuracy of {estimated_loss[1]*100:.2f}%).')
Model achieved loss of 9.1229 on the transformed validation set (accuracy of 34.97%).

We also have some hand-drawn digits that are not part of the MNIST dataset (I drew these myself :P). These images better reflect the kind of images that the model will receive in production, so it is important that it performs well on these.

In [30]:
from torchvision import transforms, datasets
test_data_root = Path('../custom-data')
x_mean, x_std = (0.13066114485263824, 0.30810731649398804)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])
test_dataset = datasets.ImageFolder(root=test_data_root, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
In [31]:
from mnist.utils.plot import tensor2img
images, labels = next(iter(test_dataloader))
plt.imshow(tensor2img(images[0]))
plt.title(str(labels[0].item()))
plt.xticks([])
plt.yticks([])
plt.show()
No description has been provided for this image
In [32]:
estimated_loss = dataset_loss(cnn, val_loader, flatten=False, device=device)
print(f'Model achieved loss of {estimated_loss[0]:.4f} on the test set (accuracy of {estimated_loss[1]*100:.2f}%).')
Model achieved loss of 9.2684 on the test set (accuracy of 34.04%).

Saving the Model

Now we're going to save the model, both as a PyTorch state dict (so we can reload the model in our training setup), and as an ONNX model (which we will be using in production).

In [33]:
torch.save(cnn.state_dict(), './thousand-iteration-model.pth')
print('Model saved')
Model saved

Here, we load the model from the state dict that we saved above and export it to ONNX.

In [34]:
cnn = mnist.model.cnn.CNN()
state_dict = torch.load('thousand-iteration-model.pth')
cnn.load_state_dict(state_dict)
cnn.eval()
dummy_input  = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    cnn,
    dummy_input,
    'model.onnx',
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

And here we load the ONNX model that we exported above and run inference on it.

In [35]:
import torch
import onnxruntime as ort
from torchvision import transforms, datasets
from mnist.utils.plot import show_images_with_labels
import numpy as np

x_mean, x_std = (0.13066114485263824, 0.30810731649398804)

transform = transforms.Compose([
    transforms.ToTensor(),                    # (C, H, W)
    transforms.Normalize((x_mean,), (x_std,)) # (C, H, W)
])

transformed_train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
train_dataset_loader = torch.utils.data.DataLoader(transformed_train_dataset, batch_size=4, shuffle=True)

session = ort.InferenceSession('model.onnx')

inputs, labels = next(iter(train_dataset_loader))

outputs = session.run(None, {'input': inputs.numpy()})

predicted = np.argmax(outputs[0], axis=1) # Equivalent to `torch.max(outputs, 1)`

show_images_with_labels(inputs, list(map(str, range(10))), labels, predicted, 'ONNX')
No description has been provided for this image