AffineEigenvaluePMM#
parametricmatrixmodels.modules.AffineEigenvaluePMM
- class AffineEigenvaluePMM(matrix_size=None, num_eig=1, which='SA', smoothing=None, Ms=None, init_magnitude=0.01, bias_term=True)[source]#
Bases:
SequentialModelAffineEigenvaluePMMis a module that implements the affine eigenvalue Parametric Matrix Model (PMM) using two primitive modules combined in a SequentialModel: an AffineHermitianMatrix module followed by an Eigenvalues module.The Affine Eigenvalue PMM (AEPMM) is described in [1] and is summarized as follows:
Given input features \(x_1, \ldots, x_p\), construct the Hermitian matrix that is affine in these features as
\[M(x) = M_0 + \sum_{i=1}^p x_i M_i\]where \(M_0, \ldots, M_p\) are trainable Hermitian matrices. An optional smoothing term \(s C\) parameterized by the smoothing hyperparameter \(s\) can be added to smooth the eigenvalues and eigenvectors of \(M(x)\). The matrix \(C\) is equal to the imaginary unit times the sum of all commutators of the \(M_i\). The requested eigenvalues of \(M(x)\) are then computed and returned as the output of the module.
See also
AffineHermitianMatrixModule that constructs the affine Hermitian matrix \(M(x)\) from trainable Hermitian matrices \(M_i\) and input features.
EigenvaluesModule that computes the eigenvalues of a matrix.
SequentialModelModule and Model that evaluates multiple modules in sequence.
References
- Parameters:
- __init__(matrix_size=None, num_eig=1, which='SA', smoothing=None, Ms=None, init_magnitude=0.01, bias_term=True)[source]#
Initialize the
AffineEigenvaluePMMmodule.By default this module is initialized to compute the smallest algebraic eigenvalue (ground state).
- Parameters:
matrix_size (int) – Size of the PMM matrices, shorthand \(n\).
num_eig (int) – Number of eigenvalues to compute, shorthand \(k\). Default is 1.
which (str) –
Which eigenvalues to compute. Default is “SA”. Options are: - ‘SA’ for smallest algebraic (default) - ‘LA’ for largest algebraic - ‘SM’ for smallest magnitude - ‘LM’ for largest magnitude - ‘EA’ for exterior algebraically - ‘EM’ for exterior by magnitude - ‘IA’ for interior algebraically - ‘IM’ for interior by magnitude
For algebraic ‘which’ options, the eigenvalues are returned in ascending algebraic order.
For magnitude ‘which’ options, the eigenvalues are returned in ascending magnitude order.
smoothing (float) – Optional smoothing parameter. Set to
0.0to disable smoothing. Default isNone/0.0(no smoothing).Ms (Array) – Optional array of shape
(input_size+1, matrix_size, matrix_size)(ifbias_termisTrue) or(input_size, matrix_size, matrix_size)(ifbias_termisFalse), containing the \(M_i\) Hermitian matrices. If not provided, the matrices will be initialized randomly when the module is compiled. Default isNone(random initialization).init_magnitude (float) – Initial magnitude for the random matrices if
Msis not provided. Default is1e-2.bias_term (bool) – If
True, include the bias term \(M_0\) in the affine matrix. Default isTrue.
Warning
When using smoothing, the
whichoptions involving magnitude should be avoided, as the smoothing only guarantees that eigenvalues near each other algebraically are smoothed, not across the spectrum.
Call the model with the input data.
Returns a
jax.jit-able andjax.grad-able callable that represents the model's forward pass.Append a module to the end of the model.
Append a module to the end of the model.
Append a module to the end of the model.
Append a module to the end of the model.
Convenience wrapper to set_precision using the dtype argument, returns self.
Compile the model for training by compiling each module.
Deserialize the model from a dictionary.
Load a model from a file and return an instance of the Model class.
Get the hyperparameters of the model as a dictionary.
Get the modules of the model.
Returns the number of trainable floats in the module.
Get the output shape of the model given an input shape.
Get the parameters of the model.
Get the state of the model.
Doc TODO
Doc TODO
Insert a module at a specific index in the model.
Insert a module at a specific index in the model.
Return True if the module is initialized and ready for training or inference.
Load the model from a file.
Remove and return a module by key or index in the model.
Remove and return a module at a specific index in the model.
Remove and return a module by key or index in the model.
Call the model with the input data.
Prepend a module to the beginning of the model.
Prepend a module to the beginning of the model.
Save the model to a file.
Save the model to a compressed file.
Serialize the model to a dictionary.
Set the hyperparameters of the model from a dictionary.
Set the parameters of the model from a PyTree of PyTrees of numpy arrays.
Set the precision of the model parameters and states.
Set the random key for the model.
Set the state of the model from a PyTree of PyTrees of numpy arrays.
Returns the name of the module, unless overridden, this is the class name.
- __call__(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#
Call the model with the input data.
- Parameters:
X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.
rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.
return_state (bool) – If True, the model will return the state of the model after evaluation. Default is
False.update_state (bool) – If True, the model will update the state of the model after evaluation. Default is
False.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch.
- Returns:
Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[Data, ModelState] | Data
- _get_callable()#
Returns a
jax.jit-able andjax.grad-able callable that represents the model’s forward pass.This must be implemented by all subclasses.
This method must be implemented by all subclasses and must return a
jax-jit-able andjax-grad-able callable in the form ofmodel_callable( params: parametricmatrixmodels.model_util.ModelParams, data: parametricmatrixmodels.typing.Data, training: bool, state: parametricmatrixmodels.model_util.ModelState, rng: Any, ) -> ( output: parametricmatrixmodels.typing.Data, new_state: parametricmatrixmodels.model_util.ModelState, )
That is, all hyperparameters are traced out and the callable depends explicitly only on
the model’s parameters, as a PyTree with leaf nodes as JAX arrays,
- the input data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
the training flag, as a boolean,
the model’s state, as a PyTree with leaf nodes as JAX arrays
and returns
- the output data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
- the new model state, as a PyTree with leaf nodes as JAX arrays. The
PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.
The training flag will be traced out, so it doesn’t need to be jittable
- Returns:
A callable that takes the model’s parameters, input data,
training flag, state, and rng key and returns the output data and
new state.
- Return type:
pmm.model_util.ModelCallable
See also
__call__Calls the model with the current parameters and given input, state, and rng.
ModelCallableTyping for the callable returned by this method.
ParamsTyping for the model parameters.
DataTyping for the input and output data.
StateTyping for the model state.
- _verify_input(X)#
- Parameters:
X (pmm.typing.Data)
- Return type:
None
- add(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- add_module(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- append(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- append_module(module, /, key=None)#
Append a module to the end of the model.
- Parameters:
module (BaseModule) – Module to append.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- astype(dtype, /)#
Convenience wrapper to set_precision using the dtype argument, returns self.
- compile(rng, input_shape, verbose=False)[source]#
Compile the model for training by compiling each module. Must be implemented by all subclasses.
- Parameters:
rng (Any | int | None) – Random key for initializing the model parameters. JAX PRNGKey or integer seed.
input_shape (pmm.typing.DataShape) – Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input.
verbose (bool) – Print debug information during compilation. Default is False.
- Return type:
None
- deserialize(data, /)#
Deserialize the model from a dictionary. This is done by deserializing the model’s parameters/metadata and then deserializing each module.
- Parameters:
data (Dict[str, Any]) – Dictionary containing the serialized model data.
- Return type:
None
- classmethod from_file(file, /)#
Load a model from a file and return an instance of the Model class.
- get_hyperparameters()[source]#
Get the hyperparameters of the model as a dictionary.
- Returns:
Dict[str, Any] – Dictionary containing the hyperparameters of the model.
- Return type:
pmm.typing.HyperParams
- get_modules()#
Get the modules of the model.
- Returns:
modules – PyTree of modules in the model. The structure of the PyTree will match that of the modules in the model.
- Return type:
pmm.model_util.ModelModules
See also
ModelModulesType alias for a PyTree of modules in a model.
- get_num_trainable_floats()#
Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns
0. If the module is not ready, returnsNone.- Returns:
Number of trainable floats in the module, or None if the module
is not ready.
- Return type:
int | None
- get_output_shape(input_shape)[source]#
Get the output shape of the model given an input shape. Must be implemented by all subclasses.
- Parameters:
input_shape (pmm.typing.DataShape) – Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc.
- Returns:
output_shape – Shape of the output after passing through the model.
- Return type:
pmm.typing.DataShape
- get_params()#
Get the parameters of the model.
- Returns:
params – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the parameters of each module.
- Return type:
pmm.model_util.ModelParams
See also
ModelParamsType alias for a PyTree of parameters in a model.
get_modulesGet the modules of the model, in the same structure as the parameters returned by this method.
set_paramsSet the parameters of the model from a corresponding PyTree of PyTrees of numpy arrays.
- get_state()#
Get the state of the model. The state is a PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the state of each module.
- Returns:
state – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree will be a composite structure where the upper level structure matches that of the modules in the model, and the lower level structure matches that of the state of each module.
- Return type:
pmm.model_util.ModelState
See also
ModelStateType alias for a PyTree of states in a model.
get_modulesGet the modules of the model, in the same structure as the state returned by this method.
set_stateSet the state of the model from a corresponding PyTree of PyTrees of numpy arrays.
- grad_input(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#
Doc TODO
- Parameters:
fwd (bool | None) – If True, use
jax.jacfwd, otherwise usejax.jacrev. Default isNone, which decides based on the input and output shapes.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch. Ifmax_batch_sizeis set to1, the gradient will be computed one sample at a time without batching. This case is particularly important forgrad_inputsince the Jacobian contains gradients across different batch samples and thus scales with the square of the batch size.X (PyTree[Inexact[Array, 'num_samples ...'], ' DataStruct'])
dtype (jax.type.DTypeLike)
rng (Any | int | None)
return_state (bool)
update_state (bool)
- Returns:
PyTree – Gradient of the model output with respect to the input data, as a PyTree of JAX arrays, the structure of which matches that of the output structure of the model composed above the input data structure. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, input_dim1, input_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the input dimensions correspond to the shape of the input data for that leaf.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’], ModelState] | PyTree[Inexact[Array, ‘num_samples …’], ’… DataStruct’]
- grad_params(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, fwd=None, max_batch_size=None)#
Doc TODO
- Parameters:
fwd (bool | None) – If True, use
jax.jacfwd, otherwise usejax.jacrev. Default isNone, which decides based on the input and output shapes.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch. Only applies ifbatched=True.X (PyTree[jaxtyping.Inexact[Array, 'num_samples ...'], 'DataStruct'])
return_state (bool)
update_state (bool)
- Returns:
PyTree – Gradient of the model output with respect to the model parameters, as a PyTree of PyTrees of JAX arrays, the upper level structure of which matches that of the model’s modules, and the lower level structure of which matches that of the parameters of each module. Each leaf array will have shape (num_samples, output_dim1, output_dim2, …, param_dim1, param_dim2, …), where the output dimensions correspond to the shape of the model output for that leaf, and the param dimensions correspond to the shape of the parameter for that leaf.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
tuple[PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None], TypeAliasForwardRef(’pmm.model_util.ModelState’)] | PyTree[jaxtyping.Inexact[Array, ‘num_samples …’] | tuple[] | None]
- insert(index, module, /, key=None)#
Insert a module at a specific index in the model.
- Parameters:
index (int) – Index to insert the module at.
module (BaseModule) – Module to insert.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- insert_module(index, module, /, key=None)#
Insert a module at a specific index in the model.
- Parameters:
index (int) – Index to insert the module at.
module (BaseModule) – Module to insert.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- is_ready()#
Return True if the module is initialized and ready for training or inference.
- Returns:
Trueif the module is ready,Falseotherwise.- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
- load(file, /)#
Load the model from a file. Supports both compressed and uncompressed
- property name: str#
Returns the name of the module, unless overridden, this is the class name.
- Returns:
Name of the module.
- pop(key, /)#
Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.
- Returns:
The removed module.
- Parameters:
- Return type:
- pop_module_by_index(index, /)#
Remove and return a module at a specific index in the model. :param index: Index of the module to remove.
- Returns:
The removed module.
- Parameters:
index (int)
- Return type:
- pop_module_by_key(key, /)#
Remove and return a module by key or index in the model. :param key: Key or index of the module to remove.
- Returns:
The removed module.
- Parameters:
- Return type:
- predict(X, /, *, dtype=<class 'jax.numpy.float64'>, rng=None, return_state=False, update_state=False, max_batch_size=None)#
Call the model with the input data.
- Parameters:
X (Data) – Input array of shape (batch_size, <input feature axes>). For example, (batch_size, input_features) for a 1D input or (batch_size, input_height, input_width, input_channels) for a 3D input.
dtype (jax.typing.DTypeLike) – Data type of the output array. Default is jax.numpy.float64. It is strongly recommended to perform training in single precision (float32 and complex64) and inference with double precision inputs (float64, the default here) with single precision weights. Default is float64.
rng (Any | int | None) – JAX random key for stochastic modules. Default is None. If None, the saved rng key will be used if it exists, which would be the final rng key from the last training run. If an integer is provided, it will be used as the seed to create a new JAX random key. Default is the saved rng key if it exists, otherwise a new random key will be generated.
return_state (bool) – If True, the model will return the state of the model after evaluation. Default is
False.update_state (bool) – If True, the model will update the state of the model after evaluation. Default is
False.max_batch_size (int | None) – If provided, the input will be split into batches of at most this size and processed sequentially to avoid OOM errors. Default is
None, which means the input will be processed in a single batch.
- Returns:
Data – Output data as a PyTree of JAX arrays, the structure and shape of which is determined by the model’s specific modules.
ModelState – If
return_stateisTrue, the state of the model after evaluation as a PyTree of PyTrees of JAX arrays, the structure of which matches that of the model’s modules.
- Return type:
Tuple[Data, ModelState] | Data
- prepend(module, /, key=None)#
Prepend a module to the beginning of the model.
- Parameters:
module (BaseModule) – Module to prepend.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- prepend_module(module, /, key=None)#
Prepend a module to the beginning of the model.
- Parameters:
module (BaseModule) – Module to prepend.
key (tuple[KeyEntry, ...] | str | int | None) – Key to name the module if modules are stored in a dictionary. If None and modules are stored in a dictionary, a UUID will be generated and used as the key. Default is None. Ignored if modules are stored in a list, tuple, or other structure.
- Return type:
None
- reset()#
- Return type:
None
- save(file, /)#
Save the model to a file.
- save_compressed(file, /)#
Save the model to a compressed file.
- serialize()#
Serialize the model to a dictionary. This is done by serializing the model’s parameters/metadata and then serializing each module.
- set_hyperparameters(hyperparams)[source]#
Set the hyperparameters of the model from a dictionary.
- Parameters:
hyperparams (Dict[str, Any]) – Dictionary containing the hyperparameters of the model.
- Return type:
None
- set_params(params, /)#
Set the parameters of the model from a PyTree of PyTrees of numpy arrays.
- Parameters:
params (pmm.model_util.ModelParams) – PyTree of PyTrees of numpy arrays representing the parameters of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the parameters of each module.
- Return type:
None
See also
ModelParamsType alias for a PyTree of parameters in a model.
get_modulesGet the modules of the model, in the same structure as the parameters accepted by this method.
get_paramsGet the parameters of the model, in the same structure as the parameters accepted by this method.
- set_precision(prec, /)#
Set the precision of the model parameters and states.
- Parameters:
prec (str | type[Any] | dtype | SupportsDType | int) – Precision to set for the model parameters and states. Valid options are: [for 32-bit precision (all options are equivalent)] - np.float32, np.complex64, “float32”, “complex64” - “single”, “f32”, “c64”, 32 [for 64-bit precision (all options are equivalent)] - np.float64, np.complex128, “float64”, “complex128” - “double”, “f64”, “c128”, 64
- Return type:
None
- set_rng(rng, /)#
Set the random key for the model.
- Parameters:
rng (Any) – Random key to set for the model. JAX PRNGKey, integer seed, or None. If None, a new random key will be generated using JAX’s
random.key. If an integer is provided, it will be used as the seed to create the key.- Return type:
None
- set_state(state, /)#
Set the state of the model from a PyTree of PyTrees of numpy arrays.
- Parameters:
state (pmm.model_util.ModelState) – PyTree of PyTrees of numpy arrays representing the state of each module in the model. The structure of the PyTree must match that of the modules in the model, and the lower level structure must match that of the state of each module.
- Return type:
None
See also
ModelStateType alias for a PyTree of states in a model.
get_modulesGet the modules of the model, in the same structure as the state accepted by this method.
get_stateGet the state of the model, in the same structure as the state accepted by this method.
- train(X, /, Y=None, *, Y_unc=None, X_val=None, Y_val=None, Y_val_unc=None, loss_fn='mse', lr=0.001, batch_size=32, epochs=100, target_loss=-inf, early_stopping_patience=100, early_stopping_min_delta=-inf, initialization_seed=None, callback=None, unroll=None, verbose=True, batch_rng=None, b1=0.9, b2=0.999, eps=1e-08, clip=1000.0)#
- Parameters:
X (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*features'], 'InStruct'])
Y (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)
Y_unc (PyTree[jaxtyping.Inexact[Array, 'num_samples ?*targets'], 'OutStruct'] | None)
X_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*features'], 'InStruct'] | None)
Y_val (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)
Y_val_unc (PyTree[jaxtyping.Inexact[Array, 'num_val_samples ?*targets'], 'OutStruct'] | None)
batch_size (int)
epochs (int)
target_loss (float)
early_stopping_patience (int)
early_stopping_min_delta (float)
callback (Callable | None)
unroll (int | None)
verbose (bool)
b1 (float)
b2 (float)
eps (float)
clip (float)
- Return type:
None