"""Contains the DNSArchive."""
from __future__ import annotations
from collections.abc import Collection, Iterator
from typing import Literal, overload
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist
from ribs._utils import (
check_batch_shape,
check_finite,
check_shape,
validate_batch,
validate_single,
)
from ribs.archives._archive_base import ArchiveBase
from ribs.archives._archive_data_frame import ArchiveDataFrame
from ribs.archives._archive_stats import ArchiveStats
from ribs.archives._array_store import ArrayStore
from ribs.archives._utils import fill_sentinel_values, parse_all_dtypes
from ribs.typing import BatchData, FieldDesc, Float, Int, SingleData
[docs]
class DNSArchive(ArchiveBase):
r"""An archive that maintains a fixed-size population via Dominated Novelty Search.
Each generation, candidates are merged with the current population, and survivors
are selected by their Dominated Novelty Search (DNS) score: for each solution, the
DNS score is the mean distance in the chosen space to the k nearest neighbors with
strictly higher objective ("fitter" neighbors). If no fitter neighbors exist, the
DNS score is treated as ``+inf``.
More info can be found in the `DNS paper <https://arxiv.org/abs/2502.00593>`_ by
Bahlous-Boldi, R, and Faldor, M et al.
By default, this archive stores the following data fields: ``solution``,
``objective``, ``measures``, and ``index``.
Args:
solution_dim: Dimensionality of the solution space. Scalar or multi-dimensional
solution shapes are allowed by passing an empty tuple or tuple of integers,
respectively.
measure_dim: Dimensionality of the measure space.
capacity: Fixed population size to maintain.
k_neighbors: Number of fitter neighbors to average over when computing DNS.
qd_score_offset: Subtracted from objective values when computing QD score.
seed: Value to seed the random number generator.
solution_dtype: Data type of the solutions. Defaults to float64 (numpy's default
floating point type).
objective_dtype: Data type of the objectives. Defaults to float64 (numpy's
default floating point type).
measures_dtype: Data type of the measures. Defaults to float64 (numpy's default
floating point type).
dtype: Shortcut for providing data type of the solutions, objectives, and
measures. Defaults to float64 (numpy's default floating point type). This
parameter sets all the dtypes simultaneously. To set individual dtypes, pass
``solution_dtype``, ``objective_dtype``, and ``measures_dtype``. Note that
``dtype`` cannot be used at the same time as those parameters.
extra_fields: Extra fields to store alongside solutions.
kdtree_kwargs: Kwargs for :class:`scipy.spatial.KDTree` used in retrieval.
"""
def __init__(
self,
*,
solution_dim: Int | tuple[Int, ...],
measure_dim: Int,
capacity: Int,
k_neighbors: Int,
qd_score_offset: Float = 0.0,
seed: Int | None = None,
solution_dtype: DTypeLike = None,
objective_dtype: DTypeLike = None,
measures_dtype: DTypeLike = None,
dtype: DTypeLike = None,
extra_fields: FieldDesc | None = None,
kdtree_kwargs: dict | None = None,
) -> None:
self._rng = np.random.default_rng(seed)
ArchiveBase.__init__(
self,
solution_dim=solution_dim,
objective_dim=(),
measure_dim=measure_dim,
)
# Set up the ArrayStore, which is a data structure that stores all the elites'
# data in arrays sharing a common index.
extra_fields = extra_fields or {}
reserved_fields = {"solution", "objective", "measures", "index"}
if reserved_fields & extra_fields.keys():
raise ValueError(
"The following names are not allowed in "
f"extra_fields: {reserved_fields}"
)
if capacity < 1:
raise ValueError("capacity must be at least 1.")
solution_dtype, objective_dtype, measures_dtype = parse_all_dtypes(
dtype, solution_dtype, objective_dtype, measures_dtype, np
)
self._store = ArrayStore(
field_desc={
"solution": (self.solution_dim, solution_dtype),
"objective": ((), objective_dtype),
"measures": (self.measure_dim, measures_dtype),
**extra_fields,
},
capacity=capacity,
)
# Set up constant properties.
self._k_neighbors = int(k_neighbors)
self._kdtree_kwargs = {} if kdtree_kwargs is None else kdtree_kwargs.copy()
self._qd_score_offset = np.asarray(
qd_score_offset, dtype=self.dtypes["objective"]
)
# Set up k-D tree with current measures in the archive. Updated on add().
self._cur_kd_tree = KDTree(self._store.data("measures"), **self._kdtree_kwargs)
# Set up statistics -- objective_sum is the sum of all objective values in the
# archive; it is useful for computing qd_score and obj_mean.
self._best_elite = None
self._objective_sum = None
self._stats = None
self._stats_reset()
## Properties inherited from ArchiveBase ##
@property
def field_list(self) -> list[str]:
return self._store.field_list_with_index
@property
def dtypes(self) -> dict[str, np.dtype]:
return self._store.dtypes_with_index
@property
def stats(self) -> ArchiveStats:
return self._stats
@property
def empty(self) -> bool:
return len(self._store) == 0
## Properties that are not in ArchiveBase ##
## Roughly ordered by the parameter list in the constructor. ##
@property
def best_elite(self) -> SingleData | None:
"""The elite with the highest objective in the archive.
None if there are no elites in the archive.
"""
return self._best_elite
@property
def k_neighbors(self) -> int:
"""The number of fitter neighbors for computing DNS."""
return self._k_neighbors
@property
def capacity(self) -> int:
"""Fixed number of solutions stored in this archive."""
return self._store.capacity
@property
def cells(self) -> int:
"""Total capacity of the archive (for coverage/statistics)."""
return self.capacity
@property
def qd_score_offset(self) -> float:
"""Subtracted from objective values when computing the QD score."""
return self._qd_score_offset
## dunder methods ##
def __len__(self) -> int:
return len(self._store)
def __iter__(self) -> Iterator[SingleData]:
return iter(self._store)
## Utilities ##
def _stats_reset(self) -> None:
"""Resets the archive stats."""
self._best_elite = None
self._objective_sum = np.asarray(0.0, dtype=self.dtypes["objective"])
self._stats = ArchiveStats(
num_elites=0,
coverage=np.asarray(0.0, dtype=self.dtypes["objective"]),
qd_score=np.asarray(0.0, dtype=self.dtypes["objective"]),
norm_qd_score=np.asarray(0.0, dtype=self.dtypes["objective"]),
obj_max=None,
obj_mean=None,
)
[docs]
def index_of(self, measures: ArrayLike) -> np.ndarray:
"""Returns the index of the closest solution to the given measures.
Unlike the structured archives like :class:`~ribs.archives.GridArchive`, this
archive does not have indexed cells where each measure "belongs." Thus, this
method instead returns the index of the solution with the closest measure to
each solution passed in.
This means that :meth:`retrieve` will return the solution with the closest
measure to each measure passed into that method.
Args:
measures: (batch_size, :attr:`measure_dim`) array of coordinates in measure
space.
Returns:
(batch_size,) array of integer indices representing the location of the
solution in the archive.
Raises:
RuntimeError: There were no entries in the archive.
ValueError: ``measures`` is not of shape (batch_size, :attr:`measure_dim`).
ValueError: ``measures`` has non-finite values (inf or NaN).
"""
measures = np.asarray(measures, dtype=self.dtypes["measures"])
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
check_finite(measures, "measures")
if self.empty:
raise RuntimeError(
"There were no solutions in the archive. "
"`DNSArchive.index_of` computes the nearest "
"neighbor to the input measures, so there must be at least one "
"solution present in the archive."
)
_, indices = self._cur_kd_tree.query(measures)
return indices.astype(np.int32)
[docs]
def index_of_single(self, measures: ArrayLike) -> Int:
"""Returns the index of the measures for one solution.
See :meth:`index_of`.
Args:
measures: (:attr:`measure_dim`,) array of measures for a single solution.
Returns:
Integer index of the measures in the archive's storage arrays.
Raises:
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
ValueError: ``measures`` has non-finite values (inf or NaN).
"""
measures = np.asarray(measures, dtype=self.dtypes["measures"])
check_shape(measures, "measures", self.measure_dim, "measure_dim")
check_finite(measures, "measures")
return int(self.index_of(measures[None])[0])
[docs]
def compute_dns(self, measures: ArrayLike, objectives: ArrayLike) -> np.ndarray:
"""Computes DNS scores for a current population (evaluation) with respect to itself."""
measures = np.asarray(measures, dtype=self.dtypes["measures"])
objectives = np.asarray(objectives, dtype=self.dtypes["objective"])
n_ind = measures.shape[0]
if n_ind == 0:
return np.zeros(0, dtype=self.dtypes["measures"])
dist = cdist(measures, measures)
np.fill_diagonal(dist, np.inf) # exclude self distances
fitter_mask = objectives[None, :] >= objectives[:, None]
dist_fitter = np.where(
fitter_mask, dist, np.inf
) # distance to fitter neighbors
k = self._k_neighbors
k_eff = min(k, n_ind - 1) # at most N-1 fitter neighbors
# get k smallest
part = np.partition(dist_fitter, k_eff - 1)[:, :k_eff]
finite_mask = np.isfinite(part)
counts = finite_mask.sum(axis=1)
safe_counts = np.where(counts == 0, 1, counts)
sums = np.where(finite_mask, part, 0).sum(axis=1)
means = sums / safe_counts
means = np.where(
counts == 0, np.inf, means
) # if no fitter neighbors, score is inf
return means
## Methods for writing to the archive ##
[docs]
def add(
self,
solution: ArrayLike,
objective: ArrayLike | None,
measures: ArrayLike,
**fields: ArrayLike,
) -> BatchData:
"""Inserts a batch of solutions with DNS-based survival selection.
The current population and the incoming batch are merged, DNS scores are
computed over the union, and the top ``capacity`` solutions by DNS are kept.
"""
if objective is None:
objective = np.zeros(len(solution), dtype=self.dtypes["objective"])
data = validate_batch(
self,
{
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
)
# Delete these so that we only use the clean, validated data in `data`.
del solution, objective, measures, fields
# Gather current population data.
cur_size = len(self)
if cur_size > 0:
cur = self._store.data(return_type="dict")
# Combine.
combined = {}
for name in self._store.field_list:
if name not in data:
raise ValueError(
f"Field '{name}' is in the archive but was not passed "
f"in the batch data. All fields must be provided."
)
combined[name] = np.concatenate((cur[name], data[name]), axis=0)
else:
combined = data
dns_scores = self.compute_dns(combined["measures"], combined["objective"])
# Select survivors: top `capacity` by DNS (descending).
cap = self.capacity
n_total = dns_scores.shape[0]
if n_total <= cap:
survivor_indices = np.arange(n_total)
else:
# Take largest `cap` values.
survivor_indices = np.argpartition(dns_scores, -cap)[-cap:]
survivor_indices = survivor_indices[
np.argsort(dns_scores[survivor_indices])
]
# Build add_info for batch entries.
batch_size = len(data["measures"])
add_info = {
"status": np.zeros(batch_size, dtype=np.int32),
"dns": np.empty(batch_size, dtype=self.dtypes["measures"]),
}
batch_start = cur_size
batch_indices_in_union = np.arange(batch_start, batch_start + batch_size)
batch_survivors = np.isin(batch_indices_in_union, survivor_indices)
add_info["status"][batch_survivors] = 2
add_info["dns"] = dns_scores[batch_indices_in_union]
survivors = {
name: combined[name][survivor_indices] for name in self._store.field_list
}
self._store.clear()
if survivors["measures"].shape[0] > 0:
self._store.add(np.arange(survivors["measures"].shape[0]), survivors)
# Update stats.
if len(self) > 0:
objective_sum = np.sum(self._store.data("objective"))
qd_score = objective_sum - len(self) * self._qd_score_offset
coverage = len(self) / self.cells
norm_qd_score = qd_score / self.cells
obj_max = np.max(self._store.data("objective"))
obj_mean = np.mean(self._store.data("objective"))
# note: QD score is not an informative statistic for DNS, as
# it has no predefined archive.
self._stats = ArchiveStats(
num_elites=len(self),
coverage=coverage,
qd_score=qd_score,
norm_qd_score=norm_qd_score,
obj_max=obj_max,
obj_mean=obj_mean,
)
# Refresh KD-tree over measures.
self._cur_kd_tree = KDTree(
self._store.data("measures"), **self._kdtree_kwargs
)
else:
self._stats_reset()
return add_info
[docs]
def add_single(
self,
solution: ArrayLike,
objective: ArrayLike | None,
measures: ArrayLike,
**fields: ArrayLike,
) -> SingleData:
"""Inserts a single solution into the archive.
Args:
solution: Parameters of the solution.
objective: Set to None to get the default value of 0; otherwise, a valid
objective value is also acceptable.
measures: Coordinates in measure space of the solution.
fields: Additional data for the solution.
Returns:
Information describing the result of the add operation. The dict contains
``status`` and ``dns`` keys; refer to :meth:`add` for the meaning of
status and dns.
Raises:
ValueError: The array arguments do not match their specified shapes.
ValueError: ``objective`` is non-finite (inf or NaN) or ``measures`` has
non-finite values.
ValueError: ``local_competition`` is turned on but objective was not passed
in.
"""
if objective is None:
objective = 0.0
data = validate_single(
self,
{
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
)
return self.add(**{key: [val] for key, val in data.items()})
[docs]
def clear(self) -> None:
"""Removes all elites in the archive."""
self._store.clear()
self._stats_reset()
## Methods for reading from the archive ##
## Refer to ArchiveBase for documentation of these methods. ##
[docs]
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
measures = np.asarray(measures, dtype=self.dtypes["measures"])
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
check_finite(measures, "measures")
occupied, data = self._store.retrieve(self.index_of(measures))
fill_sentinel_values(occupied, data)
return occupied, data
[docs]
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
measures = np.asarray(measures, dtype=self.dtypes["measures"])
check_shape(measures, "measures", self.measure_dim, "measure_dim")
check_finite(measures, "measures")
occupied, data = self.retrieve(measures[None])
occupied_flag = bool(occupied[0])
return occupied_flag, {field: arr[0] for field, arr in data.items()}
@overload
def data(
self,
fields: str,
return_type: Literal["dict", "tuple", "pandas"] = "dict",
) -> np.ndarray: ...
@overload
def data(
self,
fields: None | Collection[str] = None,
return_type: Literal["dict"] = "dict",
) -> BatchData: ...
@overload
def data(
self,
fields: None | Collection[str] = None,
return_type: Literal["tuple"] = "tuple",
) -> tuple[np.ndarray]: ...
@overload
def data(
self,
fields: None | Collection[str] = None,
return_type: Literal["pandas"] = "pandas",
) -> ArchiveDataFrame: ...
[docs]
def data(
self,
fields: None | Collection[str] | str = None,
return_type: Literal["dict", "tuple", "pandas"] = "dict",
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)
[docs]
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites