Illuminating the Latent Space of an MNIST GAN

One of the most popular applications of Generative Adversarial Networks is generating fake images. In particular, websites like this person does not exist serve a GAN that generates fake images of people (this x does not exist provides a comprehensive list of such websites). Such websites are entertaining, especially when one is asked to figure out which face is real.

Usually, these websites extract fake images by sampling the GAN’s latent space. For those unfamiliar with GANs, this means that each image is associated with a real valued vector of \(n\) components. But since these vectors are typically generated randomly, the usefulness of these websites breaks down when we wish to search for a specific image.

For instance, suppose that instead of fake faces, we want to generate fake handwriting, specifically the digit eight (8). We could train a GAN on the MNIST dataset and produce a generator network that generates fake digits. Now, we can repeatedly sample the latent space until an eight appears. However, if we want to find an eight, we could optimize latent space directly with CMA-ES. To ensure that we generate eights, we could use the output classification prediction of a LeNet-5 classifier as the objective (see Bontrager 2018).1 But notice that the latent space likely contains many examples of the digit eight, and they might vary in the weight of the pen stroke or the lightness of the ink color. If we make these properties our behavior characteristics, we could search latent space for many different examples of eight in a single run!

Fontaine 2021 takes exactly this approach when generating new levels for the classic video game Super Mario Bros. They term this approach “Latent Space Illumination”, as they explore quality diversity (QD) algorithms (including CMA-ME) as a method to search the latent space of a video game level GAN and illuminate the behavior space of possible level mechanics. In this tutorial, we illuminate the latent space of the aforementioned MNIST GAN by mimicking the approach taken in Fontaine 2021.

(1) Since the discriminator of the GAN is only trained to evaluate how realistic an image is, it cannot detect specific digits. Hence, we need the LeNet-5 to check that the digit is an 8.


First, we install pyribs, PyTorch, and several utilities.

%pip install ribs torch==1.7 torchvision==0.8 numpy matplotlib
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

Below, we check what device is available for PyTorch. On Colab, activate the GPU by clicking “Runtime” in the toolbar at the top. Then, click “Change Runtime Type”, and select “GPU”.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading the GAN and Classifier

For this tutorial, we pretrained a GAN that generates MNIST digits using the code from a beginner GAN tutorial. We also pretrained a LeNet-5 classifier for the MNIST dataset using the code here. Below, we define the network structures.

class Generator(nn.Module):
    """Generator network for the GAN."""

    def __init__(self, nz):
        super(Generator, self).__init__()

        # Size of the latent space (number of dimensions). = nz
        self.main = nn.Sequential(
            nn.Linear(, 256),
            nn.Linear(256, 512),
            nn.Linear(512, 1024),
            nn.Linear(1024, 784),

    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    """Discriminator network for the GAN."""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.Linear(1024, 512),
            nn.Linear(512, 256),
            nn.Linear(256, 1),

    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

LENET5 = nn.Sequential(
    nn.Conv2d(1, 6, (5, 5), stride=1, padding=0),  # (1,28,28) -> (6,24,24)
    nn.MaxPool2d(2),  # (6,24,24) -> (6,12,12)
    nn.Conv2d(6, 16, (5, 5), stride=1, padding=0),  # (6,12,12) -> (16,8,8)
    nn.MaxPool2d(2),  # (16,8,8) -> (16,4,4)
    nn.Flatten(),  # (16,4,4) -> (256,)
    nn.Linear(256, 120),  # (256,) -> (120,)
    nn.Linear(120, 84),  # (120,) -> (84,)
    nn.Linear(84, 10),  # (84,) -> (10,)
    nn.LogSoftmax(dim=1),  # (10,) log probabilities

Next, we load the pretrained weights for each network.

import os
from urllib.request import urlretrieve
from pathlib import Path

LOCAL_DIR = Path("lsi_mnist_weights")
WEB_DIR = ""

# Download the model files to LOCAL_DIR.
for filename in [
    model_path = LOCAL_DIR / filename
    if not model_path.is_file():
        urlretrieve(WEB_DIR + filename, str(model_path))

# Load the weights of each network from its file.
g_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_generator.pth"),
d_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_discriminator.pth"),
c_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_classifier.pth"),

# Instantiate networks and insert the weights.
generator = Generator(nz=128).to(device)
discriminator = Discriminator().to(device)
<All keys matched successfully>


After loading the GAN and the classifier, we can begin exploring the latent space of the GAN with the pyribs implementation of CMA-ME. Thus, we import and initialize the GridArchive, ImprovementEmitter, and Optimizer from pyribs.

For the GridArchive, we choose a 2D behavior space with “boldness” and “lightness” as the behavior characteristics. We approximate “boldness” of a digit by counting the number of white pixels in the image, and we approximate “lightness” by averaging the values of the white pixels in the image. We define a “white” pixel as a pixel with value at least 0.5 (pixels are bounded to the range \([0,1]\)). Since there are 784 pixels in an image, boldness is bounded to the range \([0, 784]\). Meanwhile, lightness is bounded to the range \([0.5, 1]\), as that is the range of a white pixel.

from ribs.archives import GridArchive

archive = GridArchive(
    [200, 200],  # 200 bins in each dimension.
    [(0, 784), (0.5, 1)],  # Boldness range, lightness range.

Next, we use 5 instances of ImprovementEmitter, each with batch size of 30. Each emitter begins with a zero vector of the same dimensionality as the latent space and an initial step size \(\sigma=0.2\).

from ribs.emitters import ImprovementEmitter

emitters = [
    ) for _ in range(5)

Finally, we construct the optimizer to connect the archive and emitters together.

from ribs.optimizers import Optimizer

optimizer = Optimizer(archive, emitters)

With the components created, we now generate latent vectors. As we use 5 emitters with batch size of 30 and run 30,000 iterations, we evaluate 30,000 * 30 * 5 = 4,500,000 latent vectors in total. This loop should take 15-30 min to run.

total_itrs = 30_000
flat_img_size = 784  # 28 * 28
start_time = time.time()

for itr in range(1, total_itrs + 1):
    sols = optimizer.ask()

    with torch.no_grad():
        tensor_sols = torch.tensor(

        # Shape: len(sols) x 1 x 28 x 28
        generated_imgs = generator(tensor_sols)

        # Normalize the images from [-1,1] to [0,1].
        normalized_imgs = (generated_imgs + 1.0) / 2.0

        # We optimize the score of the digit being 8. Other digits may also be
        # used.
        lenet5_normalized = ((normalized_imgs - LENET5_MEAN_TRANSFORM) /
        objs = torch.exp(LENET5(lenet5_normalized)[:, 8]).cpu().numpy()

        # Shape: len(sols) x 784
        flattened_imgs = normalized_imgs.cpu().numpy().reshape(
            (-1, flat_img_size))

        # The first bc is the "boldness" of the digit (i.e. number of white
        # pixels). We consider pixels with values larger than or equal to 0.5
        # to be "white".
        # Shape: len(sols) x 1
        boldness = np.count_nonzero(flattened_imgs >= 0.5,

        # The second bc is the "lightness" of the digit (i.e. the mean value of
        # the white pixels).
        # Shape: len(sols) x 1
        flattened_imgs[flattened_imgs < 0.5] = 0  # Set non-white pixels to 0.
        # Add 1 to avoid dividing by zero.
        lightness = (np.sum(flattened_imgs, axis=1, keepdims=True) /
                     (boldness + 1))

        # Each BC entry is [boldness, lightness].
        bcs = np.concatenate([boldness, lightness], axis=1)

    optimizer.tell(objs, bcs)

    if itr % 1000 == 0:
            f"Iteration {itr} complete after {time.time() - start_time}s - "
            f"Archive size: {len(archive.as_pandas(include_solutions=False))}")
Iteration 1000 complete after 39.17980217933655s - Archive size: 13171
Iteration 2000 complete after 72.20655941963196s - Archive size: 15408
Iteration 3000 complete after 107.56191420555115s - Archive size: 15640
Iteration 4000 complete after 138.75481128692627s - Archive size: 15986
Iteration 5000 complete after 173.41253542900085s - Archive size: 16378
Iteration 6000 complete after 202.84771490097046s - Archive size: 16476
Iteration 7000 complete after 231.2869167327881s - Archive size: 16927
Iteration 8000 complete after 258.17751288414s - Archive size: 17018
Iteration 9000 complete after 285.6783549785614s - Archive size: 17149
Iteration 10000 complete after 315.4162104129791s - Archive size: 17196
Iteration 11000 complete after 342.89622044563293s - Archive size: 17284
Iteration 12000 complete after 370.1462724208832s - Archive size: 17340
Iteration 13000 complete after 401.8246293067932s - Archive size: 17502
Iteration 14000 complete after 436.1385066509247s - Archive size: 17771
Iteration 15000 complete after 471.6374089717865s - Archive size: 17920
Iteration 16000 complete after 499.57357358932495s - Archive size: 17984
Iteration 17000 complete after 535.9386947154999s - Archive size: 18248
Iteration 18000 complete after 568.7088708877563s - Archive size: 18356
Iteration 19000 complete after 600.0366671085358s - Archive size: 18405
Iteration 20000 complete after 631.5517485141754s - Archive size: 18533
Iteration 21000 complete after 663.398030757904s - Archive size: 18620
Iteration 22000 complete after 695.3690409660339s - Archive size: 18676
Iteration 23000 complete after 727.6523280143738s - Archive size: 19271
Iteration 24000 complete after 760.9145286083221s - Archive size: 19338
Iteration 25000 complete after 796.7973093986511s - Archive size: 20390
Iteration 26000 complete after 829.9189705848694s - Archive size: 20586
Iteration 27000 complete after 862.0929872989655s - Archive size: 20600
Iteration 28000 complete after 894.1377077102661s - Archive size: 20617
Iteration 29000 complete after 926.0865910053253s - Archive size: 20649
Iteration 30000 complete after 958.2865278720856s - Archive size: 20674


Below, we visualize the archive after all evaluations. The x-axis is the boldness and the y-axis is the lightness. The color indicates the objective value. We can see that we found many images that the classifier strongly believed to be an eight.

from ribs.visualize import grid_archive_heatmap

plt.figure(figsize=(8, 6))
grid_archive_heatmap(archive, vmin=0.0, vmax=1.0)
plt.title("LSI MNIST")

Next, we display a grid of digits generated from a selected set of latent vectors in the archive.

from torchvision.utils import make_grid

def show_grid_img(x_start,
                  figsize=(8, 6)):
    """Displays a grid of images from the archive.
        x_start (int): Starting index along x-axis.
        x_num (int): Number of images to generate along x-axis.
        x_step_size (int): Index step size along x-axis.
        y_start (int): Starting index along y-axis.
        y_num (int): Number of images to generate along y-axis.
        y_step_size (int): Index step size along y-axis.
        archive (GridArchive): Archive with results from CMA-ME.
        figsize ((int, int)): Size of the figure for the image.
    x_range = np.arange(x_start, x_start + x_step_size * x_num, x_step_size)
    y_range = np.arange(y_start, y_start + y_step_size * y_num, y_step_size)
    grid_indices = [(x, y) for y in np.flip(y_range) for x in x_range]

    imgs = []
    img_size = (28, 28)
    df = archive.as_pandas()
    solutions, indices = df.batch_solutions(), df.batch_indices()
    for index in grid_indices:
            sol = solutions[indices.index(index)]
        except ValueError:
            print(f"There is no solution at index {index}.")

        with torch.no_grad():
            img = generator(
            # Normalize images to [0,1].
            normalized = (img.reshape(1, *img_size) + 1) / 2

    img_grid = make_grid(imgs, nrow=x_num, padding=0)
    plt.imshow(np.transpose(img_grid.cpu().numpy(), (1, 2, 0)),

    # Change labels to be BC values.
    x_ticklabels = [
        for i in [x_start + x_step_size * k for k in range(x_num + 1)]
    y_ticklabels = [
        round(archive.boundaries[1][i], 2) for i in [
            y_start + y_step_size * y_num - y_step_size * k
            for k in range(y_num + 1)
    plt.xticks([img_size[0] * x for x in range(x_num + 1)], x_ticklabels)
    plt.yticks([img_size[0] * x for x in range(y_num + 1)], y_ticklabels)

As we can see below, digits get bolder as we go along the x-axis. Meanwhile, as we go along the y-axis, the digits get brighter. For instance, the image in the bottom right corner is grey and bold, while the image in the top left corner is white and thin.

show_grid_img(10, 8, 7, 105, 6, 15, archive)

Here we display images from a wider range of the archive. Note that in order to generate images with high boldness values, CMA-ME generated images that do not look realistic (see the bottom right corner in particular).

show_grid_img(10, 8, 15, 90, 6, 15, archive)

To determine how realistic all of the images in the archive are, we can evaluate them with the discriminator network of the GAN. Below, we create a new archive where the objective value of each solution is the discriminator score. BCs remain the same.

discriminator_archive = GridArchive(
    [200, 200],  # 200 bins in each dimension.
    [(0, 784), (0.5, 1)],  # Boldness range, lightness range.

# Evaluate each solution in the archive and insert it into the new archive.
for elite in archive:
    # No need to normalize to [0, 1] since the discriminator takes in images in
    # the range [-1, 1].
    img = generator(
    obj = discriminator(img).item()
    discriminator_archive.add(elite.sol, obj, elite.beh)

Now, we can plot a heatmap of the archive with the discriminator score. The large regions of low score (in black) show that many images in the archive are not realistic, even though LeNet-5 had high confidence that these images showed the digit eight.

plt.figure(figsize=(8, 6))
grid_archive_heatmap(discriminator_archive, vmin=0.0, vmax=1.0)
plt.title("Discriminator Evaluation")


By searching the latent space of an MNIST GAN, CMA-ME found images of the digit eight that varied in boldness and lightness. Even though the LeNet-5 network had high confidence that these images were eights, it turned out that many of these images were highly unrealistic — when we evaluated them with the GAN’s discriminator network, the images mostly received low scores.

In short, we found that large portions of the GAN’s latent space are unrealistic. This is not surprising because during training, the GAN generates fake images by randomly sampling the latent space from a fixed Gaussian distribution, and some portions of the distribution are less likely to be sampled. Thus, we have the following questions, which we leave open for future exploration:

  • How can we ensure that CMA-ME searches for realistic eights?

  • While searching for realistic eights, can we also search for other digits at the same time?


If you find this tutorial useful, please cite it as:

  title   = {Illuminating the Latent Space of an MNIST GAN},
  author  = {Yulun Zhang and Bryon Tjanaka and Matthew C. Fontaine and Stefanos Nikolaidis},
  journal = {},
  year    = {2021},
  url     = {}