NonSequentialModel#

parametricmatrixmodels.NonSequentialModel

class NonSequentialModel(modules=None, connections=None, /, *, rng=None, separator='.')[source]#

Bases: Model

A nonsequential model that chains modules (or other models) together with directed connections.

For confidence intervals or uncertainty quantification, wrap a trained model with ConformalModel.

See also

jax.tree

PyTree utilities and concepts in JAX.

Model

Abstract base class for all models.

SequentialModel

A model that applies modules in sequence.

ConformalModel

Wrap a trained model to produce confidence intervals.

Parameters:
  • modules (ModelModules | BaseModule | None)

  • connections (Dict[str, str | List[str] | Tuple[str, ...]] | None)

  • rng (Any | int | None)

  • separator (str)

__init__(modules=None, connections=None, /, *, rng=None, separator='.')[source]#

Initialize a nonsequential model with a PyTree of modules and a random key.

Parameters:
  • modules (ModelModules | BaseModule | None) – module(s) to initialize the model with. Default is None, which will become an empty dictionary. Can be a single module, which will be wrapped in a list, or a PyTree of modules (e.g., nested lists, tuples, or dictionaries).

  • connections (Dict[str, str | List[str] | Tuple[str, ...]] | None) – Directed connections between module input and outputs in the model. Keys are period-separated paths of module outputs, and values are lists or tuples of period-separated paths of module inputs that receive the output. The reserved keys “input” and “output” refer to the model input and output, respectively. The separator can be changed from the default period using the separator argument. Default is None, which will become an empty dictionary.

  • rng (Any | int | None) – Initial random key for the model. Default is None. If None, a new random key will be generated using JAX’s random.key. If an integer is provided, it will be used as the seed to create the key.

  • separator (str) – Separator string to use for denoting paths in the connections dictionary. Default is “.”.

Return type:

None

Examples

To denote a sequential model where all modules expect a bare array and produce a bare array:

>>> modules = [Module1(), Module2(), Module3()]
>>> connections = {
...     "input": "0",
...     "0": "1",
...     "1": "2",
...     "2": "output"
... }
>>> model = NonSequentialModel(modules, connections)

or equivalently to name the modules:

>>> modules = {"M0": Module1(), "M1": Module2(), "M2": Module3()}
>>> connections = {
...     "input": "M0",
...     "M0": "M1",
...     "M1": "M2",
...     "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
or equivalently to use nested structures:
>>> modules = {
...     "block1": [Module1(), Module2()],
...     "block2": {"M3": Module3()}
... }
>>> connections = {
...     "input": "block1.0",
...     "block1.0": "block1.1",
...     "block1.1": "block2.M3",
...     "block2.M3": "output"
... }
>>> model = NonSequentialModel(modules, connections)

All three of the above will produce a model that applies the same three modules sequentially.

If a module outputs a PyTree of arrays, or expects a PyTree of arrays as input, the connections can specify the leaf nodes using the same period-separated path syntax. For example, if Module1 outputs a dict with keys “a” and “b”, and Module2 expects a tuple of two arrays as input, the connections can be specified as:

>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
...     "input": "M1",
...     "M1.a": "M2.1",
...     "M1.b": "M2.0",
...     "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)

This will send the “a” output of Module1 to the second input of Module2, and the “b” output of Module1 to the first input of Module2.

Note

If a module expects a Tuple or List as input, it is best to write the module to accept both Tuple and List types, since the specific input type between List and Tuple cannot be inferred at compile time.

If the entire model input or output is a PyTree of arrays, the connections use the same period-separated path syntax with the reserved keys “input” and “output”. For example, if the model input is a dict with keys “x1” and “x2”, and the model output is a Tuple of two arrays, the connections can be specified as:

>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
...     "input.x1": "M1",
...     "M1": "M2",
...     "M2": "output.1",
...     "input.x2": "output.0"
... }
>>> model = NonSequentialModel(modules, connections)

This will perform a sequential model on the “x1” input through Module1 and Module2, sending the output to the second output of the model, and will send the “x2” input directly to the first output of the model unchanged.

Modules that output PyTrees need not be fully traversed if entire subtrees are to be passed between modules. For example, if Module1 outputs a dict with keys “a” and “b”, and Module2 expects a dict with keys “a” and “b” as input, the connections can be specified as:

>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
...     "input": "M1",
...     "M1": "M2",
...     "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)

or equivalently:

>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
...     "input": "M1",
...     "M1.a": "M2.a",
...     "M1.b": "M2.b",
...     "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)

Both ways will pass the entire output dict of Module1 to Module2.

Module ouputs can be sent to multiple module inputs by specifying a list or tuple of input paths in the connections dictionary. For example, to send the output of Module1 to both Module2 and Module3:

>>> modules = {"M1": Module1(), "M2": Module2(), "M3": Module3()}
>>> connections = {
...     "input": "M1",
...     "M1": ["M2", "M3"],
...     "M2": "output.0",
...     "M3": "output.1"
... }
>>> model = NonSequentialModel(modules, connections)

This will create a model that sends the output of Module1 to both Module2 and Module3 in parallel, and collects their outputs as a Tuple as the model output.

The order of the connections in the dictionary does not matter, as long as the connections form a valid directed acyclic graph from the model input to the model output. It is not necessary to use all parts of the model input, or all modules. However, this will raise a warning. It is not necessary and will not raise a warning if some parts of the outputs of some modules are not used, but all inputs of all present modules must be connected.

See also

ModelModules

Type alias for a PyTree of modules in a model.

jax.random.key

JAX function to create a random key.

jax.tree_util.keystr

JAX function to create string paths for PyTree KeyPaths in the format expected by the connections dictionary.

__call__

Call the model with the input data.

_get_callable

Returns a jax.jit-able and jax.grad-able callable that represents the model's forward pass.

add

Append a module to the end of the model.

add_module

Append a module to the end of the model.

append

Append a module to the end of the model.

append_module

Append a module to the end of the model.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the model for training by finding the execution order of the directed graph defined by the connections, and compiling each module in that order.

deserialize

Deserialize the model from a dictionary.

from_file

Load a model from a file and return an instance of the Model class.

get_execution_order

Resolve the connections dictionary to find the execution order of module execution.

get_hyperparameters

Get the hyperparameters of the model as a dictionary.

get_modules

Get the modules of the model.

get_num_trainable_floats

Returns the number of trainable floats in the module.

get_output_shape

Get the output shape of the model given an input shape.

get_params

Get the parameters of the model.

get_rng

get_state

Get the state of all modules in the model as a PyTree.

grad_input

Doc TODO

grad_params

Doc TODO

insert

Insert a module at a specific index in the model.

insert_module

Insert a module at a specific index in the model.

is_ready

Check if the model is compiled and ready for use.

load

Load the model from a file.

pop

Remove and return a module by key or index in the model.

pop_module_by_index

Remove and return a module at a specific index in the model.

pop_module_by_key

Remove and return a module by key or index in the model.

predict

Call the model with the input data.

prepend

Prepend a module to the beginning of the model.

prepend_module

Prepend a module to the beginning of the model.

reset

Reset the compiled state of the model.

save

Save the model to a file.

save_compressed

Save the model to a compressed file.

serialize

Serialize the model to a dictionary.

set_hyperparameters

Set the hyperparameters of the model from a dictionary.

set_params

Set the parameters of the model from a PyTree of PyTrees of numpy arrays.

set_precision

Set the precision of the model parameters and states.

set_rng

Set the random key for the model.

set_state

Set the state of the model from a PyTree of PyTrees of numpy arrays.

train

name

Returns the name of the module, unless overridden, this is the class name.

__call__(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#

Call the model with the input data.

Parameters:
  • X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.

  • rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.

  • return_state (bool) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool) – If True, the model will update the state of the model after evaluation. Default is False.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch.

Returns:

  • Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[Data, ModelState] | Data

_get_callable()[source]#

Returns a jax.jit-able and jax.grad-able callable that represents the model’s forward pass.

This must be implemented by all subclasses.

This method must be implemented by all subclasses and must return a jax-jit-able and jax-grad-able callable in the form of

model_callable(
    params: parametricmatrixmodels.model_util.ModelParams,
    data: parametricmatrixmodels.typing.Data,
    training: bool,
    state: parametricmatrixmodels.model_util.ModelState,
    rng: Any,
) -> (
    output: parametricmatrixmodels.typing.Data,
    new_state: parametricmatrixmodels.model_util.ModelState,
    )

That is, all hyperparameters are traced out and the callable depends explicitly only on

  • the model’s parameters, as a PyTree with leaf nodes as JAX arrays,

  • the input data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the training flag, as a boolean,

  • the model’s state, as a PyTree with leaf nodes as JAX arrays

and returns

  • the output data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the new model state, as a PyTree with leaf nodes as JAX arrays. The

    PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.

The training flag will be traced out, so it doesn’t need to be jittable

Returns:

  • A callable that takes the model’s parameters, input data,

  • training flag, state, and rng key and returns the output data and

  • new state.

Return type:

pmm.model_util.ModelCallable

See also

__call__

Calls the model with the current parameters and given input, state, and rng.

ModelCallable

Typing for the callable returned by this method.

Params

Typing for the model parameters.

Data

Typing for the input and output data.

State

Typing for the model state.

_get_module_input_dependencies()[source]#

Get the input dependencies for each module in the execution order.

Returns:

  • module_input_dependencies – List of PyTrees of str paths to the required inputs for each module in the execution order. The first entry corresponds to the “input” node and is None.

  • output_input_dependencies – PyTree of str paths to the required inputs for the “output” node.

Return type:

tuple[list[PyTree[str]], PyTree[str]]

_get_shape_progression(input_shape, /)[source]#

Get the progression of output shapes through the model given an input shape. The first entry is the model input shape, and the last entry is the model output shape.

Parameters:

input_shape (DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.

Returns:

  • input_shapes – PyTree of input shapes at each module in the execution order, with the same structure as the modules in the execution order.

  • output_shapes – PyTree of output shapes at each module in the execution order, with the same structure as the modules in the execution order.

  • output_shape – Shape of the output after passing through the model.

Return type:

Tuple[PyTree[DataShape | None], PyTree[DataShape | None], DataShape]

_verify_input(X)#
Parameters:

X (pmm.typing.Data)

Return type:

None

add(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

add_module(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

append(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

append_module(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

astype(dtype, /)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (str | type[Any] | dtype | SupportsDType | int)

Return type:

Model

compile(rng, input_shape, /, *, verbose=False)[source]#

Compile the model for training by finding the execution order of the directed graph defined by the connections, and compiling each module in that order.

Parameters:
  • rng (Any | int | None) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.

  • input_shape (pmm.typing.DataShape) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.

  • verbose (bool) – Print debug information during compilation. Default is False.

Raises:

ValueError – If the connections do not form a valid directed acyclic graph from the model input to the model output.

Return type:

None

deserialize(data, /)#

Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.

Parameters:

data (Dict[str, Any]) – Dictionary containing the serialized model data.

Return type:

None

classmethod from_file(file, /)#

Load a model from a file and return an instance of the Model class.

Parameters:

file (str) – File to load the model from.

Returns:

Model – An instance of the Model class with the loaded parameters.

Return type:

Model

get_execution_order()[source]#

Resolve the connections dictionary to find the execution order of module execution.

Raises:

ValueError – If the connections do not form a valid directed acyclic graph from the model input to the model output.

Return type:

list[str]

get_hyperparameters()[source]#

Get the hyperparameters of the model as a dictionary.

Returns:

Dict[str, Any] – Dictionary containing the hyperparameters of the model.

Return type:

pmm.typing.HyperParams

get_modules()#

Get the modules of the model.

Returns:

modules – PyTree of modules in the model. The structure of the PyTree will match that of the modules in the model.

Return type:

pmm.model_util.ModelModules

See also

ModelModules

Type alias for a PyTree of modules in a model.

get_num_trainable_floats()#

Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns 0. If the module is not ready, returns None.

Returns:

  • Number of trainable floats in the module, or None if the module

  • is not ready.

Return type:

int | None

get_output_shape(input_shape, /)[source]#

Get the output shape of the model given an input shape. Must be implemented by all subclasses.

Parameters:

input_shape (pmm.typing.DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.

Returns:

output_shape – Shape of the output after passing through the model.

Return type:

pmm.typing.DataShape

get_params()#

Get the parameters of the model.

Returns:

params – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the parameters of each module.

Return type:

pmm.model_util.ModelParams

See also

ModelParams

Type alias for a PyTree of parameters in a model.

get_modules

Get the modules of the model, in the same structure as the parameters returned by this method.

set_params

Set the parameters of the model from a corresponding PyTree of PyTrees of numpy arrays.

get_rng()#
Return type:

Any

get_state()[source]#

Get the state of all modules in the model as a PyTree.

Override of the base method in order to ignore modules that are not in the execution order.

Returns:

  • A PyTree of the states of all modules in the model with the same

  • structure as the modules PyTree. Modules that are not in the

  • execution order will have state None.

Return type:

pmm.model_util.ModelState

grad_input(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#

Doc TODO

Parameters:
  • fwd (bool | None) – If True, use jax.jacfwd, otherwise use jax.jacrev. Default is None, which decides based on the input and output shapes.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch. If max_batch_size is set to 1, the gradient will be computed one sample at a time without batching. This case is particularly important for grad_input since the Jacobian contains gradients across different batch samples and thus scales with the square of the batch size.

  • X (PyTree[Inexact[Array, 'num_samples ...'], ' DataStruct'])

  • dtype (jax.type.DTypeLike)

  • rng (Any | int | None)

  • return_state (bool)

  • update_state (bool)

Returns:

  • PyTree – Gradient of the model output with respect to the input data, as a PyTree of JAX arrays, the structure of which matches that of the output structure of the model composed above the input data structure. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, input_dim1, input_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the input dimensions correspond to the shape of the input data for that leaf.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’], ModelState] | PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’]

grad_params(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#

Doc TODO

Parameters:
  • fwd (bool | None) – If True, use jax.jacfwd, otherwise use jax.jacrev. Default is None, which decides based on the input and output shapes.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch. Only applies if batched=True.

  • X (PyTree[jaxtyping.Inexact[Array, 'num_samples ...'], 'DataStruct'])

  • dtype (str | type[Any] | dtype | SupportsDType)

  • rng (Any | int | None)

  • return_state (bool)

  • update_state (bool)

Returns:

  • PyTree – Gradient of the model output with respect to the model parameters, as a PyTree of PyTrees of JAX arrays, the upper level structure of which matches that of the model’s modules, and the lower level structure of which matches that of the parameters of each module. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, param_dim1, param_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the param dimensions correspond to the shape of the parameter for that leaf.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

tuple[PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None], TypeAliasForwardRef(’pmm.model_util.ModelState’)] | PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None]

insert(index, module, /, key=None)#

Insert a module at a specific index in the model.

Parameters:
  • index (int) – Index to insert the module at.

  • module (BaseModule) – Module to insert.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

insert_module(index, module, /, key=None)#

Insert a module at a specific index in the model.

Parameters:
  • index (int) – Index to insert the module at.

  • module (BaseModule) – Module to insert.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

is_ready()[source]#

Check if the model is compiled and ready for use. Overrides the base implementation since not all modules need to be ready, as some may not appear in the execution order.

Returns:

True if the model is compiled and ready, False otherwise.

Return type:

bool

load(file, /)#

Load the model from a file. Supports both compressed and uncompressed

Parameters:

file (str | IO | Path) – File to load the model from.

Return type:

None

property name: str#

Returns the name of the module, unless overridden, this is the class name.

Returns:

Name of the module.

pop(key, /)#

Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.

Returns:

The removed module.

Parameters:

key (tuple[KeyEntry, ...] | str | int)

Return type:

BaseModule

pop_module_by_index(index, /)#

Remove and return a module at a specific index in the model. :param index: Index of the module to remove.

Returns:

The removed module.

Parameters:

index (int)

Return type:

BaseModule

pop_module_by_key(key, /)#

Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.

Returns:

The removed module.

Parameters:

key (tuple[KeyEntry, ...] | str | int)

Return type:

BaseModule

predict(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#

Call the model with the input data.

Parameters:
  • X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.

  • rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.

  • return_state (bool) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool) – If True, the model will update the state of the model after evaluation. Default is False.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch.

Returns:

  • Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[Data, ModelState] | Data

prepend(module, /, key=None)#

Prepend a module to the beginning of the model.

Parameters:
  • module (BaseModule) – Module to prepend.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

prepend_module(module, /, key=None)#

Prepend a module to the beginning of the model.

Parameters:
  • module (BaseModule) – Module to prepend.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

reset()[source]#

Reset the compiled state of the model. This will require recompilation before the model can be used again.

Return type:

None

save(file, /)#

Save the model to a file.

Parameters:

file (str | IO | Path) – File to save the model to.

Return type:

None

save_compressed(file, /)#

Save the model to a compressed file.

Parameters:

file (str | IO | Path) – File to save the model to.

Return type:

None

serialize()#

Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.

Returns:

Dict[str, Any]

Return type:

dict[str, Any]

set_hyperparameters(hyperparams, /)[source]#

Set the hyperparameters of the model from a dictionary.

Parameters:

hyperparams (Dict[str, Any]) – Dictionary containing the hyperparameters of the model.

Return type:

None

set_params(params, /)#

Set the parameters of the model from a PyTree of PyTrees of numpy arrays.

Parameters:

params (pmm.model_util.ModelParams) – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the parameters of each module.

Return type:

None

See also

ModelParams

Type alias for a PyTree of parameters in a model.

get_modules

Get the modules of the model, in the same structure as the parameters accepted by this method.

get_params

Get the parameters of the model, in the same structure as the parameters accepted by this method.

set_precision(prec, /)#

Set the precision of the model parameters and states.

Parameters:

prec (str | type[Any] | dtype | SupportsDType | int) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64

Return type:

None

set_rng(rng, /)#

Set the random key for the model.

Parameters:

rng (Any) – Random key to set for the model. JAX PRNGKey, integer seed, or None. If None, a new random key will be generated using JAX’s random.key. If an integer is provided, it will be used as the seed to create the key.

Return type:

None

set_state(state, /)#

Set the state of the model from a PyTree of PyTrees of numpy arrays.

Parameters:

state (pmm.model_util.ModelState) – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the state of each module.

Return type:

None

See also

ModelState

Type alias for a PyTree of states in a model.

get_modules

Get the modules of the model, in the same structure as the state accepted by this method.

get_state

Get the state of the model, in the same structure as the state accepted by this method.

train(X, /, Y=None, *, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, epochs=100, target_loss=-inf, early_stopping_patience=100, early_stopping_min_delta=-inf, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_rng=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)#
Parameters:
  • X (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*features'], 'InStruct'])

  • Y (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)

  • Y_unc (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)

  • X_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*features'], 'InStruct'] | None)

  • Y_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)

  • Y_val_unc (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)

  • loss_fn (str | Callable)

  • lr (float | Callable[[int], float])

  • batch_size (int)

  • epochs (int)

  • target_loss (float)

  • early_stopping_patience (int)

  • early_stopping_min_delta (float)

  • initialization_seed (Any | int | None)

  • callback (Callable | None)

  • unroll (int | None)

  • verbose (bool)

  • batch_rng (Any | int | None)

  • b1 (float)

  • b2 (float)

  • eps (float)

  • clip (float)

Return type:

None