Model#

Basic Usage#

TODO

Advanced Usage#

TODO

Saving and Loading Models#

TODO

API Reference#

class Model(modules=None, rng=None)[source]#

Bases: object

Model class built from a list of modules.

__init__

Initialize the model with the input shape and a list of modules.

get_num_trainable_floats

reset

append_module

Append a module to the model.

prepend_module

Prepend a module to the model.

insert_module

Insert a module at the given index in the model.

add

Append a module to the model.

put

Prepend a module to the model.

insert

Insert a module at the given index in the model.

remove_module

Remove a module from the model at the given index.

pop_module

Pop the last module from the model.

__getitem__

Get the module at the given index.

compile

Compile the model for training by compiling each module.

get_output_shape

Get the output shape of the model given an input shape.

get_params

Get the parameters of the model as a Tuple of numpy arrays.

set_params

Set the parameters of the model from a Tuple of numpy arrays.

get_state

Get the state of the model as a Tuple of numpy arrays.

set_state

Set the state of the model from a Tuple of numpy arrays.

get_rng

set_rng

Set the random key for the model.

_get_callable

This method must return a jax-jittable and jax-gradable callable in the form of ` (     params: Tuple[np.ndarray, ...],     input_NF: np.ndarray[num_samples, num_features],     training: bool,     state: Tuple[np.ndarray, ...],     rng: key<fry> ) -> (         output_NF: np.ndarray[num_samples, num_output_features],         new_state: Tuple[np.ndarray, ...]     ) ` That is, all hyperparameters are traced out and the callable depends explicitly only on a Tuple of parameter numpy arrays, the input array, the training flag, a state Tuple of numpy arrays, and a JAX rng key.

__call__

Call the model with the input array.

predict

Call the model with the input array.

set_precision

Set the precision of the model parameters and states.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

train

serialize

Serialize the model to a dictionary.

deserialize

Deserialize the model from a dictionary.

save

Save the model to a file.

save_compressed

Save the model to a compressed file.

load

Load the model from a file.

from_file

Load a model from a file and return an instance of the Model class.

Model.__init__(modules=None, rng=None)[source]#

Initialize the model with the input shape and a list of modules.

Parameters:
  • modules (List[BaseModule], optional) – List of modules to initialize the model with. Default is an empty list.

  • rng (Any, optional) – Initial random key for the model. Default is None. If None, a new random key will be generated using JAX’s random.PRNGKey. If an integer is provided, it will be used as the seed to create the key.

Model.get_num_trainable_floats()[source]#
Return type:

Optional[int]

Model.reset()[source]#
Return type:

None

Model.append_module(module)[source]#

Append a module to the model.

Parameters:

module (BaseModule) – Module to append to the model.

Return type:

None

Model.prepend_module(module)[source]#

Prepend a module to the model.

Parameters:

module (BaseModule) – Module to prepend to the model.

Return type:

None

Model.insert_module(module, index)[source]#

Insert a module at the given index in the model.

Parameters:
  • module (BaseModule) – Module to insert into the model.

  • index (int) – Index at which to insert the module.

Return type:

None

Model.add(module)#

Append a module to the model.

Parameters:

module (BaseModule) – Module to append to the model.

Return type:

None

Model.put(module)#

Prepend a module to the model.

Parameters:

module (BaseModule) – Module to prepend to the model.

Return type:

None

Model.insert(module, index)#

Insert a module at the given index in the model.

Parameters:
  • module (BaseModule) – Module to insert into the model.

  • index (int) – Index at which to insert the module.

Return type:

None

Model.remove_module(index)[source]#

Remove a module from the model at the given index.

Parameters:

index (int) – Index of the module to remove.

Return type:

None

Model.pop_module()[source]#

Pop the last module from the model.

Return type:

BaseModule

Returns:

BaseModule

The last module in the model

Model.__getitem__(key)[source]#

Get the module at the given index.

Parameters:

index (int) – Index of the module to retrieve.

Return type:

Union[List[BaseModule], BaseModule]

Returns:

BaseModule

The module at the specified index.

Model.compile(rngkey, input_shape, verbose=False)[source]#

Compile the model for training by compiling each module.

Parameters:
  • rngkey (Union[Any, int]) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.

  • input_shape (Tuple[int, ...]) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.

  • verbose (bool, optional) – Print debug information during compilation. Default is False.

Return type:

None

Model.get_output_shape(input_shape)[source]#

Get the output shape of the model given an input shape.

Parameters:

input_shape (Tuple[int, ...]) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.

Return type:

Tuple[int, ...]

Returns:

Tuple[int, …]

Shape of the output array after passing through the model.

Model.get_params()[source]#

Get the parameters of the model as a Tuple of numpy arrays.

Return type:

Tuple[Array, ...]

Returns:

Tuple[np.ndarray, …]

numpy arrays representing the parameters of the model. The order of the parameters should match the order in which they are used in the _get_callable method.

Model.set_params(params)[source]#

Set the parameters of the model from a Tuple of numpy arrays.

Parameters:

params (Tuple[np.ndarray, ...]) – numpy arrays representing the parameters of the model. The order of the parameters should match the order in which they are used in the _get_callable method.

Return type:

None

Model.get_state()[source]#

Get the state of the model as a Tuple of numpy arrays.

Return type:

Tuple[Array, ...]

Returns:

Tuple[np.ndarray, …]

numpy arrays representing the state of the model. The order of the states should match the order in which they are used in the _get_callable method.

Model.set_state(state)[source]#

Set the state of the model from a Tuple of numpy arrays.

Parameters:

state (Tuple[np.ndarray, ...]) – numpy arrays representing the state of the model. The order of the states should match the order in which they are used in the _get_callable method.

Return type:

None

Model.get_rng()[source]#
Return type:

Any

Model.set_rng(rng)[source]#

Set the random key for the model.

Parameters:

rng (Any) – Random key to set for the model. JAX PRNGKey or an integer seed

Return type:

None

Model._get_callable()[source]#

This method must return a jax-jittable and jax-gradable callable in the form of ``` (

params: Tuple[np.ndarray, …], input_NF: np.ndarray[num_samples, num_features], training: bool, state: Tuple[np.ndarray, …], rng: key<fry>

) -> (

output_NF: np.ndarray[num_samples, num_output_features], new_state: Tuple[np.ndarray, …]

)

``` That is, all hyperparameters are traced out and the callable depends explicitly only on a Tuple of parameter numpy arrays, the input array, the training flag, a state Tuple of numpy arrays, and a JAX rng key.

The training flag will be traced out, so it doesn’t need to be jittable

Return type:

Callable[[Tuple[Array, ...], Array, bool, Tuple[Array, ...], Any], Tuple[Array, Tuple[Array, ...]]]

Model.__call__(X, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False)[source]#

Call the model with the input array.

Parameters:
  • X (np.ndarray) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (Optional[Any], optional) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights.

  • rng (Any, optional) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key.

  • return_state (bool, optional) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool, optional) – If True, the model will update the state of the model after evaluation. Default is False.

Return type:

Array

Returns:

np.ndarray

Output array of shape (batch_size, <output feature axes>). For example, (batch_size, output_features) for a 1D output or (batch_size, output_height, output_width, output_channels) for a 3D output.

Tuple[np.ndarray, …], optional

If return_state is True, the model will also return the state of the model as a Tuple of numpy arrays. The order of the states will match the order in which they are used in the _get_callable method.

Model.predict(X, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False)#

Call the model with the input array.

Parameters:
  • X (np.ndarray) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (Optional[Any], optional) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights.

  • rng (Any, optional) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key.

  • return_state (bool, optional) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool, optional) – If True, the model will update the state of the model after evaluation. Default is False.

Return type:

Array

Returns:

np.ndarray

Output array of shape (batch_size, <output feature axes>). For example, (batch_size, output_features) for a 1D output or (batch_size, output_height, output_width, output_channels) for a 3D output.

Tuple[np.ndarray, …], optional

If return_state is True, the model will also return the state of the model as a Tuple of numpy arrays. The order of the states will match the order in which they are used in the _get_callable method.

Model.set_precision(prec)[source]#

Set the precision of the model parameters and states.

Parameters:

prec (Union[np.dtype, str, int]) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64

Return type:

None

Model.astype(dtype)[source]#

Convenience wrapper to set_precision using the dtype argument, returns self.

Return type:

Model

Model.train(X, Y=None, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, num_epochs=100, convergence_threshold=1e-12, early_stopping_patience=10, early_stopping_tolerance=1e-06, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_seed=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)[source]#
Return type:

None

Model.serialize()[source]#

Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.

Return type:

Dict[str, Union[Any, Dict[str, Any]]]

Returns:

Dict[str, Union[Any, Dict[str, Any]]]

Model.deserialize(data)[source]#

Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.

Parameters:

data (Dict[str, Any]) – Dictionary containing the serialized model data.

Return type:

None

Model.save(filename)[source]#

Save the model to a file.

Parameters:

filename (str) – Name of the file to save the model to.

Return type:

None

Model.save_compressed(filename)[source]#

Save the model to a compressed file.

Parameters:

filename (str) – Name of the file to save the model to.

Return type:

None

Model.load(filename)[source]#

Load the model from a file. Supports both compressed and uncompressed

Parameters:

filename (str) – Name of the file to load the model from.

Return type:

None

classmethod Model.from_file(filename)[source]#

Load a model from a file and return an instance of the Model class.

Parameters:

filename (str) – Name of the file to load the model from.

Return type:

Model

Returns:

Model

An instance of the Model class with the loaded parameters.