Generating Images to Fool an MNIST Classifier

Despite their high performance on classification tasks such as MNIST, neural networks like the LeNet-5 have a weakness: they are easy to fool. Namely, given images like the ones below, a classifier may confidently believe that it is seeing certain digits, even though the images look like random noise to humans. Naturally, this phenomenon raises some concerns, especially when the network in question is used in a safety-critical system like a self-driving car. Given such unrecognizable input, one would hope that the network at least has low confidence in its prediction.

fooling images example

To make matters worse for neural networks, generating such images is incredibly easy with QD algorithms. As shown in Nguyen 2015, one can use simple MAP-Elites to generate these images. In this tutorial, we will use the pyribs version of MAP-Elites to do just that.


First, we install pyribs and PyTorch.

%pip install ribs torch==1.7 torchvision==0.8

Here, we import PyTorch and some utilities.

import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

Below, we check what device is available for PyTorch. On Colab, activate the GPU by clicking “Runtime” in the toolbar at the top. Then, click “Change Runtime Type”, and select “GPU”.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Preliminary: MNIST Network

We have pretrained a high-performing LeNet-5 classifier (98.4% training set accuracy, 98.5% test set accuracy) for the MNIST dataset using the code here. This is the same network that we use in the LSI MNIST tutorial. Below, we define the network.

LENET5 = nn.Sequential(
    nn.Conv2d(1, 6, (5, 5), stride=1, padding=0),  # (1,28,28) -> (6,24,24)
    nn.MaxPool2d(2),  # (6,24,24) -> (6,12,12)
    nn.Conv2d(6, 16, (5, 5), stride=1, padding=0),  # (6,12,12) -> (16,8,8)
    nn.MaxPool2d(2),  # (16,8,8) -> (16,4,4)
    nn.Flatten(),  # (16,4,4) -> (256,)
    nn.Linear(256, 120),  # (256,) -> (120,)
    nn.Linear(120, 84),  # (120,) -> (84,)
    nn.Linear(84, 10),  # (84,) -> (10,)
    nn.LogSoftmax(dim=1),  # (10,) log probabilities

Now, we download the weights and load them into the network.

import os
from urllib.request import urlretrieve
from pathlib import Path

LOCAL_DIR = Path("fooling_mnist_weights")
WEB_DIR = ""

# Download the model file to LOCAL_DIR.
filename = "mnist_classifier.pth"
model_path = LOCAL_DIR / filename
if not model_path.is_file():
    urlretrieve(WEB_DIR + filename, str(model_path))

# Load the weights of the network.
state_dict = torch.load(
    str(LOCAL_DIR / "mnist_classifier.pth"),

# Insert the weights into the network.
<All keys matched successfully>

Fooling the Classifier with MAP-Elites

In order to fool the classifier into seeing various digits, we use MAP-Elites. As we have 10 distinct digits (0-9), we have a discrete behavior space with 10 values. Note that while pyribs is designed for continuous search spaces, the behavior space can be either continuous or discrete.

Our classifier outputs a log probability vector with its belief that it is seeing each digit. Thus, our objective for each digit is to maximize the probability that the classifier assigns to the image associated with it. For instance, for digit 5, we want to generate an image that makes the classifier believe with high probability that it is seeing a 5.

In pyribs, we implement MAP-Elites with a GridArchive and a GaussianEmitter. Below, we start by constructing the GridArchive. The archive has 10 bins and a range of (0,10). Since GridArchive was originally designed for continuous spaces, it does not directly support discrete spaces, but by using these settings, we have a bin for each digit from 0 to 9.

from ribs.archives import GridArchive

archive = GridArchive([10], [(0, 10)])

Next, we use a single Gaussian emitter with batch size of 30. The emitter begins with an image filled with 0.5 (i.e. grey, since pixels are in the range \([0,1]\)) and has \(\sigma = 0.5\).

from ribs.emitters import GaussianEmitter

img_size = (28, 28)
flat_img_size = 784  # 28 * 28
emitters = [
        # Start with a grey image.
        np.full(flat_img_size, 0.5),
        # Bound the generated images to the pixel range.
        bounds=[(0, 1)] * flat_img_size,

Finally, we construct the optimizer to connect the archive and emitter together.

from ribs.optimizers import Optimizer

optimizer = Optimizer(archive, emitters)

With the components created, we now generate the images. As we use 1 emitter with batch size of 30 and run 30,000 iterations, we evaluate 900,000 images in total.

total_itrs = 30_000
start_time = time.time()

for itr in range(1, total_itrs + 1):
    sols = optimizer.ask()

    with torch.no_grad():

        # Reshape and normalize the image and pass it through the network.
        imgs = sols.reshape((-1, 1, *img_size))
        imgs = torch.tensor(imgs, dtype=torch.float32, device=device)
        output = LENET5(imgs)

        # The BC is the digit that the network believes it is seeing, i.e. the
        # digit with the maximum probability. The objective is the probability
        # associated with that digit.
        scores, predicted = torch.max("cpu"), 1)
        scores = torch.exp(scores)
        objs = scores.numpy()
        bcs = predicted.numpy()

    optimizer.tell(objs, bcs)
    if itr % 1000 == 0:
        print(f"Iteration {itr} complete after {time.time() - start_time} s")
Iteration 1000 complete after 6.609215497970581 s
Iteration 2000 complete after 15.589641332626343 s
Iteration 3000 complete after 24.31386113166809 s
Iteration 4000 complete after 32.35465717315674 s
Iteration 5000 complete after 40.373658418655396 s
Iteration 6000 complete after 48.49511432647705 s
Iteration 7000 complete after 56.5163037776947 s
Iteration 8000 complete after 64.4883759021759 s
Iteration 9000 complete after 72.49214792251587 s
Iteration 10000 complete after 80.55756187438965 s
Iteration 11000 complete after 88.57506895065308 s
Iteration 12000 complete after 96.59975409507751 s
Iteration 13000 complete after 104.67273259162903 s
Iteration 14000 complete after 112.81262230873108 s
Iteration 15000 complete after 120.81725907325745 s
Iteration 16000 complete after 128.84538173675537 s
Iteration 17000 complete after 137.2752993106842 s
Iteration 18000 complete after 145.32959914207458 s
Iteration 19000 complete after 153.31980180740356 s
Iteration 20000 complete after 161.4025936126709 s
Iteration 21000 complete after 169.54731440544128 s
Iteration 22000 complete after 177.60430788993835 s
Iteration 23000 complete after 185.68359541893005 s
Iteration 24000 complete after 193.74732971191406 s
Iteration 25000 complete after 201.9666783809662 s
Iteration 26000 complete after 210.08716130256653 s
Iteration 27000 complete after 218.16455078125 s
Iteration 28000 complete after 226.33121705055237 s
Iteration 29000 complete after 234.52697253227234 s
Iteration 30000 complete after 242.6009750366211 s

Below, we display the images found. Interestingly, though the images look mostly like noise, we can occasionally make out traces of the original digit. Note that MAP-Elites may not find images for all the digits, and this is mostly due to the small behavior space. Usually, QD algorithms run with fairly large behavior spaces. This is something to keep in mind when tuning QD algorithms.

fig, ax = plt.subplots(2, 5, figsize=(10, 4))
ax = ax.flatten()
found = set()

# Display images.
for elite in archive:
    digit = elite.idx[0]

    # No need to normalize image because we want to see the original.
    ax[digit].imshow(elite.sol.reshape(28, 28), cmap="Greys")
    ax[digit].set_title(f"{digit} | Score: {elite.obj:.3f}", pad=8)

# Mark digits that we did not generate images for.
for digit in range(10):
    if digit not in found:
        ax[digit].set_title(f"{digit} | (no solution)", pad=8)


In this tutorial, we used MAP-Elites to generate images that fool a LeNet-5 MNIST classifier. For further exploration, we recommend referring to Nguyen 2015 and replicating or extending the other experiments described in the paper.


If you find this tutorial useful, please cite it as:

  title   = {Generating Images to Fool an MNIST Classifier},
  author  = {Bryon Tjanaka and Matthew C. Fontaine and Stefanos Nikolaidis},
  journal = {},
  year    = {2021},
  url     = {}