ribs.discount_models.MLP

class ribs.discount_models.MLP(layer_specs: Collection[tuple[int, int] | tuple[int, int, bool]], activation: Callable)[source]

PyTorch multi-layer perceptron model.

The MLP has identical activations on every layer, and no activation on the last layer. Each layer can be configured to have biases.

Note

This model requires PyTorch to be installed, e.g., by running pip install torch.

Parameters:
layer_specs: Collection[tuple[int, int] | tuple[int, int, bool]]

List of tuples specifying the linear layers. Each tuple can either contain (in_features, out_features) or (in_features, out_features, bias), where in_features and out_features are integers specifying the input and output shapes of the network, while bias is a bool indicating whether the layer should have a bias.

activation: Callable

Activation layer class, e.g., torch.nn.Tanh

Methods

deserialize(array)

Loads parameters from 1D array.

forward(x)

Passes the inputs through the MLP.

gradient()

Returns 1D array with gradient of all parameters in the model.

num_params()

Counts number of parameters in the model.

serialize()

Returns 1D array with all parameters in the model.

deserialize(array: ndarray) MLP[source]

Loads parameters from 1D array.

For example, given the array output by serialize(), this method can be used to load that array back into the parameters of this model.

Returns:

The model itself, so that it is possible to call model = MLP(...).deserialize(x)

forward(x: Tensor) Tensor[source]

Passes the inputs through the MLP.

gradient() ndarray[source]

Returns 1D array with gradient of all parameters in the model.

Essentially, all the gradients of the model’s parameters are retrieved, flattened, and concatenated together.

Returns:

1D array whose length corresponds to the total size of all gradients in the model.

num_params() int[source]

Counts number of parameters in the model.

serialize() ndarray[source]

Returns 1D array with all parameters in the model.

Essentially, all the parameters of the model are retrieved, flattened, and concatenated together.

Returns:

1D array whose length corresponds to the number of parameters in the model.