NonSequentialModel#
parametricmatrixmodels.NonSequentialModel
- class NonSequentialModel(modules=None, connections=None, /, *, rng=None, separator='.')[source]#
Bases:
ModelA nonsequential model that chains modules (or other models) together with directed connections.
For confidence intervals or uncertainty quantification, wrap a trained model with
ConformalModel.See also
jax.treePyTree utilities and concepts in JAX.
ModelAbstract base class for all models.
SequentialModelA model that applies modules in sequence.
ConformalModelWrap a trained model to produce confidence intervals.
- Parameters:
- __init__(modules=None, connections=None, /, *, rng=None, separator='.')[source]#
Initialize a nonsequential model with a PyTree of modules and a random key.
- Parameters:
modules (ModelModules | BaseModule | None) – module(s) to initialize the model with. Default is None, which will become an empty dictionary. Can be a single module, which will be wrapped in a list, or a PyTree of modules (e.g., nested lists, tuples, or dictionaries).
connections (Dict[str, str | List[str] | Tuple[str, ...]] | None) – Directed connections between module input and outputs in the model. Keys are period-separated paths of module outputs, and values are lists or tuples of period-separated paths of module inputs that receive the output. The reserved keys “input” and “output” refer to the model input and output, respectively. The separator can be changed from the default period using the
separatorargument. Default is None, which will become an empty dictionary.rng (Any | int | None) – Initial random key for the model. Default is None. If None, a new random key will be generated using JAX’s
random.key. If an integer is provided, it will be used as the seed to create the key.separator (str) – Separator string to use for denoting paths in the connections dictionary. Default is “.”.
- Return type:
None
Examples
To denote a sequential model where all modules expect a bare array and produce a bare array:
>>> modules = [Module1(), Module2(), Module3()] >>> connections = { ... "input": "0", ... "0": "1", ... "1": "2", ... "2": "output" ... } >>> model = NonSequentialModel(modules, connections)
or equivalently to name the modules:
>>> modules = {"M0": Module1(), "M1": Module2(), "M2": Module3()} >>> connections = { ... "input": "M0", ... "M0": "M1", ... "M1": "M2", ... "M2": "output" ... } >>> model = NonSequentialModel(modules, connections)
- or equivalently to use nested structures:
>>> modules = { ... "block1": [Module1(), Module2()], ... "block2": {"M3": Module3()} ... } >>> connections = { ... "input": "block1.0", ... "block1.0": "block1.1", ... "block1.1": "block2.M3", ... "block2.M3": "output" ... } >>> model = NonSequentialModel(modules, connections)
All three of the above will produce a model that applies the same three modules sequentially.
If a module outputs a PyTree of arrays, or expects a PyTree of arrays as input, the connections can specify the leaf nodes using the same period-separated path syntax. For example, if Module1 outputs a dict with keys “a” and “b”, and Module2 expects a tuple of two arrays as input, the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()} >>> connections = { ... "input": "M1", ... "M1.a": "M2.1", ... "M1.b": "M2.0", ... "M2": "output" ... } >>> model = NonSequentialModel(modules, connections)
This will send the “a” output of Module1 to the second input of Module2, and the “b” output of Module1 to the first input of Module2.
Note
If a module expects a Tuple or List as input, it is best to write the module to accept both Tuple and List types, since the specific input type between List and Tuple cannot be inferred at compile time.
If the entire model input or output is a PyTree of arrays, the connections use the same period-separated path syntax with the reserved keys “input” and “output”. For example, if the model input is a dict with keys “x1” and “x2”, and the model output is a Tuple of two arrays, the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()} >>> connections = { ... "input.x1": "M1", ... "M1": "M2", ... "M2": "output.1", ... "input.x2": "output.0" ... } >>> model = NonSequentialModel(modules, connections)
This will perform a sequential model on the “x1” input through Module1 and Module2, sending the output to the second output of the model, and will send the “x2” input directly to the first output of the model unchanged.
Modules that output PyTrees need not be fully traversed if entire subtrees are to be passed between modules. For example, if Module1 outputs a dict with keys “a” and “b”, and Module2 expects a dict with keys “a” and “b” as input, the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()} >>> connections = { ... "input": "M1", ... "M1": "M2", ... "M2": "output" ... } >>> model = NonSequentialModel(modules, connections)
or equivalently:
>>> modules = {"M1": Module1(), "M2": Module2()} >>> connections = { ... "input": "M1", ... "M1.a": "M2.a", ... "M1.b": "M2.b", ... "M2": "output" ... } >>> model = NonSequentialModel(modules, connections)
Both ways will pass the entire output dict of Module1 to Module2.
Module ouputs can be sent to multiple module inputs by specifying a list or tuple of input paths in the connections dictionary. For example, to send the output of Module1 to both Module2 and Module3:
>>> modules = {"M1": Module1(), "M2": Module2(), "M3": Module3()} >>> connections = { ... "input": "M1", ... "M1": ["M2", "M3"], ... "M2": "output.0", ... "M3": "output.1" ... } >>> model = NonSequentialModel(modules, connections)
This will create a model that sends the output of Module1 to both Module2 and Module3 in parallel, and collects their outputs as a Tuple as the model output.
The order of the connections in the dictionary does not matter, as long as the connections form a valid directed acyclic graph from the model input to the model output. It is not necessary to use all parts of the model input, or all modules. However, this will raise a warning. It is not necessary and will not raise a warning if some parts of the outputs of some modules are not used, but all inputs of all present modules must be connected.
See also
ModelModulesType alias for a PyTree of modules in a model.
jax.random.keyJAX function to create a random key.
jax.tree_util.keystrJAX function to create string paths for PyTree KeyPaths in the format expected by the connections dictionary.
Call the model with the input data.
Returns a
jax.jit-able andjax.grad-able callable that represents the model's forward pass.Append a module to the end of the model.
Append a module to the end of the model.
Append a module to the end of the model.
Append a module to the end of the model.
Convenience wrapper to set_precision using the dtype argument, returns self.
Compile the model for training by finding the execution order of the directed graph defined by the connections, and compiling each module in that order.
Deserialize the model from a dictionary.
Load a model from a file and return an instance of the Model class.
Resolve the connections dictionary to find the execution order of module execution.
Get the hyperparameters of the model as a dictionary.
Get the modules of the model.
Returns the number of trainable floats in the module.
Get the output shape of the model given an input shape.
Get the parameters of the model.
Get the state of all modules in the model as a PyTree.
Doc TODO
Doc TODO
Insert a module at a specific index in the model.
Insert a module at a specific index in the model.
Check if the model is compiled and ready for use.
Load the model from a file.
Remove and return a module by key or index in the model.
Remove and return a module at a specific index in the model.
Remove and return a module by key or index in the model.
Call the model with the input data.
Prepend a module to the beginning of the model.
Prepend a module to the beginning of the model.
Reset the compiled state of the model.
Save the model to a file.
Save the model to a compressed file.
Serialize the model to a dictionary.
Set the hyperparameters of the model from a dictionary.
Set the parameters of the model from a PyTree of PyTrees of numpy arrays.
Set the precision of the model parameters and states.
Set the random key for the model.
Set the state of the model from a PyTree of PyTrees of numpy arrays.
Returns the name of the module, unless overridden, this is the class name.
- __call__(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#
Call the model with the input data.
- Parameters:
X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.
rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.
return_state (bool) – If True, the model will return the state of the model after evaluation. Default is
False.update_state (bool) – If True, the model will update the state of the model after evaluation. Default is
False.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch.
- Returns:
Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[Data, ModelState] | Data
- _get_callable()[source]#
Returns a
jax.jit-able andjax.grad-able callable that represents the model’s forward pass.This must be implemented by all subclasses.
This method must be implemented by all subclasses and must return a
jax-jit-able andjax-grad-able callable in the form ofmodel_callable( params: parametricmatrixmodels.model_util.ModelParams, data: parametricmatrixmodels.typing.Data, training: bool, state: parametricmatrixmodels.model_util.ModelState, rng: Any, ) -> ( output: parametricmatrixmodels.typing.Data, new_state: parametricmatrixmodels.model_util.ModelState, )
That is, all hyperparameters are traced out and the callable depends explicitly only on
the model’s parameters, as a PyTree with leaf nodes as JAX arrays,
- the input data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
the training flag, as a boolean,
the model’s state, as a PyTree with leaf nodes as JAX arrays
and returns
- the output data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
- the new model state, as a PyTree with leaf nodes as JAX arrays. The
PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.
The training flag will be traced out, so it doesn’t need to be jittable
- Returns:
A callable that takes the model’s parameters, input data,
training flag, state, and rng key and returns the output data and
new state.
- Return type:
pmm.model_util.ModelCallable
See also
__call__Calls the model with the current parameters and given input, state, and rng.
ModelCallableTyping for the callable returned by this method.
ParamsTyping for the model parameters.
DataTyping for the input and output data.
StateTyping for the model state.
- _get_module_input_dependencies()[source]#
Get the input dependencies for each module in the execution order.
- Returns:
module_input_dependencies – List of PyTrees of str paths to the required inputs for each module in the execution order. The first entry corresponds to the “input” node and is None.
output_input_dependencies – PyTree of str paths to the required inputs for the “output” node.
- Return type:
- _get_shape_progression(input_shape, /)[source]#
Get the progression of output shapes through the model given an input shape. The first entry is the model input shape, and the last entry is the model output shape.
- Parameters:
input_shape (DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.
- Returns:
input_shapes – PyTree of input shapes at each module in the execution order, with the same structure as the modules in the execution order.
output_shapes – PyTree of output shapes at each module in the execution order, with the same structure as the modules in the execution order.
output_shape – Shape of the output after passing through the model.
- Return type:
Tuple[PyTree[DataShape | None], PyTree[DataShape | None], DataShape]
- _verify_input(X)#
- Parameters:
X (pmm.typing.Data)
- Return type:
None
- add(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- add_module(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- append(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- append_module(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- astype(dtype, /)#
Convenience wrapper to set_precision using the dtype argument, returns self.
- compile(rng, input_shape, /, *, verbose=False)[source]#
Compile the model for training by finding the execution order of the directed graph defined by the connections, and compiling each module in that order.
- Parameters:
rng (Any | int | None) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.
input_shape (pmm.typing.DataShape) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.
verbose (bool) – Print debug information during compilation. Default is False.
- Raises:
ValueError – If the connections do not form a valid directed acyclic graph from the model input to the model output.
- Return type:
None
- deserialize(data, /)#
Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.
- Parameters:
data (Dict[str, Any]) – Dictionary containing the serialized model data.
- Return type:
None
- classmethod from_file(file, /)#
Load a model from a file and return an instance of the Model class.
- get_execution_order()[source]#
Resolve the connections dictionary to find the execution order of module execution.
- Raises:
ValueError – If the connections do not form a valid directed acyclic graph from the model input to the model output.
- Return type:
- get_hyperparameters()[source]#
Get the hyperparameters of the model as a dictionary.
- Returns:
Dict[str, Any] – Dictionary containing the hyperparameters of the model.
- Return type:
pmm.typing.HyperParams
- get_modules()#
Get the modules of the model.
- Returns:
modules – PyTree of modules in the model. The structure of the PyTree will match that of the modules in the model.
- Return type:
pmm.model_util.ModelModules
See also
ModelModulesType alias for a PyTree of modules in a model.
- get_num_trainable_floats()#
Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns
0. If the module is not ready, returnsNone.- Returns:
Number of trainable floats in the module, or None if the module
is not ready.
- Return type:
int | None
- get_output_shape(input_shape, /)[source]#
Get the output shape of the model given an input shape. Must be implemented by all subclasses.
- Parameters:
input_shape (pmm.typing.DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.
- Returns:
output_shape – Shape of the output after passing through the model.
- Return type:
pmm.typing.DataShape
- get_params()#
Get the parameters of the model.
- Returns:
params – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the parameters of each module.
- Return type:
pmm.model_util.ModelParams
See also
ModelParamsType alias for a PyTree of parameters in a model.
get_modulesGet the modules of the model, in the same structure as the parameters returned by this method.
set_paramsSet the parameters of the model from a corresponding PyTree of PyTrees of numpy arrays.
- get_state()[source]#
Get the state of all modules in the model as a PyTree.
Override of the base method in order to ignore modules that are not in the execution order.
- Returns:
A PyTree of the states of all modules in the model with the same
structure as the modules PyTree. Modules that are not in the
execution order will have state None.
- Return type:
pmm.model_util.ModelState
- grad_input(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#
Doc TODO
- Parameters:
fwd (bool | None) – If True, use
jax.jacfwd, otherwise usejax.jacrev. Default isNone, which decides based on the input and output shapes.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch. Ifmax_batch_sizeis set to1, the gradient will be computed one sample at a time without batching. This case is particularly important forgrad_inputsince the Jacobian contains gradients across different batch samples and thus scales with the square of the batch size.X (PyTree[Inexact[Array, 'num_samples ...'], ' DataStruct'])
dtype (jax.type.DTypeLike)
rng (Any | int | None)
return_state (bool)
update_state (bool)
- Returns:
PyTree – Gradient of the model output with respect to the input data, as a PyTree of JAX arrays, the structure of which matches that of the output structure of the model composed above the input data structure. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, input_dim1, input_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the input dimensions correspond to the shape of the input data for that leaf.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’], ModelState] | PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’]
- grad_params(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#
Doc TODO
- Parameters:
fwd (bool | None) – If True, use
jax.jacfwd, otherwise usejax.jacrev. Default isNone, which decides based on the input and output shapes.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch. Only applies ifbatched=True.X (PyTree[jaxtyping.Inexact[Array, 'num_samples ...'], 'DataStruct'])
return_state (bool)
update_state (bool)
- Returns:
PyTree – Gradient of the model output with respect to the model parameters, as a PyTree of PyTrees of JAX arrays, the upper level structure of which matches that of the model’s modules, and the lower level structure of which matches that of the parameters of each module. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, param_dim1, param_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the param dimensions correspond to the shape of the parameter for that leaf.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
tuple[PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None], TypeAliasForwardRef(’pmm.model_util.ModelState’)] | PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None]
- insert(index, module, /, key=None)#
Insert a module at a specific index in the model.
- Parameters:
index (int) – Index to insert the module at.
module (BaseModule) – Module to insert.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- insert_module(index, module, /, key=None)#
Insert a module at a specific index in the model.
- Parameters:
index (int) – Index to insert the module at.
module (BaseModule) – Module to insert.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- is_ready()[source]#
Check if the model is compiled and ready for use. Overrides the base implementation since not all modules need to be ready, as some may not appear in the execution order.
- Returns:
True if the model is compiled and ready, False otherwise.
- Return type:
- load(file, /)#
Load the model from a file. Supports both compressed and uncompressed
- property name: str#
Returns the name of the module, unless overridden, this is the class name.
- Returns:
Name of the module.
- pop(key, /)#
Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.
- Returns:
The removed module.
- Parameters:
- Return type:
- pop_module_by_index(index, /)#
Remove and return a module at a specific index in the model. :param index: Index of the module to remove.
- Returns:
The removed module.
- Parameters:
index (int)
- Return type:
- pop_module_by_key(key, /)#
Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.
- Returns:
The removed module.
- Parameters:
- Return type:
- predict(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#
Call the model with the input data.
- Parameters:
X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.
rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.
return_state (bool) – If True, the model will return the state of the model after evaluation. Default is
False.update_state (bool) – If True, the model will update the state of the model after evaluation. Default is
False.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch.
- Returns:
Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[Data, ModelState] | Data
- prepend(module, /, key=None)#
Prepend a module to the beginning of the model.
- Parameters:
module (BaseModule) – Module to prepend.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- prepend_module(module, /, key=None)#
Prepend a module to the beginning of the model.
- Parameters:
module (BaseModule) – Module to prepend.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- reset()[source]#
Reset the compiled state of the model. This will require recompilation before the model can be used again.
- Return type:
None
- save(file, /)#
Save the model to a file.
- save_compressed(file, /)#
Save the model to a compressed file.
- serialize()#
Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.
- set_hyperparameters(hyperparams, /)[source]#
Set the hyperparameters of the model from a dictionary.
- Parameters:
hyperparams (Dict[str, Any]) – Dictionary containing the hyperparameters of the model.
- Return type:
None
- set_params(params, /)#
Set the parameters of the model from a PyTree of PyTrees of numpy arrays.
- Parameters:
params (pmm.model_util.ModelParams) – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the parameters of each module.
- Return type:
None
See also
ModelParamsType alias for a PyTree of parameters in a model.
get_modulesGet the modules of the model, in the same structure as the parameters accepted by this method.
get_paramsGet the parameters of the model, in the same structure as the parameters accepted by this method.
- set_precision(prec, /)#
Set the precision of the model parameters and states.
- Parameters:
prec (str | type[Any] | dtype | SupportsDType | int) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64
- Return type:
None
- set_rng(rng, /)#
Set the random key for the model.
- Parameters:
rng (Any) – Random key to set for the model. JAX PRNGKey, integer seed, or None. If None, a new random key will be generated using JAX’s
random.key. If an integer is provided, it will be used as the seed to create the key.- Return type:
None
- set_state(state, /)#
Set the state of the model from a PyTree of PyTrees of numpy arrays.
- Parameters:
state (pmm.model_util.ModelState) – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the state of each module.
- Return type:
None
See also
ModelStateType alias for a PyTree of states in a model.
get_modulesGet the modules of the model, in the same structure as the state accepted by this method.
get_stateGet the state of the model, in the same structure as the state accepted by this method.
- train(X, /, Y=None, *, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, epochs=100, target_loss=-inf, early_stopping_patience=100, early_stopping_min_delta=-inf, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_rng=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)#
- Parameters:
X (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*features'], 'InStruct'])
Y (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)
Y_unc (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)
X_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*features'], 'InStruct'] | None)
Y_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)
Y_val_unc (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)
batch_size (int)
epochs (int)
target_loss (float)
early_stopping_patience (int)
early_stopping_min_delta (float)
callback (Callable | None)
unroll (int | None)
verbose (bool)
b1 (float)
b2 (float)
eps (float)
clip (float)
- Return type:
None