LowRankAffineObservablePMM#
parametricmatrixmodels.modules.LowRankAffineObservablePMM
- class LowRankAffineObservablePMM(matrix_size=None, primary_rank=None, num_eig=None, which=None, smoothing=None, affine_bias_matrix=True, num_secondaries=1, secondary_rank=None, output_size=None, centered=True, bias_term=True, lambdaMs=None, uMs=None, lambdaDs=None, uDs=None, b=None, init_magnitude=0.01)[source]#
Bases:
MultiModule
LowRankAffineObservablePMM
is a module that implements a general regression model via the affine observable Parametric Matrix Model (PMM) with low-rank trainable matrices using four primitive modules combined in a MultiModule: a LowRankAffineHermitianMatrix module followed by an Eigenvectors module followed by a LowrankTransitionAmplitudeSum module followed optionally by a Bias module.The Affine Observable PMM (AOPMM) is described in [1] and is summarized as follows:
Given input features \(x_1, \ldots, x_p\), construct the Hermitian matrix that is affine in these features as
\[M(x) = M_0 + \sum_{i=1}^p x_i M_i\]where \(M_0, \ldots, M_p\) are trainable Hermitian matrices. An optional smoothing term \(s C\) parameterized by the smoothing hyperparameter \(s\) can be added to smooth the eigenvalues and eigenvectors of \(M(x)\). The matrix \(C\) is equal to the imaginary unit times the sum of all commutators of the \(M_i\).
Then, take the leading \(r\) eigenvectors (by default corresponding to the largest magnitude eigenvalues if there is no smoothing, or the smallest algebraic if there is smoothing) of \(M(x)\) and compute the sum of the transition amplitudes of these eigenvectors with trainable Hermitian observable matrices (secondaries) \(D_{ql}\) to form the output vector \(z\) with \(q\) components as
\[z_k = \sum_{m=1}^l \left( \left[\sum_{i,j=1}^r |v_i^H D_{km} v_j|^2 \right] - \frac{r^2}{2} ||D_{km}||^2_2 \right)\]where \(||\cdot||_2\) is the operator 2-norm (largest singular value) so for Hermitian \(D\), \(||D||_2\) is the largest absolute eigenvalue.
The \(-\frac{1}{2} ||D_{km}||^2_2\) term centers the value of each term and can be disabled by setting the
centered
parameter toFalse
.Finally, an optional trainable bias term \(b_k\) can be added to each component.
In this module, the trainable Hermitian matrices, both \(M_i\) and \(D_{km}\), are parametrized in low-rank form by sums of rank-1 terms constructed from outer products of complex vectors. This reduces the number of trainable parameters.
Warning
Even though the math shows that the centering term should be multiplied by \(r^2\), in practice this doesn’t work well and instead setting the centering term to \(\frac{1}{2} ||D_{km}||^2_2\) works much better. This non-\(r^2\) scaling is used here.
See also
LowRankAffineHermitianMatrix
Module that constructs the affine Hermitian matrix \(M(x)\) from low-rank trainable Hermitian matrices \(M_i\) and input features.
Eigenvectors
Module that computes the eigenvectors of a matrix.
LowRankTransitionAmplitudeSum
Module that computes the sum of transition amplitudes of eigenvectors with trainable low-rank observable matrices.
Bias
Module that adds a trainable bias term to the output.
MultiModule
Module that combines multiple modules in sequence.
AffineObservablePMM
Full-rank version of this module.
References
- __init__(matrix_size=None, primary_rank=None, num_eig=None, which=None, smoothing=None, affine_bias_matrix=True, num_secondaries=1, secondary_rank=None, output_size=None, centered=True, bias_term=True, lambdaMs=None, uMs=None, lambdaDs=None, uDs=None, b=None, init_magnitude=0.01)[source]#
Initialize the module.
- Parameters:
matrix_size (
int
) – Size of the trainable matrices, shorthand \(n\).primary_rank (
int
) – Rank of the trainable Hermitian matrices \(M_i\).num_eig (
int
) – Number of eigenvectors to use in the transition amplitude calculation, shorthand \(r\).which (
str
) – Which eigenvectors to use based on eigenvalue. Options are: - ‘SA’ for smallest algebraic (default with smoothing) - ‘LA’ for largest algebraic - ‘SM’ for smallest magnitude - ‘LM’ for largest magnitude (default without smoothing) - ‘EA’ for exterior algebraically - ‘EM’ for exterior by magnitude - ‘IA’ for interior algebraically - ‘IM’ for interior by magnitudesmoothing (
float
) – Optional smoothing parameter for the affine matrix. Set toNone
/0.0
to disable smoothing. Default isNone
/0.0
(no smoothing).affine_bias_matrix (
bool
) – IfTrue
, include the bias term \(M_0\) in the affine matrix. Default isTrue
.num_secondaries (
int
) – Number of secondary observable matrices \(D_{km}\) per output component. Shorthand \(l\). Default is1
.secondary_rank (
int
) – Rank of the trainable Hermitian observable matrices \(D_{km}\).output_size (
int
) – Size of the output vector, shorthand \(q\).centered (
bool
) – IfTrue
, include the centering term in the transition amplitude sum. Default isTrue
.bias_term (
bool
) – IfTrue
, include a trainable bias term \(b_k\) in the output. Default isTrue
.lambdaMs (
Array
) – Optional array of shape(input_size+1, primary_rank)
(ifaffine_bias_matrix
isTrue
) or(input_size, primary_rank)
(ifbias_term
isFalse
), containing the lambda_k^i real coefficients used to construct the low-rank \(M_i\) matrices. If not provided, the coefficients will be initialized randomly when the module is compiled. Default isNone
(random initialization).uMs (
Array
) – Optional array of shape(input_size+1, primary_rank, matrix_size)
(ifaffine_bias_matrix
isTrue
) or(input_size, primary_rank, matrix_size)
(ifaffine_bias_matrix
isFalse
), containing the complex vectors which parameterize the low-rank \(M_i\) Hermitian matrices. If not provided, the vectors will be initialized randomly when the module is compiled. Default isNone
(random initialization).lambdaDs (
Array
) – Optional array of shape(output_size, num_secondaries, secondary_rank)
containing the lambda_k^m real coefficients used to construct the low-rank \(D_{km}\) observable matrices. If not provided, the coefficients will be initialized randomly when the module is compiled. Default isNone
(random initialization).uDs (
Array
) – Optional array of shape(output_size, num_secondaries, secondary_rank, matrix_size)
containing the complex vectors which parameterize the low-rank \(D_{km}\) Hermitian observable matrices. If not provided, the vectors will be initialized randomly when the module is compiled. Default isNone
(random initialization).b (
Array
) – Optional array of shape(output_size,)
containing the bias terms \(b_k\). If not provided, the bias terms will be randomly initialized when the module is compiled. Default isNone
(random initialization).init_magnitude (
float
) – Initial magnitude for the random initialization. Default is1e-2
.
Warning
When using smoothing, the
which
options involving magnitude should be avoided, as the smoothing only guarantees that eigenvalues near each other algebraically are smoothed, not across the spectrum.
Call the module with the current parameters and given input, state, and rng.
Returns a
jax.jit
-able andjax.grad
-able callable that represents the module's forward pass.Convenience wrapper to set_precision using the dtype argument, returns self.
Compile the module to be used with the given input shape.
Deserialize the module from a dictionary.
Get the hyperparameters of the module.
Returns the number of trainable floats in the module.
Get the output shape of the module given the input shape.
Get the current trainable parameters of the module.
Get the current state of the module.
Return True if the module is initialized and ready for training or inference.
Returns the name of the module, unless overridden, this is the class name.
Serialize the module to a dictionary.
Set the hyperparameters of the module.
Set the trainable parameters of the module.
Set the precision of the module parameters and state.
Set the state of the module.
- __call__(input_NF, training=False, state=(), rng=None)#
Call the module with the current parameters and given input, state, and rng.
- Parameters:
- Return type:
- Returns:
Output array of shape (num_samples, num_output_features) and new state.
- Raises:
ValueError – If the module is not ready (i.e., compile() has not been called).
See also
_get_callable
Returns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.
- _get_callable()#
Returns a
jax.jit
-able andjax.grad
-able callable that represents the module’s forward pass.This method must be implemented by all subclasses and must return a
jax-jit
-able andjax-grad
-able callable in the form ofmodule_callable( params: tuple[np.ndarray, ...], input_NF: np.ndarray[num_samples, num_features], training: bool, state: tuple[np.ndarray, ...], rng: Any ) -> ( output_NF: np.ndarray[num_samples, num_output_features], new_state: tuple[np.ndarray, ...] )
That is, all hyperparameters are traced out and the callable depends explicitly only on a
tuple
of parameterjax.numpy
arrays, the input array, the training flag, a statetuple
ofjax.numpy
arrays, and a JAX rng key.The training flag will be traced out, so it doesn’t need to be jittable
- Return type:
- Returns:
A callable that takes the module’s parameters, input data, training flag, state, and rng key and returns the output data and new state.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
See also
__call__
Calls the module with the current parameters and given input, state, and rng.
- astype(dtype)#
Convenience wrapper to set_precision using the dtype argument, returns self.
- Parameters:
dtype (
dtype
|str
) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)np.float32
,np.complex64
,"float32"
,"complex64"
,"single"
,"f32"
,"c64"
,32
For 64-bit precision (all options are equivalent)np.float64
,np.complex128
,"float64"
,"complex128"
,"double"
,"f64"
,"c128"
,64
- Return type:
- Returns:
BaseModule – The module itself, with updated precision.
- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64
is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
See also
set_precision
Sets the precision of the module parameters and state.
- compile(rng, input_shape)#
Compile the module to be used with the given input shape.
This method initializes the module’s parameters and state based on the input shape and random key.
This is needed since
Model
s are built before the input data is given, so before training or inference can be done, the module needs to be compiled and each module passes its output shape to the next module’scompile
method.The RNG key is used to initialize random parameters, if needed.
This is not used to trace or jit the module’s callable, that is done automatically later.
- deserialize(data)[source]#
Deserialize the module from a dictionary.
This method sets the module’s parameters and state based on the provided dictionary.
The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.
- Parameters:
data (
dict
[str
,Any
]) – Dictionary containing the serialized module data.- Raises:
ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.
- Return type:
- get_hyperparameters()[source]#
Get the hyperparameters of the module.
Hyperparameters are used to configure the module and are not trainable. They can be set via set_hyperparameters.
- get_num_trainable_floats()#
Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns
0
. If the module is not ready, returnsNone
.
- get_output_shape(input_shape)#
Get the output shape of the module given the input shape.
- get_params()#
Get the current trainable parameters of the module. If the module has no trainable parameters, this method should return an empty tuple.
- Return type:
- Returns:
Tuple of numpy arrays representing the module’s parameters.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- get_state()#
Get the current state of the module.
States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.
The state is optional, in which case this method should return the empty tuple.
- is_ready()#
Return True if the module is initialized and ready for training or inference.
- Return type:
- Returns:
True
if the module is ready,False
otherwise.- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- name()[source]#
Returns the name of the module, unless overridden, this is the class name.
- Return type:
- Returns:
Name of the module.
- serialize()[source]#
Serialize the module to a dictionary.
This method returns a dictionary representation of the module, including its parameters and state.
The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.
This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.
- set_hyperparameters(hyperparams)[source]#
Set the hyperparameters of the module.
Hyperparameters are used to configure the module and are not trainable. They can be set via this method.
The default implementation uses setattr to set the hyperparameters as attributes of the class instance.
- set_params(params)#
Set the trainable parameters of the module.
- Parameters:
params (
tuple
[Array
,...
]) – Tuple of numpy arrays representing the new parameters.- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
- set_precision(prec)#
Set the precision of the module parameters and state.
- Parameters:
prec (
dtype
|str
|int
) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)np.float32
,np.complex64
,"float32"
,"complex64"
,"single"
,"f32"
,"c64"
,32
. For 64-bit precision (all options are equivalent)np.float64
,np.complex128
,"float64"
,"complex128"
,"double"
,"f64"
,"c128"
,64
.- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64
is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
- Return type:
See also
astype
Convenience wrapper to set_precision using the dtype argument, returns self.