AffineObservablePMM#

parametricmatrixmodels.modules.AffineObservablePMM

class AffineObservablePMM(matrix_size=None, num_eig=None, which=None, smoothing=None, affine_bias_matrix=True, num_secondaries=1, output_size=None, centered=True, bias_term=True, use_expectation_values=False, Ms=None, Ds=None, b=None, init_magnitude=0.01)[source]#

Bases: SequentialModel

AffineObservablePMM is a module that implements a general regression model via the affine observable Parametric Matrix Model (PMM) using four primitive modules combined in a SequentialModel: a AffineHermitianMatrix module followed by an Eigenvectors module followed by a TransitionAmplitudeSum module (or an ExpectationValueSum) followed optionally by a Bias module.

The Affine Observable PMM (AOPMM) is described in [1] and is summarized as follows:

Given input features \(x_1, \ldots, x_p\), construct the Hermitian matrix that is affine in these features as

\[M(x) = M_0 + \sum_{i=1}^p x_i M_i\]

where \(M_0, \ldots, M_p\) are trainable Hermitian matrices. An optional smoothing term \(s C\) parameterized by the smoothing hyperparameter \(s\) can be added to smooth the eigenvalues and eigenvectors of \(M(x)\). The matrix \(C\) is equal to the imaginary unit times the sum of all commutators of the \(M_i\).

Then, take the leading \(r\) eigenvectors (by default corresponding to the largest magnitude eigenvalues if there is no smoothing, or the smallest algebraic if there is smoothing) of \(M(x)\) and compute the sum of the transition amplitudes of these eigenvectors with trainable Hermitian observable matrices (secondaries) \(D_{ql}\) to form the output vector \(z\) with \(q\) components as

\[z_k = \sum_{m=1}^l \left( \left[\sum_{i,j=1}^r |v_i^H D_{km} v_j|^2 \right] - \frac{r^2}{2} ||D_{km}||^2_2 \right)\]

or, if using expectation values instead of transition amplitudes,

\[z_k = \sum_{m=1}^l \left( \left[\sum_{i=1}^r v_i^H D_{km} v_i \right] - \frac{r}{2} ||D_{km}||^2_2 \right)\]

where \(||\cdot||_2\) is the operator 2-norm (largest singular value) so for Hermitian \(D\), \(||D||_2\) is the largest absolute eigenvalue.

The \(-\frac{1}{2} ||D_{km}||^2_2\) term centers the value of each term and can be disabled by setting the centered parameter to False.

Finally, an optional trainable bias term \(b_k\) can be added to each component.

Warning

Even though the math shows that the centering term should be multiplied by \(r^2\) or \(r\), in practice this doesn’t work well and instead setting the centering term to \(\frac{1}{2} ||D_{km}||^2_2\) works much better. This non-\(r^2\)/\(r\) scaling is used here.

See also

AffineHermitianMatrix

Module that constructs the affine Hermitian matrix \(M(x)\) from trainable Hermitian matrices \(M_i\) and input features.

Eigenvectors

Module that computes the eigenvectors of a matrix.

TransitionAmplitudeSum

Module that computes the sum of transition amplitudes of eigenvectors with trainable observable matrices.

ExpectationValueSum

Module that computes the sum of expectation values of eigenvectors with trainable observable matrices.

Bias

Module that adds a trainable bias term to the output.

MultiModule

Module that combines multiple modules in sequence.

LowRankAffineObservablePMM

Low-rank version of this module.

References

Parameters:
  • matrix_size (int)

  • num_eig (int)

  • which (str)

  • smoothing (float)

  • affine_bias_matrix (bool)

  • num_secondaries (int)

  • output_size (int)

  • centered (bool)

  • bias_term (bool)

  • use_expectation_values (bool)

  • Ms (np.ndarray)

  • Ds (np.ndarray)

  • b (np.ndarray)

  • init_magnitude (float)

__init__(matrix_size=None, num_eig=None, which=None, smoothing=None, affine_bias_matrix=True, num_secondaries=1, output_size=None, centered=True, bias_term=True, use_expectation_values=False, Ms=None, Ds=None, b=None, init_magnitude=0.01)[source]#

Initialize the module.

Parameters:
  • matrix_size (int) – Size of the trainable matrices, shorthand \(n\).

  • num_eig (int) – Number of eigenvectors to use in the transition amplitude calculation, shorthand \(r\).

  • which (str) – Which eigenvectors to use based on eigenvalue. Options are: - ‘SA’ for smallest algebraic (default with smoothing) - ‘LA’ for largest algebraic - ‘SM’ for smallest magnitude - ‘LM’ for largest magnitude (default without smoothing) - ‘EA’ for exterior algebraically - ‘EM’ for exterior by magnitude - ‘IA’ for interior algebraically - ‘IM’ for interior by magnitude

  • smoothing (float) – Optional smoothing parameter for the affine matrix. Set to None/0.0 to disable smoothing. Default is None/0.0 (no smoothing).

  • affine_bias_matrix (bool) – If True, include the bias term \(M_0\) in the affine matrix. Default is True.

  • num_secondaries (int) – Number of secondary observable matrices \(D_{km}\) per output component. Shorthand \(l\). Default is 1.

  • output_size (int) – Size of the output vector, shorthand \(q\).

  • centered (bool) – If True, include the centering term in the transition amplitude sum. Default is True.

  • bias_term (bool) – If True, include a trainable bias term \(b_k\) in the output. Default is True.

  • use_expectation_values (bool) – If True, use expectation values instead of transition amplitudes in the output calculation. Default is False.

  • Ms (Array) – Optional array of shape (input_size+1, matrix_size, matrix_size) (if affine_bias_matrix is True) or (input_size, matrix_size, matrix_size) (if affine_bias_matrix is False), containing the \(M_i\) Hermitian matrices. If not provided, the matrices will be initialized randomly when the module is compiled. Default is None (random initialization).

  • Ds (Array) – Optional array of shape (output_size, num_secondaries, matrix_size, matrix_size) containing the \(D_{km}\) Hermitian observable matrices. If not provided, the matrices will be initialized randomly when the module is compiled. Default is None (random initialization).

  • b (Array) – Optional array of shape (output_size,) containing the bias terms \(b_k\). If not provided, the bias terms will be randomly initialized when the module is compiled. Default is None (random initialization).

  • init_magnitude (float) – Initial magnitude for the random initialization. Default is 1e-2.

Warning

When using smoothing, the which options involving magnitude should be avoided, as the smoothing only guarantees that eigenvalues near each other algebraically are smoothed, not across the spectrum.

__call__

Call the model with the input data.

_get_callable

Returns a jax.jit-able and jax.grad-able callable that represents the model's forward pass.

add

Append a module to the end of the model.

add_module

Append a module to the end of the model.

append

Append a module to the end of the model.

append_module

Append a module to the end of the model.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the model for training by compiling each module.

deserialize

Deserialize the model from a dictionary.

from_file

Load a model from a file and return an instance of the Model class.

get_hyperparameters

Get the hyperparameters of the model as a dictionary.

get_modules

Get the modules of the model.

get_num_trainable_floats

Returns the number of trainable floats in the module.

get_output_shape

Get the output shape of the model given an input shape.

get_params

Get the parameters of the model.

get_rng

get_state

Get the state of the model.

grad_input

Doc TODO

grad_params

Doc TODO

insert

Insert a module at a specific index in the model.

insert_module

Insert a module at a specific index in the model.

is_ready

Return True if the module is initialized and ready for training or inference.

load

Load the model from a file.

pop

Remove and return a module by key or index in the model.

pop_module_by_index

Remove and return a module at a specific index in the model.

pop_module_by_key

Remove and return a module by key or index in the model.

predict

Call the model with the input data.

prepend

Prepend a module to the beginning of the model.

prepend_module

Prepend a module to the beginning of the model.

reset

save

Save the model to a file.

save_compressed

Save the model to a compressed file.

serialize

Serialize the model to a dictionary.

set_hyperparameters

Set the hyperparameters of the model from a dictionary.

set_params

Set the parameters of the model from a PyTree of PyTrees of numpy arrays.

set_precision

Set the precision of the model parameters and states.

set_rng

Set the random key for the model.

set_state

Set the state of the model from a PyTree of PyTrees of numpy arrays.

train

name

Returns the name of the module, unless overridden, this is the class name.

__call__(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#

Call the model with the input data.

Parameters:
  • X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.

  • rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.

  • return_state (bool) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool) – If True, the model will update the state of the model after evaluation. Default is False.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch.

Returns:

  • Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[Data, ModelState] | Data

_get_callable()#

Returns a jax.jit-able and jax.grad-able callable that represents the model’s forward pass.

This must be implemented by all subclasses.

This method must be implemented by all subclasses and must return a jax-jit-able and jax-grad-able callable in the form of

model_callable(
    params: parametricmatrixmodels.model_util.ModelParams,
    data: parametricmatrixmodels.typing.Data,
    training: bool,
    state: parametricmatrixmodels.model_util.ModelState,
    rng: Any,
) -> (
    output: parametricmatrixmodels.typing.Data,
    new_state: parametricmatrixmodels.model_util.ModelState,
    )

That is, all hyperparameters are traced out and the callable depends explicitly only on

  • the model’s parameters, as a PyTree with leaf nodes as JAX arrays,

  • the input data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the training flag, as a boolean,

  • the model’s state, as a PyTree with leaf nodes as JAX arrays

and returns

  • the output data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the new model state, as a PyTree with leaf nodes as JAX arrays. The

    PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.

The training flag will be traced out, so it doesn’t need to be jittable

Returns:

  • A callable that takes the model’s parameters, input data,

  • training flag, state, and rng key and returns the output data and

  • new state.

Return type:

pmm.model_util.ModelCallable

See also

__call__

Calls the model with the current parameters and given input, state, and rng.

ModelCallable

Typing for the callable returned by this method.

Params

Typing for the model parameters.

Data

Typing for the input and output data.

State

Typing for the model state.

_verify_input(X)#
Parameters:

X (pmm.typing.Data)

Return type:

None

add(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

add_module(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

append(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

append_module(module, /, key=None)#

Append a module to the end of the model.

Parameters:
  • module (BaseModule) – Module to append.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

astype(dtype, /)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (str | type[Any] | dtype | SupportsDType | int)

Return type:

Model

compile(rng, input_shape, verbose=False)[source]#

Compile the model for training by compiling each module. Must be implemented by all subclasses.

Parameters:
  • rng (Any | int | None) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.

  • input_shape (pmm.typing.DataShape) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.

  • verbose (bool) – Print debug information during compilation. Default is False.

Return type:

None

deserialize(data, /)#

Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.

Parameters:

data (Dict[str, Any]) – Dictionary containing the serialized model data.

Return type:

None

classmethod from_file(file, /)#

Load a model from a file and return an instance of the Model class.

Parameters:

file (str) – File to load the model from.

Returns:

Model – An instance of the Model class with the loaded parameters.

Return type:

Model

get_hyperparameters()[source]#

Get the hyperparameters of the model as a dictionary.

Returns:

Dict[str, Any] – Dictionary containing the hyperparameters of the model.

Return type:

pmm.typing.HyperParams

get_modules()#

Get the modules of the model.

Returns:

modules – PyTree of modules in the model. The structure of the PyTree will match that of the modules in the model.

Return type:

pmm.model_util.ModelModules

See also

ModelModules

Type alias for a PyTree of modules in a model.

get_num_trainable_floats()#

Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns 0. If the module is not ready, returns None.

Returns:

  • Number of trainable floats in the module, or None if the module

  • is not ready.

Return type:

int | None

get_output_shape(input_shape)[source]#

Get the output shape of the model given an input shape. Must be implemented by all subclasses.

Parameters:

input_shape (pmm.typing.DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.

Returns:

output_shape – Shape of the output after passing through the model.

Return type:

pmm.typing.DataShape

get_params()#

Get the parameters of the model.

Returns:

params – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the parameters of each module.

Return type:

pmm.model_util.ModelParams

See also

ModelParams

Type alias for a PyTree of parameters in a model.

get_modules

Get the modules of the model, in the same structure as the parameters returned by this method.

set_params

Set the parameters of the model from a corresponding PyTree of PyTrees of numpy arrays.

get_rng()#
Return type:

Any

get_state()#

Get the state of the model. The state is a PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the state of each module.

Returns:

state – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the state of each module.

Return type:

pmm.model_util.ModelState

See also

ModelState

Type alias for a PyTree of states in a model.

get_modules

Get the modules of the model, in the same structure as the state returned by this method.

set_state

Set the state of the model from a corresponding PyTree of PyTrees of numpy arrays.

grad_input(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#

Doc TODO

Parameters:
  • fwd (bool | None) – If True, use jax.jacfwd, otherwise use jax.jacrev. Default is None, which decides based on the input and output shapes.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch. If max_batch_size is set to 1, the gradient will be computed one sample at a time without batching. This case is particularly important for grad_input since the Jacobian contains gradients across different batch samples and thus scales with the square of the batch size.

  • X (PyTree[Inexact[Array, 'num_samples ...'], ' DataStruct'])

  • dtype (jax.type.DTypeLike)

  • rng (Any | int | None)

  • return_state (bool)

  • update_state (bool)

Returns:

  • PyTree – Gradient of the model output with respect to the input data, as a PyTree of JAX arrays, the structure of which matches that of the output structure of the model composed above the input data structure. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, input_dim1, input_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the input dimensions correspond to the shape of the input data for that leaf.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’], ModelState] | PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’]

grad_params(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#

Doc TODO

Parameters:
  • fwd (bool | None) – If True, use jax.jacfwd, otherwise use jax.jacrev. Default is None, which decides based on the input and output shapes.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch. Only applies if batched=True.

  • X (PyTree[jaxtyping.Inexact[Array, 'num_samples ...'], 'DataStruct'])

  • dtype (str | type[Any] | dtype | SupportsDType)

  • rng (Any | int | None)

  • return_state (bool)

  • update_state (bool)

Returns:

  • PyTree – Gradient of the model output with respect to the model parameters, as a PyTree of PyTrees of JAX arrays, the upper level structure of which matches that of the model’s modules, and the lower level structure of which matches that of the parameters of each module. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, param_dim1, param_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the param dimensions correspond to the shape of the parameter for that leaf.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

tuple[PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None], TypeAliasForwardRef(’pmm.model_util.ModelState’)] | PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None]

insert(index, module, /, key=None)#

Insert a module at a specific index in the model.

Parameters:
  • index (int) – Index to insert the module at.

  • module (BaseModule) – Module to insert.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

insert_module(index, module, /, key=None)#

Insert a module at a specific index in the model.

Parameters:
  • index (int) – Index to insert the module at.

  • module (BaseModule) – Module to insert.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

is_ready()#

Return True if the module is initialized and ready for training or inference.

Returns:

True if the module is ready, False otherwise.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

bool

load(file, /)#

Load the model from a file. Supports both compressed and uncompressed

Parameters:

file (str | IO | Path) – File to load the model from.

Return type:

None

property name: str#

Returns the name of the module, unless overridden, this is the class name.

Returns:

Name of the module.

pop(key, /)#

Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.

Returns:

The removed module.

Parameters:

key (tuple[KeyEntry, ...] | str | int)

Return type:

BaseModule

pop_module_by_index(index, /)#

Remove and return a module at a specific index in the model. :param index: Index of the module to remove.

Returns:

The removed module.

Parameters:

index (int)

Return type:

BaseModule

pop_module_by_key(key, /)#

Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.

Returns:

The removed module.

Parameters:

key (tuple[KeyEntry, ...] | str | int)

Return type:

BaseModule

predict(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#

Call the model with the input data.

Parameters:
  • X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.

  • dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.

  • rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.

  • return_state (bool) – If True, the model will return the state of the model after evaluation. Default is False.

  • update_state (bool) – If True, the model will update the state of the model after evaluation. Default is False.

  • max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is None, which means the input will be processed in a single batch.

Returns:

  • Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.

  • ModelState – If return_state is True, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.

Return type:

Tuple[Data, ModelState] | Data

prepend(module, /, key=None)#

Prepend a module to the beginning of the model.

Parameters:
  • module (BaseModule) – Module to prepend.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

prepend_module(module, /, key=None)#

Prepend a module to the beginning of the model.

Parameters:
  • module (BaseModule) – Module to prepend.

  • key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.

Return type:

None

reset()#
Return type:

None

save(file, /)#

Save the model to a file.

Parameters:

file (str | IO | Path) – File to save the model to.

Return type:

None

save_compressed(file, /)#

Save the model to a compressed file.

Parameters:

file (str | IO | Path) – File to save the model to.

Return type:

None

serialize()#

Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.

Returns:

Dict[str, Any]

Return type:

dict[str, Any]

set_hyperparameters(hyperparams)[source]#

Set the hyperparameters of the model from a dictionary.

Parameters:

hyperparams (Dict[str, Any]) – Dictionary containing the hyperparameters of the model.

Return type:

None

set_params(params, /)#

Set the parameters of the model from a PyTree of PyTrees of numpy arrays.

Parameters:

params (pmm.model_util.ModelParams) – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the parameters of each module.

Return type:

None

See also

ModelParams

Type alias for a PyTree of parameters in a model.

get_modules

Get the modules of the model, in the same structure as the parameters accepted by this method.

get_params

Get the parameters of the model, in the same structure as the parameters accepted by this method.

set_precision(prec, /)#

Set the precision of the model parameters and states.

Parameters:

prec (str | type[Any] | dtype | SupportsDType | int) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64

Return type:

None

set_rng(rng, /)#

Set the random key for the model.

Parameters:

rng (Any) – Random key to set for the model. JAX PRNGKey, integer seed, or None. If None, a new random key will be generated using JAX’s random.key. If an integer is provided, it will be used as the seed to create the key.

Return type:

None

set_state(state, /)#

Set the state of the model from a PyTree of PyTrees of numpy arrays.

Parameters:

state (pmm.model_util.ModelState) – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the state of each module.

Return type:

None

See also

ModelState

Type alias for a PyTree of states in a model.

get_modules

Get the modules of the model, in the same structure as the state accepted by this method.

get_state

Get the state of the model, in the same structure as the state accepted by this method.

train(X, /, Y=None, *, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, epochs=100, target_loss=-inf, early_stopping_patience=100, early_stopping_min_delta=-inf, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_rng=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)#
Parameters:
  • X (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*features'], 'InStruct'])

  • Y (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)

  • Y_unc (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)

  • X_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*features'], 'InStruct'] | None)

  • Y_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)

  • Y_val_unc (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)

  • loss_fn (str | Callable)

  • lr (float | Callable[[int], float])

  • batch_size (int)

  • epochs (int)

  • target_loss (float)

  • early_stopping_patience (int)

  • early_stopping_min_delta (float)

  • initialization_seed (Any | int | None)

  • callback (Callable | None)

  • unroll (int | None)

  • verbose (bool)

  • batch_rng (Any | int | None)

  • b1 (float)

  • b2 (float)

  • eps (float)

  • clip (float)

Return type:

None