Generating Images to Fool an MNIST Classifier

Despite their high performance on classification tasks such as MNIST, neural networks like the LeNet-5 have a weakness: they are easy to fool. Namely, given images like the ones below, a classifier may confidently believe that it is seeing certain digits, even though the images look like random noise to humans. Naturally, this phenomenon raises some concerns, especially when the network in question is used in a safety-critical system like a self-driving car. Given such unrecognizable input, one would hope that the network at least has low confidence in its prediction.

fooling images example

To make matters worse for neural networks, generating such images is incredibly easy with QD algorithms. As shown in Nguyen 2015, one can use simple MAP-Elites to generate these images. In this tutorial, we will instead use the pyribs version of MAP-Elites to do just that.

Setup

First, we install pyribs and PyTorch.

%pip install ribs torch torchvision

Here, we import PyTorch and some utilities.

import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

Below, we check what device is available for PyTorch. On Colab, activate the GPU by clicking “Runtime” in the toolbar at the top. Then, click “Change Runtime Type”, and select “GPU”.

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
cuda

Preliminary: MNIST Network

For the classifier network, we train a LeNet-5 to classify MNIST. If you are not familiar with PyTorch, we recommend referring to the PyTorch 60-minute blitz. On the other hand, if you are familiar, feel free to skip to the next section, where we demonstrate how to fool the network.

Note: This section is adapted from the Training a Classifier tutorial in the 60-minute blitz.

Before training the network, we load and preprocess the MNIST dataset.

# Transform each image by turning it into a tensor and then
# normalizing the values.
MEAN_TRANSFORM = 0.1307
STD_DEV_TRANSFORM = 0.3081
mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((MEAN_TRANSFORM,), (STD_DEV_TRANSFORM,))
])

TRAIN_BATCH_SIZE = 64
TRAINLOADER = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
    './data', train=True, download=True, transform=mnist_transforms),
                                          batch_size=TRAIN_BATCH_SIZE,
                                          shuffle=True)

TEST_BATCH_SIZE = 1000
TESTLOADER = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
    './data', train=False, transform=mnist_transforms),
                                         batch_size=TEST_BATCH_SIZE,
                                         shuffle=False)

This is our training function. We use negative log likelihood loss and Adam optimization.

def fit(net, epochs):
    """Trains net for the given number of epochs."""
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(net.parameters())

    for epoch in range(epochs):
        print(f"=== Epoch {epoch + 1} ===")
        total_loss = 0.0

        # Iterate through batches in the shuffled training dataset.
        for batch_i, data in enumerate(TRAINLOADER):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if (batch_i + 1) % 100 == 0:
                print(f"Batch {batch_i + 1:5d}: {total_loss}")
                total_loss = 0.0

Now, we define the LeNet-5 and train it for 2 epochs. We have annotated the shapes of the data (excluding the batch dimension) as they pass through the network.

LENET5 = nn.Sequential(
    nn.Conv2d(1, 6, (5, 5), stride=1, padding=0),  # (1,28,28) -> (6,24,24)
    nn.MaxPool2d(2),  # (6,24,24) -> (6,12,12)
    nn.ReLU(),
    nn.Conv2d(6, 16, (5, 5), stride=1, padding=0),  # (6,12,12) -> (16,8,8)
    nn.MaxPool2d(2),  # (16,8,8) -> (16,4,4)
    nn.ReLU(),
    nn.Flatten(),  # (16,4,4) -> (256,)
    nn.Linear(256, 120),  # (256,) -> (120,)
    nn.ReLU(),
    nn.Linear(120, 84),  # (120,) -> (84,)
    nn.ReLU(),
    nn.Linear(84, 10),  # (84,) -> (10,)
    nn.LogSoftmax(dim=1),  # (10,) log probabilities
).to(device)

fit(LENET5, 2)
=== Epoch 1 ===
Batch   100: 91.06456322968006
Batch   200: 27.53672430664301
Batch   300: 20.126135285943747
Batch   400: 17.020390208810568
Batch   500: 12.69742445461452
Batch   600: 11.237632440403104
Batch   700: 10.171053199097514
Batch   800: 9.923257116228342
Batch   900: 9.814844051375985
=== Epoch 2 ===
Batch   100: 7.419087789952755
Batch   200: 6.923573239240795
Batch   300: 7.54965728148818
Batch   400: 7.2184919700957835
Batch   500: 6.637612740509212
Batch   600: 6.85739607969299
Batch   700: 6.978088988456875
Batch   800: 5.98210142692551
Batch   900: 6.327374076470733

Finally, we evaluate the network on the train and test sets.

def evaluate(net, loader):
    """Evaluates the network's accuracy on the images in the dataloader."""
    correct_per_num = [0 for _ in range(10)]
    total_per_num = [0 for _ in range(10)]

    with torch.no_grad():
        for data in loader:
            images, labels = data
            outputs = net(images.to(device))
            _, predicted = torch.max(outputs.to("cpu"), 1)
            c = (predicted == labels).squeeze()
            for i in range(len(c)):
                label = labels[i]
                correct_per_num[label] += c[i].item()
                total_per_num[label] += 1

    for i in range(10):
        print(f"Class {i}: {correct_per_num[i] / total_per_num[i]:5.3f}"
              f" ({correct_per_num[i]} / {total_per_num[i]})")
    print(f"TOTAL  : {sum(correct_per_num) / sum(total_per_num):5.3f}"
          f" ({sum(correct_per_num)} / {sum(total_per_num)})")
evaluate(LENET5, TRAINLOADER)
Class 0: 0.988 (5854 / 5923)
Class 1: 0.996 (6714 / 6742)
Class 2: 0.983 (5858 / 5958)
Class 3: 0.985 (6037 / 6131)
Class 4: 0.983 (5741 / 5842)
Class 5: 0.986 (5347 / 5421)
Class 6: 0.996 (5894 / 5918)
Class 7: 0.992 (6215 / 6265)
Class 8: 0.972 (5690 / 5851)
Class 9: 0.979 (5823 / 5949)
TOTAL  : 0.986 (59173 / 60000)
evaluate(LENET5, TESTLOADER)
Class 0: 0.984 (964 / 980)
Class 1: 0.996 (1131 / 1135)
Class 2: 0.985 (1017 / 1032)
Class 3: 0.991 (1001 / 1010)
Class 4: 0.989 (971 / 982)
Class 5: 0.985 (879 / 892)
Class 6: 0.994 (952 / 958)
Class 7: 0.994 (1022 / 1028)
Class 8: 0.978 (953 / 974)
Class 9: 0.975 (984 / 1009)
TOTAL  : 0.987 (9874 / 10000)

Fooling the Classifier with MAP-Elites

Above, we trained a reasonably high-performing classifier. In order to fool the classifier into seeing various digits, we use MAP-Elites. As we have 10 distinct digits (0-9), we have a discrete behavior space with 10 values. Note that while pyribs is designed for continuous search spaces, the behavior space can be either continuous or discrete.

Our classifier outputs a log probability vector with its belief that it is seeing each digit. Thus, our objective for each digit is to maximize the probability that the classifier assigns to the image associated with it. For instance, for digit 5, we want to generate an image that makes the classifier believe with high probability that it is seeing a 5.

In pyribs, we implement MAP-Elites with a GridArchive and a GaussianEmitter. Below, we start by constructing the GridArchive. The archive has 10 bins and a range of (0,10). Since GridArchive was originally designed for continuous spaces, it does not directly support discrete spaces, but by using these settings, we have a bin for each digit from 0 to 9.

from ribs.archives import GridArchive

archive = GridArchive([10], [(0, 10)])

Next, we use a single Gaussian emitter with batch size of 30. The emitter begins with an image filled with 0.5 (i.e. grey, since pixels are in the range \([0,1]\)) and has \(\sigma = 0.5\).

from ribs.emitters import GaussianEmitter

img_size = (28, 28)
flat_img_size = 784  # 28 * 28
emitters = [
    GaussianEmitter(
        archive,
        # Start with a grey image.
        np.full(flat_img_size, 0.5),
        0.5,
        # Bound the generated images to the pixel range.
        bounds=[(0, 1)] * flat_img_size,
        batch_size=30,
    )
]

Finally, we construct the optimizer to connect the archive and emitter together.

from ribs.optimizers import Optimizer

optimizer = Optimizer(archive, emitters)

With the components created, we now generate the images. As we use 1 emitter with batch size of 30 and run 30,000 iterations, we evaluate 900,000 images in total.

total_itrs = 30_000
start_time = time.time()

for itr in range(1, total_itrs + 1):
    sols = optimizer.ask()

    with torch.no_grad():

        # Reshape and normalize the image and pass it through the network.
        imgs = sols.reshape((-1, 1, *img_size))
        imgs = (imgs - MEAN_TRANSFORM) / STD_DEV_TRANSFORM
        imgs = torch.tensor(imgs, dtype=torch.float32, device=device)
        output = LENET5(imgs)

        # The BC is the digit that the network believes it is seeing, i.e. the
        # digit with the maximum probability. The objective is the probability
        # associated with that digit.
        scores, predicted = torch.max(output.to("cpu"), 1)
        scores = torch.exp(scores)
        objs = scores.numpy()
        bcs = predicted.numpy()

    optimizer.tell(objs, bcs)
    
    if itr % 1000 == 0:
        print(f"Iteration {itr} complete after {time.time() - start_time} s")
Iteration 1000 complete after 8.40089201927185 s
Iteration 2000 complete after 18.655755758285522 s
Iteration 3000 complete after 26.78686022758484 s
Iteration 4000 complete after 34.25956344604492 s
Iteration 5000 complete after 40.23374128341675 s
Iteration 6000 complete after 45.780840158462524 s
Iteration 7000 complete after 51.34646511077881 s
Iteration 8000 complete after 56.87763333320618 s
Iteration 9000 complete after 62.4202446937561 s
Iteration 10000 complete after 67.97641730308533 s
Iteration 11000 complete after 74.24591207504272 s
Iteration 12000 complete after 79.86771273612976 s
Iteration 13000 complete after 86.17090797424316 s
Iteration 14000 complete after 93.36901998519897 s
Iteration 15000 complete after 100.97218585014343 s
Iteration 16000 complete after 106.8744752407074 s
Iteration 17000 complete after 112.68033242225647 s
Iteration 18000 complete after 118.06590342521667 s
Iteration 19000 complete after 123.54357981681824 s
Iteration 20000 complete after 131.5358066558838 s
Iteration 21000 complete after 139.29581952095032 s
Iteration 22000 complete after 144.79404830932617 s
Iteration 23000 complete after 150.26445746421814 s
Iteration 24000 complete after 155.667062997818 s
Iteration 25000 complete after 161.09095001220703 s
Iteration 26000 complete after 166.48615026474 s
Iteration 27000 complete after 171.90232849121094 s
Iteration 28000 complete after 177.3490025997162 s
Iteration 29000 complete after 182.76266932487488 s
Iteration 30000 complete after 188.1841962337494 s

Below, we display the results we found with MAP-Elites. The index_0 column shows the digit associated with each image, and the objective column shows the network’s belief that the image is that digit. The solution columns show the image’s pixel values.

archive.as_pandas().sort_values("index_0")
index_0 behavior_0 objective solution_0 solution_1 solution_2 solution_3 solution_4 solution_5 solution_6 ... solution_774 solution_775 solution_776 solution_777 solution_778 solution_779 solution_780 solution_781 solution_782 solution_783
1 0 0.0 0.941922 1.000000 0.741815 0.768159 0.277410 0.781590 0.176740 0.000000 ... 0.000000 0.041704 0.000000 1.000000 1.000000 0.436551 1.000000 0.619011 0.036524 0.000000
9 1 1.0 0.967510 0.850452 0.348060 0.028136 0.025715 0.000000 0.710674 0.105554 ... 1.000000 0.000000 0.000000 1.000000 1.000000 1.000000 0.000000 0.000000 0.710068 0.000000
5 2 2.0 0.981887 0.000000 0.440409 0.647115 0.640636 0.700368 0.910001 0.000000 ... 0.000000 1.000000 1.000000 0.000000 0.069804 0.608800 0.502689 0.310673 0.000000 0.752838
4 3 3.0 0.972874 0.462327 0.000000 0.000000 1.000000 0.000000 0.723633 0.727103 ... 0.135687 0.356678 0.279045 0.068244 0.000000 0.688684 0.000000 0.000000 1.000000 0.498214
3 4 4.0 0.996900 0.286218 0.914999 1.000000 0.137627 0.089804 0.000000 0.488552 ... 0.268614 1.000000 1.000000 0.000000 0.012825 0.245794 1.000000 0.404129 0.000000 0.548628
7 5 5.0 0.991366 0.869267 0.000000 0.236837 0.089832 0.499572 0.262858 0.000000 ... 0.211596 0.000000 0.000000 0.771642 0.384867 0.904340 0.068203 0.469802 0.916816 0.898603
6 6 6.0 0.980629 1.000000 0.000000 0.942983 0.601392 0.250491 0.039754 1.000000 ... 1.000000 0.000000 0.670791 0.000000 1.000000 0.810804 0.625034 0.000000 0.208471 1.000000
2 7 7.0 0.988252 0.000000 0.000000 0.832956 0.282398 0.000000 0.724716 0.737155 ... 0.000000 0.657358 0.868062 1.000000 0.460378 0.403149 1.000000 0.000000 1.000000 0.878756
0 8 8.0 0.996376 0.699162 0.852378 1.000000 1.000000 0.848423 0.496602 1.000000 ... 1.000000 0.065986 0.000000 0.764668 0.000000 0.000000 1.000000 0.491454 0.000000 0.578239
8 9 9.0 0.985330 0.689702 1.000000 0.815496 1.000000 1.000000 0.163877 0.911765 ... 1.000000 0.095587 0.576414 0.303699 0.000000 0.396534 0.000000 1.000000 0.366522 0.000000

10 rows × 787 columns

Here, we display the images found. Interestingly, though the images look mostly like noise, we can occasionally make out traces of the original digit. Note that MAP-Elites may not find images for all the digits, and this is mostly due to the small behavior space. Usually, QD algorithms run with fairly large behavior spaces. This is something to keep in mind when tuning QD algorithms.

fig, ax = plt.subplots(2, 5, figsize=(10, 4))
fig.tight_layout()
ax = ax.flatten()
found = set()

# Display images.
for _, row in archive.as_pandas().iterrows():
    i = int(row.loc["index_0"])
    found.add(i)
    obj = row.loc["objective"]
    ax[i].set_title(f"{i} | Score: {obj:.3f}", pad=8)
    img = row.loc["solution_0":].to_numpy().reshape(28, 28)

    # No need to normalize image because we want to see the original.
    ax[i].imshow(img, cmap="Greys")
    ax[i].set_axis_off()

# Mark digits that we did not generate images for.
for i in range(10):
    if i not in found:
        ax[i].set_title(f"{i} | (no solution)", pad=8)
        ax[i].set_axis_off()
../_images/fooling_mnist_28_0.png

Conclusion

In this tutorial, we used MAP-Elites to generate images that fool a LeNet-5 MNIST classifier. For further exploration, we recommend referring to Nguyen 2015 and replicating or extending the other experiments described in the paper.