Model
#
Basic Usage#
TODO
Advanced Usage#
TODO
Saving and Loading Models#
TODO
API Reference#
- class Model(modules=None, rng=None)[source]#
Bases:
object
Model class built from a list of modules.
Initialize the model with the input shape and a list of modules. |
|
Append a module to the model. |
|
Prepend a module to the model. |
|
Insert a module at the given index in the model. |
|
Append a module to the model. |
|
Prepend a module to the model. |
|
Insert a module at the given index in the model. |
|
Remove a module from the model at the given index. |
|
Pop the last module from the model. |
|
Get the module at the given index. |
|
Compile the model for training by compiling each module. |
|
Get the output shape of the model given an input shape. |
|
Get the parameters of the model as a Tuple of numpy arrays. |
|
Set the parameters of the model from a Tuple of numpy arrays. |
|
Get the state of the model as a Tuple of numpy arrays. |
|
Set the state of the model from a Tuple of numpy arrays. |
|
Set the random key for the model. |
|
This method must return a jax-jittable and jax-gradable callable in the form of |
|
Call the model with the input array. |
|
Call the model with the input array. |
|
Set the precision of the model parameters and states. |
|
Convenience wrapper to set_precision using the dtype argument, returns self. |
|
Serialize the model to a dictionary. |
|
Deserialize the model from a dictionary. |
|
Save the model to a file. |
|
Save the model to a compressed file. |
|
Load the model from a file. |
|
Load a model from a file and return an instance of the Model class. |
- Model.__init__(modules=None, rng=None)[source]#
Initialize the model with the input shape and a list of modules.
- Parameters:
modules (List[BaseModule], optional) – List of modules to initialize the model with. Default is an empty list.
rng (Any, optional) – Initial random key for the model. Default is None. If None, a new random key will be generated using JAX’s random.PRNGKey. If an integer is provided, it will be used as the seed to create the key.
- Model.append_module(module)[source]#
Append a module to the model.
- Parameters:
module (BaseModule) – Module to append to the model.
- Return type:
- Model.prepend_module(module)[source]#
Prepend a module to the model.
- Parameters:
module (BaseModule) – Module to prepend to the model.
- Return type:
- Model.insert_module(module, index)[source]#
Insert a module at the given index in the model.
- Parameters:
module (BaseModule) – Module to insert into the model.
index (int) – Index at which to insert the module.
- Return type:
- Model.add(module)#
Append a module to the model.
- Parameters:
module (BaseModule) – Module to append to the model.
- Return type:
- Model.put(module)#
Prepend a module to the model.
- Parameters:
module (BaseModule) – Module to prepend to the model.
- Return type:
- Model.insert(module, index)#
Insert a module at the given index in the model.
- Parameters:
module (BaseModule) – Module to insert into the model.
index (int) – Index at which to insert the module.
- Return type:
- Model.pop_module()[source]#
Pop the last module from the model.
- Return type:
- Returns:
- BaseModule
The last module in the model
- Model.__getitem__(key)[source]#
Get the module at the given index.
- Parameters:
index (int) – Index of the module to retrieve.
- Return type:
- Returns:
- BaseModule
The module at the specified index.
- Model.compile(rngkey, input_shape, verbose=False)[source]#
Compile the model for training by compiling each module.
- Parameters:
rngkey (Union[Any, int]) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.
input_shape (Tuple[int, ...]) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.
verbose (bool, optional) – Print debug information during compilation. Default is False.
- Return type:
- Model.get_output_shape(input_shape)[source]#
Get the output shape of the model given an input shape.
- Parameters:
input_shape (Tuple[int, ...]) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.
- Return type:
- Returns:
- Tuple[int, …]
Shape of the output array after passing through the model.
- Model.set_params(params)[source]#
Set the parameters of the model from a Tuple of numpy arrays.
- Parameters:
params (Tuple[np.ndarray, ...]) – numpy arrays representing the parameters of the model. The order of the parameters should match the order in which they are used in the _get_callable method.
- Return type:
- Model.set_state(state)[source]#
Set the state of the model from a Tuple of numpy arrays.
- Parameters:
state (Tuple[np.ndarray, ...]) – numpy arrays representing the state of the model. The order of the states should match the order in which they are used in the _get_callable method.
- Return type:
- Model.set_rng(rng)[source]#
Set the random key for the model.
- Parameters:
rng (Any) – Random key to set for the model. JAX PRNGKey or an integer seed
- Return type:
- Model._get_callable()[source]#
This method must return a jax-jittable and jax-gradable callable in the form of ``` (
params: Tuple[np.ndarray, …], input_NF: np.ndarray[num_samples, num_features], training: bool, state: Tuple[np.ndarray, …], rng: key<fry>
- ) -> (
output_NF: np.ndarray[num_samples, num_output_features], new_state: Tuple[np.ndarray, …]
)
``` That is, all hyperparameters are traced out and the callable depends explicitly only on a Tuple of parameter numpy arrays, the input array, the training flag, a state Tuple of numpy arrays, and a JAX rng key.
The training flag will be traced out, so it doesn’t need to be jittable
- Model.__call__(X, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False)[source]#
Call the model with the input array.
- Parameters:
X (np.ndarray) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (Optional[Any], optional) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights.
rng (Any, optional) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key.
return_state (bool, optional) – If True, the model will return the state of the model after evaluation. Default is False.
update_state (bool, optional) – If True, the model will update the state of the model after evaluation. Default is False.
- Return type:
- Returns:
- np.ndarray
Output array of shape (batch_size, <output feature axes>). For example, (batch_size, output_features) for a 1D output or (batch_size, output_height, output_width, output_channels) for a 3D output.
- Tuple[np.ndarray, …], optional
If return_state is True, the model will also return the state of the model as a Tuple of numpy arrays. The order of the states will match the order in which they are used in the _get_callable method.
- Model.predict(X, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False)#
Call the model with the input array.
- Parameters:
X (np.ndarray) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (Optional[Any], optional) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights.
rng (Any, optional) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key.
return_state (bool, optional) – If True, the model will return the state of the model after evaluation. Default is False.
update_state (bool, optional) – If True, the model will update the state of the model after evaluation. Default is False.
- Return type:
- Returns:
- np.ndarray
Output array of shape (batch_size, <output feature axes>). For example, (batch_size, output_features) for a 1D output or (batch_size, output_height, output_width, output_channels) for a 3D output.
- Tuple[np.ndarray, …], optional
If return_state is True, the model will also return the state of the model as a Tuple of numpy arrays. The order of the states will match the order in which they are used in the _get_callable method.
- Model.set_precision(prec)[source]#
Set the precision of the model parameters and states.
- Parameters:
prec (Union[np.dtype, str, int]) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64
- Return type:
- Model.astype(dtype)[source]#
Convenience wrapper to set_precision using the dtype argument, returns self.
- Return type:
- Model.train(X, Y=None, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, num_epochs=100, convergence_threshold=1e-12, early_stopping_patience=10, early_stopping_tolerance=1e-06, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_seed=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)[source]#
- Return type:
- Model.serialize()[source]#
Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.
- Model.deserialize(data)[source]#
Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.