Source code for ribs.visualize._visualize_qdax

"""Provides visualization functions for QDax repertoires."""
import numpy as np

from ribs.archives import CVTArchive
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap


def _as_cvt_archive(repertoire, ranges):
    """Converts a QDax repertoire into a CVTArchive."""

    # Construct a CVTArchive. We set solution_dim to 0 since we are only
    # plotting and do not need to have the solutions available.
    cvt_archive = CVTArchive(
        solution_dim=0,
        cells=repertoire.centroids.shape[0],
        ranges=ranges,
        custom_centroids=repertoire.centroids,
    )

    # Add everything to the CVTArchive.
    occupied = repertoire.fitnesses != -np.inf
    cvt_archive.add(
        np.empty((occupied.sum(), 0)),
        repertoire.fitnesses[occupied],
        repertoire.descriptors[occupied],
    )

    return cvt_archive


[docs]def qdax_repertoire_heatmap( repertoire, ranges, *args, **kwargs, ): # pylint: disable = line-too-long """Plots a heatmap of a QDax MapElitesRepertoire. Internally, this function converts a :class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into a :class:`~ribs.archives.CVTArchive` and plots it with :meth:`cvt_archive_heatmap`. Args: repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire): A MAP-Elites repertoire output by an algorithm in QDax. ranges (array-like of (float, float)): Upper and lower bound of each dimension of the measure space, e.g. ``[(-1, 1), (-2, 2)]`` indicates the first dimension should have bounds :math:`[-1,1]` (inclusive), and the second dimension should have bounds :math:`[-2,2]` (inclusive). This is needed since the MapElitesRepertoire does not store measure space bounds. *args: Positional arguments to pass to :meth:`cvt_archive_heatmap`. **kwargs: Keyword arguments to pass to :meth:`cvt_archive_heatmap`. """ # pylint: enable = line-too-long cvt_archive_heatmap(_as_cvt_archive(repertoire, ranges), *args, **kwargs)
[docs]def qdax_repertoire_3d_plot( repertoire, ranges, *args, **kwargs, ): # pylint: disable = line-too-long """Plots a QDax MapElitesRepertoire with 3D measure space. Internally, this function converts a :class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into a :class:`~ribs.archives.CVTArchive` and plots it with :meth:`cvt_archive_3d_plot`. Args: repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire): A MAP-Elites repertoire output by an algorithm in QDax. ranges (array-like of (float, float)): Upper and lower bound of each dimension of the measure space, e.g. ``[(-1, 1), (-2, 2), (-3, 3)]`` indicates the first dimension should have bounds :math:`[-1,1]` (inclusive), the second dimension should have bounds :math:`[-2,2]`, and the third dimension should have bounds :math:`[-3,3]` (inclusive). This is needed since the MapElitesRepertoire does not store measure space bounds. *args: Positional arguments to pass to :meth:`cvt_archive_3d_plot`. **kwargs: Keyword arguments to pass to :meth:`cvt_archive_3d_plot`. """ # pylint: enable = line-too-long cvt_archive_3d_plot(_as_cvt_archive(repertoire, ranges), *args, **kwargs)