Source code for ribs.emitters.opt._gradient_opt_base
"""Provides GradientOptBase."""
from abc import ABC, abstractmethod
[docs]class GradientOptBase(ABC):
"""Base class for gradient-based optimizers.
.. note::
These optimizers are designed for gradient ascent rather than gradient
descent.
These optimizers maintain a current solution point :math:`\\theta`. The
solution point is obtained with the :attr:`theta` property, and it is
updated by passing a gradient to :meth:`step`. Finally, the point can be
reset to a new value with :meth:`reset`.
Your constructor may take in additional arguments beyond ``theta0`` and
``lr``, but expect that these two arguments will always be passed in.
Args:
theta0 (array-like): Initial solution. 1D array.
lr (float): Learning rate for the update.
"""
def __init__(self, theta0, lr):
pass
@property
@abstractmethod
def theta(self):
"""The current solution point."""
[docs] @abstractmethod
def reset(self, theta0):
"""Resets the solution point to a new value.
Args:
theta0 (array-like): The new solution point. 1D array.
"""
[docs] @abstractmethod
def step(self, gradient):
"""Ascends the solution based on the given gradient.
Args:
gradient (array-like): The (estimated) gradient of the current
solution point. 1D array.
"""