Source code for ribs.emitters.opt._gradient_opt_base
"""Provides GradientOptBase."""
from abc import ABC, abstractmethod
import numpy as np
from numpy.typing import ArrayLike
from ribs.typing import Float
[docs]
class GradientOptBase(ABC):
r"""Base class for gradient-based optimizers.
.. note::
These optimizers are designed for gradient ascent rather than gradient descent.
These optimizers maintain a current solution point :math:`\theta`. The solution
point is obtained with the :attr:`theta` property, and it is updated by passing a
gradient to :meth:`step`. Finally, the point can be reset to a new value with
:meth:`reset`.
Your constructor may take in additional arguments beyond ``theta0`` and ``lr``, but
expect that these two arguments will always be passed in.
Args:
theta0: Initial solution. 1D array.
lr: Learning rate for the update.
"""
def __init__(self, theta0: ArrayLike, lr: Float) -> None:
pass
@property
@abstractmethod
def theta(self) -> np.ndarray:
"""The current solution point."""
[docs]
@abstractmethod
def reset(self, theta0: ArrayLike) -> None:
"""Resets the solution point to a new value.
Args:
theta0: The new solution point. 1D array.
"""
[docs]
@abstractmethod
def step(self, gradient: ArrayLike) -> None:
"""Ascends the solution based on the given gradient.
Args:
gradient: The (estimated) gradient of the current solution point. 1D array.
"""