"""Provides sliding_boundaries_archive_heatmap."""
import matplotlib.pyplot as plt
import numpy as np
from ribs.visualize._utils import (retrieve_cmap, set_cbar, validate_df,
validate_heatmap_visual_args)
# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments
[docs]def sliding_boundaries_archive_heatmap(archive,
ax=None,
*,
df=None,
transpose_measures=False,
cmap="magma",
aspect="auto",
ms=None,
boundary_lw=0,
vmin=None,
vmax=None,
cbar="auto",
cbar_kwargs=None,
rasterized=False):
"""Plots heatmap of a :class:`~ribs.archives.SlidingBoundariesArchive` with
2D measure space.
Since the boundaries of :class:`ribs.archives.SlidingBoundariesArchive` are
dynamic, we plot the heatmap as a scatter plot, in which each marker is an
elite and its color represents the objective value. Boundaries can
optionally be drawn by setting ``boundary_lw`` to a positive value.
Examples:
.. plot::
:context: close-figs
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import SlidingBoundariesArchive
>>> from ribs.visualize import sliding_boundaries_archive_heatmap
>>> archive = SlidingBoundariesArchive(solution_dim=2,
... dims=[10, 20],
... ranges=[(-1, 1), (-1, 1)],
... seed=42)
>>> # Populate the archive with the negative sphere function.
>>> xy = np.clip(np.random.standard_normal((1000, 2)), -1.5, 1.5)
>>> archive.add(solution=xy,
... objective=-np.sum(xy**2, axis=1),
... measures=xy)
>>> # Plot heatmaps of the archive.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,6))
>>> fig.suptitle("Negative sphere function")
>>> sliding_boundaries_archive_heatmap(archive, ax=ax1,
... boundary_lw=0.5)
>>> sliding_boundaries_archive_heatmap(archive, ax=ax2)
>>> ax1.set_title("With boundaries")
>>> ax2.set_title("Without boundaries")
>>> ax1.set(xlabel='x coords', ylabel='y coords')
>>> ax2.set(xlabel='x coords', ylabel='y coords')
>>> plt.show()
Args:
archive (SlidingBoundariesArchive): A 2D
:class:`~ribs.archives.SlidingBoundariesArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.data` with ``return_type="pandas"``
and modifying the resulting
:class:`~ribs.archives.ArchiveDataFrame`. Note that, at a minimum,
the data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
this to ``True``.
cmap (str, list, matplotlib.colors.Colormap): Colormap to use when
plotting intensity. Either the name of a
:class:`~matplotlib.colors.Colormap`, a list of RGB or RGBA colors
(i.e. an :math:`N \\times 3` or :math:`N \\times 4` array), or a
:class:`~matplotlib.colors.Colormap` object.
aspect ('auto', 'equal', float): The aspect ratio of the heatmap (i.e.
height/width). Defaults to ``'auto'``. ``'equal'`` is the same as
``aspect=1``. See :meth:`matplotlib.axes.Axes.set_aspect` for more
info.
ms (float): Marker size for the solutions.
boundary_lw (float): Line width when plotting the boundaries.
Set to ``0`` to have no boundaries.
vmin (float): Minimum objective value to use in the plot. If ``None``,
the minimum objective value in the archive is used.
vmax (float): Maximum objective value to use in the plot. If ``None``,
the maximum objective value in the archive is used.
cbar ('auto', None, matplotlib.axes.Axes): By default, this is set to
``'auto'`` which displays the colorbar on the archive's current
:class:`~matplotlib.axes.Axes`. If ``None``, then colorbar is not
displayed. If this is an :class:`~matplotlib.axes.Axes`, displays
the colorbar on the specified Axes.
cbar_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.colorbar`.
rasterized (bool): Whether to rasterize the heatmap. This can be useful
for saving to a vector format like PDF. Essentially, only the
heatmap will be converted to a raster graphic so that the archive
cells will not have to be individually rendered. Meanwhile, the
surrounding axes, particularly text labels, will remain in vector
format.
Raises:
ValueError: The archive is not 2D.
"""
validate_heatmap_visual_args(
aspect, cbar, archive.measure_dim, [2],
"Heatmap can only be plotted for a 2D SlidingBoundariesArchive")
if aspect is None:
aspect = "auto"
# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)
# Retrieve archive data.
if df is None:
measures_batch = archive.data("measures")
objective_batch = archive.data("objective")
else:
df = validate_df(df)
measures_batch = df.get_field("measures")
objective_batch = df.get_field("objective")
x = measures_batch[:, 0]
y = measures_batch[:, 1]
x_boundary = archive.boundaries[0]
y_boundary = archive.boundaries[1]
lower_bounds = archive.lower_bounds
upper_bounds = archive.upper_bounds
if transpose_measures:
# Since the archive is 2D, transpose by swapping the x and y measures
# and boundaries and by flipping the bounds (the bounds are arrays of
# length 2).
x, y = y, x
x_boundary, y_boundary = y_boundary, x_boundary
lower_bounds = np.flip(lower_bounds)
upper_bounds = np.flip(upper_bounds)
# Initialize the axis.
ax = plt.gca() if ax is None else ax
ax.set_xlim(lower_bounds[0], upper_bounds[0])
ax.set_ylim(lower_bounds[1], upper_bounds[1])
ax.set_aspect(aspect)
# Create the plot.
vmin = np.min(objective_batch) if vmin is None else vmin
vmax = np.max(objective_batch) if vmax is None else vmax
t = ax.scatter(x,
y,
s=ms,
c=objective_batch,
cmap=cmap,
vmin=vmin,
vmax=vmax,
rasterized=rasterized)
if boundary_lw > 0.0:
# Careful with bounds here. Lines drawn along the x axis should extend
# between the y bounds and vice versa -- see
# https://github.com/icaros-usc/pyribs/issues/270
ax.vlines(x_boundary,
lower_bounds[1],
upper_bounds[1],
color='k',
linewidth=boundary_lw,
rasterized=rasterized)
ax.hlines(y_boundary,
lower_bounds[0],
upper_bounds[0],
color='k',
linewidth=boundary_lw,
rasterized=rasterized)
# Create color bar.
set_cbar(t, ax, cbar, cbar_kwargs)