import numpy as np

Refer to `Kingma and Ba 2014 <https://arxiv.org/pdf/1412.6980.pdf>`_ for

Args:
theta0 (array-like): Initial solution. 1D array.
lr (float): Learning rate for the update.
beta1 (float): Exponential decay rate for the moment estimates.
beta2 (float): Another exponential decay rate for the moment estimates.
epsilon (float): Hyperparameter for numerical stability.
l2_coeff (float): Coefficient for L2 regularization. Note this is
**not** the same as "weight decay" -- see `this blog post
`Loshchilov and Hutler 2019 <https://arxiv.org/abs/1711.05101>_` for
"""

def __init__(  # pylint: disable = super-init-not-called
self,
theta0,
lr=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
l2_coeff=0.0):
self._m = None
self._v = None
self._t = None

self._epsilon = epsilon
self._beta1 = beta1
self._beta2 = beta2
self._l2_coeff = l2_coeff

self._lr = lr
self._theta = None
self.reset(theta0)

@property
def theta(self):
return self._theta

[docs]    def reset(self, theta0):
self._theta = np.copy(theta0)
self._m = np.zeros_like(self._theta)
self._v = np.zeros_like(self._theta)
self._t = 0

# Invert gradient since we seek to maximize -- see pseudocode here:

# L2 regularization (not weight decay).