Illuminating the Latent Space of an MNIST GAN

One of the most popular applications of Generative Adversarial Networks is generating fake images. In particular, websites like this person does not exist serve a GAN that generates fake images of people (this x does not exist provides a comprehensive list of such websites). Such websites are entertaining, especially when one is asked to figure out which face is real.

Usually, these websites extract fake images by sampling the GAN’s latent space. For those unfamiliar with GANs, this means that each image is associated with a real valued vector of \(n\) components. But since these vectors are typically generated randomly, the usefulness of these websites breaks down when we wish to search for a specific image.

For instance, suppose that instead of fake faces, we want to generate fake handwriting, specifically the digit eight (8). We could train a GAN on the MNIST dataset and produce a generator network that generates fake digits. Now, we can repeatedly sample the latent space until an eight appears. However, if we want to find an eight, we could optimize latent space directly with CMA-ES. To ensure that we generate eights, we could use the output classification prediction of a LeNet-5 classifier as the objective (see Bontrager 2018).1 But notice that the latent space likely contains many examples of the digit eight, and they might vary in the weight of the pen stroke or the lightness of the ink color. If we make these properties our behavior characteristics, we could search latent space for many different examples of eight in a single run!

Fontaine 2021 takes exactly this approach when generating new levels for the classic video game Super Mario Bros. They term this approach “Latent Space Illumination”, as they explore quality diversity (QD) algorithms (including CMA-ME) as a method to search the latent space of a video game level GAN and illuminate the behavior space of possible level mechanics. In this tutorial, we illuminate the latent space of the aforementioned MNIST GAN by mimicking the approach taken in Fontaine 2021.

(1) Since the discriminator of the GAN is only trained to evaluate how realistic an image is, it cannot detect specific digits. Hence, we need the LeNet-5 to check that the digit is an 8.

Setup

First, we install pyribs, PyTorch, and several utilities.

%pip install ribs torch torchvision numpy matplotlib
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

Below, we check what device is available for PyTorch. On Colab, activate the GPU by clicking “Runtime” in the toolbar at the top. Then, click “Change Runtime Type”, and select “GPU”.

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
cuda

Loading the GAN and Classifier

For this tutorial, we pretrained a GAN that generates MNIST digits using the code from a beginner GAN tutorial. We also pretrained a LeNet-5 classifier for the MNIST dataset using the code in our Fooling MNIST tutorial. Below, we define the network structures.

class Generator(nn.Module):
    """Generator network for the GAN."""

    def __init__(self, nz):
        super(Generator, self).__init__()

        # Size of the latent space (number of dimensions).
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)


class Discriminator(nn.Module):
    """Discriminator network for the GAN."""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)


LENET5 = nn.Sequential(
    nn.Conv2d(1, 6, (5, 5), stride=1, padding=0),  # (1,28,28) -> (6,24,24)
    nn.MaxPool2d(2),  # (6,24,24) -> (6,12,12)
    nn.ReLU(),
    nn.Conv2d(6, 16, (5, 5), stride=1, padding=0),  # (6,12,12) -> (16,8,8)
    nn.MaxPool2d(2),  # (16,8,8) -> (16,4,4)
    nn.ReLU(),
    nn.Flatten(),  # (16,4,4) -> (256,)
    nn.Linear(256, 120),  # (256,) -> (120,)
    nn.ReLU(),
    nn.Linear(120, 84),  # (120,) -> (84,)
    nn.ReLU(),
    nn.Linear(84, 10),  # (84,) -> (10,)
    nn.LogSoftmax(dim=1),  # (10,) log probabilities
).to(device)
LENET5_MEAN_TRANSFORM = 0.1307
LENET5_STD_DEV_TRANSFORM = 0.3081

Next, we load the pretrained weights for each network.

import os
from urllib.request import urlretrieve
from pathlib import Path

LOCAL_DIR = Path("lsi_mnist_weights")
LOCAL_DIR.mkdir(exist_ok=True)
WEB_DIR = "https://raw.githubusercontent.com/icaros-usc/pyribs/master/examples/tutorials/_static/"

# Download the model files to LOCAL_DIR.
for filename in [
        "mnist_generator.pth",
        "mnist_discriminator.pth",
        "mnist_classifier.pth",
]:
    model_path = LOCAL_DIR / filename
    if not model_path.is_file():
        urlretrieve(WEB_DIR + filename, str(model_path))

# Load the weights of each network from its file.
g_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_generator.pth"),
    map_location=device,
)
d_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_discriminator.pth"),
    map_location=device,
)
c_state_dict = torch.load(
    str(LOCAL_DIR / "mnist_classifier.pth"),
    map_location=device,
)

# Instantiate networks and insert the weights.
generator = Generator(nz=128).to(device)
discriminator = Discriminator().to(device)
generator.load_state_dict(g_state_dict)
discriminator.load_state_dict(d_state_dict)
LENET5.load_state_dict(c_state_dict)
<All keys matched successfully>

LSI with CMA-ME on MNIST GAN

After loading the GAN and the classifier, we can begin exploring the latent space of the GAN with the pyribs implementation of CMA-ME. Thus, we import and initialize the GridArchive, ImprovementEmitter, and Optimizer from pyribs.

For the GridArchive, we choose a 2D behavior space with “boldness” and “lightness” as the behavior characteristics. We approximate “boldness” of a digit by counting the number of white pixels in the image, and we approximate “lightness” by averaging the values of the white pixels in the image. We define a “white” pixel as a pixel with value at least 0.5 (pixels are bounded to the range \([0,1]\)). Since there are 784 pixels in an image, boldness is bounded to the range \([0, 784]\). Meanwhile, lightness is bounded to the range \([0.5, 1]\), as that is the range of a white pixel.

from ribs.archives import GridArchive

archive = GridArchive(
    [200, 200],  # 200 bins in each dimension.
    [(0, 784), (0.5, 1)],  # Boldness range, lightness range.
)

Next, we use 5 instances of ImprovementEmitter, each with batch size of 30. Each emitter begins with a zero vector of the same dimensionality as the latent space and an initial step size \(\sigma=0.2\).

from ribs.emitters import ImprovementEmitter

emitters = [
    ImprovementEmitter(
        archive,
        np.zeros(generator.nz),
        0.2,
        batch_size=30,
    ) for _ in range(5)
]

Finally, we construct the optimizer to connect the archive and emitters together.

from ribs.optimizers import Optimizer

optimizer = Optimizer(archive, emitters)

With the components created, we now generate latent vectors. As we use 5 emitters with batch size of 30 and run 30,000 iterations, we evaluate 30,000 * 30 * 5 = 4,500,000 latent vectors in total. This loop should take 15-30 min to run.

total_itrs = 30_000
flat_img_size = 784  # 28 * 28
start_time = time.time()

for itr in range(1, total_itrs + 1):
    sols = optimizer.ask()

    with torch.no_grad():
        tensor_sols = torch.tensor(
            sols,
            dtype=torch.float32,
            device=device,
        )

        # Shape: len(sols) x 1 x 28 x 28
        generated_imgs = generator(tensor_sols)

        # Normalize the images from [-1,1] to [0,1].
        normalized_imgs = (generated_imgs + 1.0) / 2.0

        # We optimize the score of the digit being 8. Other digits may also be
        # used.
        lenet5_normalized = ((normalized_imgs - LENET5_MEAN_TRANSFORM) /
                             LENET5_STD_DEV_TRANSFORM)
        objs = torch.exp(LENET5(lenet5_normalized)[:, 8]).cpu().numpy()

        # Shape: len(sols) x 784
        flattened_imgs = normalized_imgs.cpu().numpy().reshape(
            (-1, flat_img_size))

        # The first bc is the "boldness" of the digit (i.e. number of white
        # pixels). We consider pixels with values larger than or equal to 0.5
        # to be "white".
        # Shape: len(sols) x 1
        boldness = np.count_nonzero(flattened_imgs >= 0.5,
                                    axis=1,
                                    keepdims=True)

        # The second bc is the "lightness" of the digit (i.e. the mean value of
        # the white pixels).
        # Shape: len(sols) x 1
        flattened_imgs[flattened_imgs < 0.5] = 0  # Set non-white pixels to 0.
        # Add 1 to avoid dividing by zero.
        lightness = (np.sum(flattened_imgs, axis=1, keepdims=True) /
                     (boldness + 1))

        # Each BC entry is [boldness, lightness].
        bcs = np.concatenate([boldness, lightness], axis=1)

    optimizer.tell(objs, bcs)

    if itr % 1000 == 0:
        print(
            f"Iteration {itr} complete after {time.time() - start_time}s - "
            f"Archive size: {len(archive.as_pandas(include_solutions=False))}")
Iteration 1000 complete after 18.84685206413269s - Archive size: 13697
Iteration 2000 complete after 36.04245662689209s - Archive size: 17369
Iteration 3000 complete after 56.91603946685791s - Archive size: 18743
Iteration 4000 complete after 72.13739275932312s - Archive size: 19262
Iteration 5000 complete after 85.93850421905518s - Archive size: 19429
Iteration 6000 complete after 99.70822525024414s - Archive size: 19574
Iteration 7000 complete after 112.46592545509338s - Archive size: 19668
Iteration 8000 complete after 124.9775128364563s - Archive size: 19759
Iteration 9000 complete after 138.07127404212952s - Archive size: 19858
Iteration 10000 complete after 150.87091493606567s - Archive size: 19930
Iteration 11000 complete after 168.31051087379456s - Archive size: 19958
Iteration 12000 complete after 182.0728931427002s - Archive size: 19976
Iteration 13000 complete after 194.43875789642334s - Archive size: 20016
Iteration 14000 complete after 207.76799893379211s - Archive size: 20052
Iteration 15000 complete after 220.59500741958618s - Archive size: 20073
Iteration 16000 complete after 232.94806599617004s - Archive size: 20122
Iteration 17000 complete after 245.77648758888245s - Archive size: 20143
Iteration 18000 complete after 259.391832113266s - Archive size: 20160
Iteration 19000 complete after 274.1263542175293s - Archive size: 20183
Iteration 20000 complete after 290.25037837028503s - Archive size: 20229
Iteration 21000 complete after 302.33102107048035s - Archive size: 20255
Iteration 22000 complete after 314.84559893608093s - Archive size: 20287
Iteration 23000 complete after 333.2425458431244s - Archive size: 20905
Iteration 24000 complete after 351.90096521377563s - Archive size: 20980
Iteration 25000 complete after 367.71741127967834s - Archive size: 21005
Iteration 26000 complete after 381.1285123825073s - Archive size: 21028
Iteration 27000 complete after 396.30028200149536s - Archive size: 21046
Iteration 28000 complete after 414.1779029369354s - Archive size: 21060
Iteration 29000 complete after 432.1986894607544s - Archive size: 21070
Iteration 30000 complete after 451.1240584850311s - Archive size: 21083

Visualization

Below, we visualize the archive after all evaluations. The x-axis is the boldness and the y-axis is the lightness. The color indicates the objective value. We can see that we found many images that the classifier strongly believed to be an eight.

from ribs.visualize import grid_archive_heatmap

plt.figure(figsize=(8, 6))
grid_archive_heatmap(archive)
plt.title("LSI MNIST")
plt.xlabel("Boldness")
plt.ylabel("Lightness")
plt.show()
../_images/lsi_mnist_19_0.png

Next, we display a grid of digits generated from a selected set of latent vectors in the archive.

from torchvision.utils import make_grid


def show_grid_img(x_start,
                  x_num,
                  x_step_size,
                  y_start,
                  y_num,
                  y_step_size,
                  archive,
                  figsize=(8, 6)):
    """Displays a grid of images from the archive.
    
    Args:
        x_start (int): Starting index along x-axis.
        x_num (int): Number of images to generate along x-axis.
        x_step_size (int): Index step size along x-axis.
        y_start (int): Starting index along y-axis.
        y_num (int): Number of images to generate along y-axis.
        y_step_size (int): Index step size along y-axis.
        archive (GridArchive): Archive with results from CMA-ME.
        figsize ((int, int)): Size of the figure for the image.
    """
    elites = archive.as_pandas()
    x_range = np.arange(x_start, x_start + x_step_size * x_num, x_step_size)
    y_range = np.arange(y_start, y_start + y_step_size * y_num, y_step_size)
    grid_indexes = [(x, y) for y in np.flip(y_range) for x in x_range]

    imgs = []
    img_size = (28, 28)
    for index in grid_indexes:
        x, y = index
        sol_row = elites[(elites["index_0"] == x) & (elites["index_1"] == y)]
        if sol_row.empty:
            print(
                f"Index ({x}, {y}) solution does not exist at the specified indexes."
            )
            return
        latent_vec = sol_row.iloc[0]["solution_0":].to_numpy()
        with torch.no_grad():
            img = generator(
                torch.tensor(latent_vec.reshape(1, generator.nz),
                             dtype=torch.float32,
                             device=device))
            # Normalize images to [0,1].
            normalized = (img.reshape(1, *img_size) + 1) / 2
            imgs.append(normalized)

    plt.figure(figsize=figsize)
    img_grid = make_grid(imgs, nrow=x_num, padding=0)
    plt.imshow(np.transpose(img_grid.cpu().numpy(), (1, 2, 0)),
               interpolation='nearest',
               cmap='gray')

    # Change labels to be BC values.
    plt.xlabel("Boldness")
    plt.ylabel("Lightness")
    x_ticklabels = [
        round(archive.boundaries[0][i])
        for i in [x_start + x_step_size * k for k in range(x_num + 1)]
    ]
    y_ticklabels = [
        round(archive.boundaries[1][i], 2) for i in [
            y_start + y_step_size * y_num - y_step_size * k
            for k in range(y_num + 1)
        ]
    ]
    plt.xticks([img_size[0] * x for x in range(x_num + 1)], x_ticklabels)
    plt.yticks([img_size[0] * x for x in range(y_num + 1)], y_ticklabels)

As we can see below, digits get bolder as we go along the x-axis. Meanwhile, as we go along the y-axis, the digits get brighter. For instance, the image in the bottom right corner is grey and bold, while the image in the top left corner is white and thin.

show_grid_img(10, 8, 7, 105, 6, 15, archive)
../_images/lsi_mnist_23_0.png

Here we display images from a wider range of the archive. Note that in order to generate images with high boldness values, CMA-ME generated images that do not look realistic (see the bottom right corner in particular).

show_grid_img(10, 8, 15, 90, 6, 15, archive)
../_images/lsi_mnist_25_0.png

To determine how realistic all of the images in the archive are, we can evaluate them with the discriminator network of the GAN. Below, we create a new archive where the objective value of each solution is the discriminator score. BCs remain the same.

df = archive.as_pandas()
discriminator_archive = GridArchive(
    [200, 200],  # 200 bins in each dimension.
    [(0, 784), (0.5, 1)],  # Boldness range, lightness range.
)
discriminator_archive.initialize(generator.nz)

# Evaluate each solution in the archive and insert it into the new archive.
for _, row in df.iterrows():
    latent = np.array(row.loc["solution_0":])
    bcs = row.loc[["behavior_0", "behavior_1"]]
    # No need to normalize to [0, 1] since the discriminator takes in images in
    # the range [-1, 1].
    img = generator(
        torch.tensor(latent.reshape(1, generator.nz),
                     dtype=torch.float32,
                     device=device))
    obj = discriminator(img).item()
    discriminator_archive.add(latent, obj, bcs)

Now, we can plot a heatmap of the archive with the discriminator score. The large regions of low score (in black) show that many images in the archive are not realistic, even though LeNet-5 had high confidence that these images showed the digit eight.

plt.figure(figsize=(8, 6))
grid_archive_heatmap(discriminator_archive)
plt.title("Discriminator Evaluation")
plt.xlabel("Boldness")
plt.ylabel("Lightness")
plt.show()
../_images/lsi_mnist_29_0.png

Conclusion

By searching the latent space of an MNIST GAN, CMA-ME found images of the digit eight that varied in boldness and lightness. Even though the LeNet-5 network had high confidence that these images were eights, it turned out that many of these images were highly unrealistic — when we evaluated them with the GAN’s discriminator network, the images mostly received low scores.

In short, we found that large portions of the GAN’s latent space are unrealistic. This is not surprising because during training, the GAN generates fake images by randomly sampling the latent space from a fixed Gaussian distribution, and some portions of the distribution are less likely to be sampled. Thus, we have the following questions, which we leave open for future exploration:

  • How can we ensure that CMA-ME searches for realistic eights?

  • While searching for realistic eights, can we also search for other digits at the same time?