"""Provides visualization functions for QDax repertoires."""
from __future__ import annotations
from collections.abc import Collection
from typing import TYPE_CHECKING
import numpy as np
from typing_extensions import ParamSpec
from ribs.archives import CVTArchive
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap
if TYPE_CHECKING:
# Only import for type checking since QDax is not installed by default.
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
P = ParamSpec("P")
def _as_cvt_archive(
repertoire: MapElitesRepertoire,
ranges: Collection[tuple[float, float]],
) -> CVTArchive:
"""Converts a QDax repertoire into a CVTArchive."""
# Construct a CVTArchive. We set solution_dim to 0 since we are only plotting and do
# not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
centroids=repertoire.centroids,
ranges=ranges,
)
# Fitness is originally (N, 1). Reshape it to (N,).
fitnesses = repertoire.fitnesses
if fitnesses.shape[1:] == (1,):
fitnesses = fitnesses.squeeze(1)
else:
raise ValueError(
"This method only supports visualizing single-objective "
"archives (i.e., there can only be one fitness)."
)
# Add everything to the CVTArchive.
occupied = fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
fitnesses[occupied],
repertoire.descriptors[occupied],
)
return cvt_archive
[docs]
def qdax_repertoire_heatmap(
repertoire: MapElitesRepertoire,
ranges: Collection[tuple[float, float]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Plots a heatmap of a single-objective QDax MapElitesRepertoire.
Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into a
:class:`~ribs.archives.CVTArchive` and plots it with :meth:`cvt_archive_heatmap`.
Args:
repertoire: A MAP-Elites repertoire output by an algorithm in QDax.
ranges: Upper and lower bound of each dimension of the measure space, e.g.
``[(-1, 1), (-2, 2)]`` indicates the first dimension should have bounds
:math:`[-1,1]` (inclusive), and the second dimension should have bounds
:math:`[-2,2]` (inclusive). This is needed since the MapElitesRepertoire
does not store measure space bounds.
*args: Positional arguments to pass to :meth:`cvt_archive_heatmap`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_heatmap`.
Raises:
ValueError: The repertoire passed in has more than one fitness.
"""
cvt_archive_heatmap(_as_cvt_archive(repertoire, ranges), *args, **kwargs)
[docs]
def qdax_repertoire_3d_plot(
repertoire: MapElitesRepertoire,
ranges: Collection[tuple[float, float]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Plots a single-objective QDax MapElitesRepertoire with 3D measure space.
Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into a
:class:`~ribs.archives.CVTArchive` and plots it with :meth:`cvt_archive_3d_plot`.
Args:
repertoire: A MAP-Elites repertoire output by an algorithm in QDax.
ranges: Upper and lower bound of each dimension of the measure space, e.g.
``[(-1, 1), (-2, 2), (-3, 3)]`` indicates the first dimension should have
bounds :math:`[-1,1]` (inclusive), the second dimension should have bounds
:math:`[-2,2]`, and the third dimension should have bounds :math:`[-3,3]`
(inclusive). This is needed since the MapElitesRepertoire does not store
measure space bounds.
*args: Positional arguments to pass to :meth:`cvt_archive_3d_plot`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_3d_plot`.
Raises:
ValueError: The repertoire passed in has more than one fitness.
"""
cvt_archive_3d_plot(_as_cvt_archive(repertoire, ranges), *args, **kwargs)