Source code for ribs.discount_models

"""Discount models and related utilities.

.. autosummary::
    :toctree:

    MLP
    DiscountModelManager
"""

from __future__ import annotations

from collections.abc import Callable, Collection
from typing import Literal

import numpy as np
from numpy.typing import ArrayLike

from ribs.typing import Float, Int

__all__ = [
    "MLP",
    "DiscountModelManager",
]

try:
    import torch
    from torch import nn
    from torch.utils.data import DataLoader, TensorDataset

    IS_TORCH_AVAILABLE = True
except ImportError:
    # pylint: disable = invalid-name, missing-class-docstring
    class nn:  # noqa: N801
        class Module:
            pass

    IS_TORCH_AVAILABLE = False


# Developer Note: The documentation for this class is hacked. To list new methods,
# manually modify the template in docs/_templates/autosummary/class.rst


[docs] class MLP(nn.Module): """PyTorch multi-layer perceptron model. The MLP has identical activations on every layer, and no activation on the last layer. Each layer can be configured to have biases. .. note:: This model requires `PyTorch <https://pytorch.org/>`_ to be installed, e.g., by running ``pip install torch``. Args: layer_specs: List of tuples specifying the linear layers. Each tuple can either contain ``(in_features, out_features)`` or ``(in_features, out_features, bias)``, where ``in_features`` and ``out_features`` are integers specifying the input and output shapes of the network, while ``bias`` is a bool indicating whether the layer should have a bias. activation: Activation layer class, e.g., :class:`torch.nn.Tanh` """ def __init__( self, layer_specs: Collection[tuple[int, int] | tuple[int, int, bool]], activation: Callable, ) -> None: if not IS_TORCH_AVAILABLE: raise ImportError("PyTorch must be installed to use the MLP.") super().__init__() layers = [] for i, spec in enumerate(layer_specs): layers.append( nn.Linear( in_features=spec[0], out_features=spec[1], bias=spec[2] if len(spec) == 3 else True, ) ) if i != len(layer_specs) - 1: layers.append(activation()) self.model = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Passes the inputs through the MLP.""" return self.model(x)
[docs] def num_params(self) -> int: """Counts number of parameters in the model.""" return sum(p.numel() for p in self.model.parameters())
[docs] def serialize(self) -> np.ndarray: """Returns 1D array with all parameters in the model. Essentially, all the parameters of the model are retrieved, flattened, and concatenated together. Returns: 1D array whose length corresponds to the number of parameters in the model. """ return nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
[docs] def deserialize(self, array: np.ndarray) -> MLP: """Loads parameters from 1D array. For example, given the array output by :meth:`serialize`, this method can be used to load that array back into the parameters of this model. Returns: The model itself, so that it is possible to call ``model = MLP(...).deserialize(x)`` """ nn.utils.vector_to_parameters(torch.from_numpy(array), self.parameters()) return self
[docs] def gradient(self) -> np.ndarray: """Returns 1D array with gradient of all parameters in the model. Essentially, all the gradients of the model's parameters are retrieved, flattened, and concatenated together. Returns: 1D array whose length corresponds to the total size of all gradients in the model. """ return np.concatenate( [p.grad.cpu().detach().numpy().ravel() for p in self.parameters()] )
[docs] class DiscountModelManager: """Wraps a PyTorch model so it can be used as a discount model. This class handles operations like training the model to match new discount value targets (in :meth:`training_loop`) and performing inference (in :meth:`inference`). .. note:: This class assumes all input and output data is of type float32, which is the default type in PyTorch. If different data types are needed, one solution may be to cast the data before/after calls to this class (as is done in :class:`~ribs.archives.DiscountArchive`). .. note:: This class requires `PyTorch <https://pytorch.org/>`_ to be installed, e.g., by running ``pip install torch``. Args: model: A PyTorch model that can take in batches of measures and output batches of scalar discount values. We assume this model has already been placed on the desired device. optimizer: A PyTorch optimizer that is set up to optimize the model's parameters. We use this to train the discount model to output new discount value targets. The optimizer state is maintained across calls to :meth:`training_loop`. device: A PyTorch device for placing tensors during training. train_epochs: When :meth:`training_loop` is called, the model will train until either (1) the total loss on each epoch is less than the ``train_cutoff_loss`` described below, or (2) the number of epochs reaches ``train_epochs``. train_cutoff_loss: See ``train_epochs``. train_batch_size: During each epoch of :meth:`training_loop`, the dataset of measures and targets will be used to train the model with this batch size. normalize_measures: Whether to normalize the measures. Pass None (default) to indicate no normalization. Alternatively, pass "zero_one" to normalize to ``[0, 1]`` or "negative_one_one" to normalize to ``[-1, 1]`` (along each dimension). To normalize to these values, we linearly transform from the range defined by ``measures_low`` and ``measures_high``, described below. measures_low: If ``normalize_measures`` is set, this is the lower bound of the measures for normalizing. measures_high: If ``normalize_measures`` is set, this is the upper bound of the measures for normalizing. normalize_discount: Whether to normalize the discount values. Pass None (default) to indicate no normalization. During training, the targets are linearly transformed to a target range such as [0, 1], and during inference, the discount values output by the model are un-normalized before being returned. Pass "zero_one" to set the range to [0, 1], and "negative_one_one" to set the range to [-1, 1]. discount_low: If ``normalize_discount`` is set, this is the lower bound of the discount values for normalizing. discount_high: If ``normalize_discount`` is set, this is the upper bound of the discount values for normalizing. """ def __init__( self, model: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, *, train_epochs: Int, train_cutoff_loss: Float, train_batch_size: Int, normalize_measures: Literal["zero_one", "negative_one_one"] | None = None, measures_low: ArrayLike | None = None, measures_high: ArrayLike | None = None, normalize_discount: Literal["zero_one", "negative_one_one"] | None = None, discount_low: Float | None = None, discount_high: Float | None = None, ) -> None: if not IS_TORCH_AVAILABLE: raise ImportError( "PyTorch must be installed to use the DiscountModelManager." ) self.model = model self.model.train() # Assume the model is in train mode by default. self.optimizer = optimizer self.device = device self.train_epochs = train_epochs self.train_cutoff_loss = train_cutoff_loss self.train_batch_size = train_batch_size self.normalize_measures = normalize_measures if self.normalize_measures is None: self.measures_low = None self.measures_high = None else: if measures_low is None or measures_high is None: raise ValueError( "If normalize_measures is not None, measures_low and measures_high must be passed in." ) self.measures_low = torch.asarray( measures_low, device=self.device, dtype=torch.float32 ).requires_grad_(False) self.measures_high = torch.asarray( measures_high, device=self.device, dtype=torch.float32 ).requires_grad_(False) self.normalize_discount = normalize_discount if self.normalize_discount is None: self.discount_low = discount_low self.discount_high = discount_high else: if discount_low is None or discount_high is None: raise ValueError( "If normalize_discount is not None, discount_low and discount_high must be passed in." ) self.discount_low = torch.asarray( discount_low, device=self.device, dtype=torch.float32 ).requires_grad_(False) self.discount_high = torch.asarray( discount_high, device=self.device, dtype=torch.float32 ).requires_grad_(False) def _normalize( self, x: ArrayLike, normalize: Literal["zero_one", "negative_one_one"] | None, low: torch.Tensor, high: torch.Tensor, ) -> torch.Tensor: """Places x on the manager's device and normalizes it.""" x = torch.asarray(x, device=self.device, dtype=torch.float32) if normalize is None: return x elif normalize == "negative_one_one": return 2.0 * (x - low) / (high - low) - 1.0 elif normalize == "zero_one": return (x - low) / (high - low) else: raise ValueError(f"Unknown normalization method {normalize}.") def _unnormalize( self, x: torch.Tensor, normalize: Literal["zero_one", "negative_one_one"] | None, low: torch.Tensor, high: torch.Tensor, ) -> torch.Tensor: """Unnormalizes x to the given range. x is assumed to already be a torch.Tensor on the manager's device. """ if normalize is None: return x elif normalize == "negative_one_one": return (x + 1.0) / 2.0 * (high - low) + low elif normalize == "zero_one": return x * (high - low) + low else: raise ValueError(f"Unknown normalization method {normalize}.")
[docs] def training_loop(self, measures: ArrayLike, targets: ArrayLike) -> list[float]: """Regresses the discount model to match the given targets at the given measures. Training proceeds until either (1) the total loss on each epoch is less than the ``train_cutoff_loss``, or (2) the number of epochs reaches ``train_epochs``. The loss function used during training is :class:`~torch.nn.MSELoss`. Args: measures: (batch_size, measure_dim) array of measure values. targets: (batch_size,) array of target values for the discount function. Returns: A list with the total MSE loss accumulated on each epoch, normalized/divided by the size of the dataset. Strictly speaking, the model is updated after every batch is passed through it, so this is not the loss that one would obtain if the measures were all passed through the model at once. """ normalized_measures = self._normalize( measures, self.normalize_measures, self.measures_low, self.measures_high ) normalized_targets = self._normalize( targets, self.normalize_discount, self.discount_low, self.discount_high ) dataset = TensorDataset(normalized_measures, normalized_targets) dataloader = DataLoader(dataset, self.train_batch_size, shuffle=True) criterion = nn.MSELoss(reduction="mean") all_epoch_loss = [] for _ in range(1, self.train_epochs + 1): epoch_loss = 0.0 for b_norm_measures, b_norm_targets in dataloader: cur = self.model(b_norm_measures).squeeze(dim=1) self.optimizer.zero_grad() loss = criterion(cur, b_norm_targets) loss.backward() self.optimizer.step() # Multiply so that we track the total loss even if batch size varies. epoch_loss += loss.item() * len(b_norm_measures) # Divide by total elements in dataset. epoch_loss /= len(dataloader.dataset) all_epoch_loss.append(epoch_loss) if epoch_loss <= self.train_cutoff_loss: break return all_epoch_loss
[docs] def inference( self, measures: ArrayLike, batch_size: int | None = None, ) -> np.ndarray: """Computes discount values at the given measures using the model. This method also temporarily puts the model in eval mode and uses :class:`torch.no_grad`. Args: measures: Inputs to the model of size (n_measures, measure_dim). batch_size: If passed in, the model will only be passed ``batch_size`` inputs at a time. This can be useful if, for instance, the model is very large and there is insufficient memory to handle many inputs simultaneously. Returns: The discount values at the input measures. """ if batch_size is None: batch_size = len(measures) normalized_measures = self._normalize( measures, self.normalize_measures, self.measures_low, self.measures_high ) dataloader = DataLoader( dataset=TensorDataset(normalized_measures), batch_size=batch_size, shuffle=False, ) self.model.eval() discounts = [] with torch.no_grad(): for (b_norm_measures,) in dataloader: b_discounts = self.model(b_norm_measures) discounts.append(b_discounts) self.model.train() # Concatenate all the chunks together and unnormalize them. discounts = torch.cat(discounts, dim=0) discounts = self._unnormalize( discounts, self.normalize_discount, self.discount_low, self.discount_high ) # Turn (X, 1) into (X,). if discounts.ndim == 2: discounts = discounts.squeeze(dim=1) return discounts.detach().cpu().numpy()