LowRankAffineHermitianMatrix#

parametricmatrixmodels.modules.LowRankAffineHermitianMatrix

class LowRankAffineHermitianMatrix(matrix_size=None, rank=None, smoothing=None, lambdas=None, us=None, init_magnitude=0.01, bias_term=True, flatten=False)[source]#

Bases: BaseModule

Module that builds a parametric hermitian matrix from an affine function of the input features with low-rank matrices.

\(M(x) = M_0 + x_1 M_1 + ... + x_p M_p + s C\) where \(M_0, M_1, ..., M_p\) are (trainable) low-rank Hermitian matrices, \(x_1, ..., x_p\) are the input features, \(s\) is the smoothing hyperparameter, and \(C\) is a matrix that is computed as the imaginary unit times the sum of the commutators of all the \(M_i\) matrices, in an efficient way using cumulative sums and the linearity of the commutator:

\[\begin{split}C &= i\sum_{\substack{i,j\\i\neq j}} \left[M_i, M_j\right] \\ &= i\sum_{\substack{i\\i\neq j}} \left[M_i, \sum_k^j M_k\right]\end{split}\]

Each \(M_i\) is a low-rank Hermitian matrix, which can be parametrized as \(M_i = \sum_k^r \lambda_i u_k^i (u_k^i)^H\) where \(u_k^i\) are a set of \(r\) complex vectors of size \(n\), \(\lambda_i\) are a set of \(r\) real numbers, and \(r\) is the rank of the matrix.

See also

AffineHermitianMatrix

Full-rank version of this module that uses full-rank Hermitian matrices instead of low-rank ones.

__init__(matrix_size=None, rank=None, smoothing=None, lambdas=None, us=None, init_magnitude=0.01, bias_term=True, flatten=False)[source]#

Create an LowRankAffineHermitianMatrix module.

Parameters:
  • matrix_size (int) – Size of the PMM matrices (square), shorthand \(n\).

  • rank (int) – Rank of the low-rank Hermitian matrices, shorthand \(r\). Must be a positive integer less than or equal to matrix_size.

  • smoothing (float) – Optional smoothing parameter. Set to 0.0 to disable smoothing. Default is None/0.0 (no smoothing).

  • lambdas (Array) – Optional array of shape (input_size+1, rank) (if bias_term is True) or (input_size, rank) (if bias_term is False), containing the \(\lambda_k^i\) real coefficients used to construct the low-rank Hermitian matrices. If not provided, the coefficients will be initialized randomly when the module is compiled.

  • us (Array) – Optional array of shape (input_size+1, rank, matrix_size) (if bias_term is True) or (input_size, rank, matrix_size) (if bias_term is False), containing the \(u_k^i\) complex vectors used to construct the low-rank Hermitian matrices. If not provided, the vectors will be initialized randomly when the module is compiled.

  • init_magnitude (float) – Optional initial magnitude of the random matrices, used when initializing the module. Default is 1e-2.

  • bias_term (bool) – If True, include the bias term \(M_0\) in the affine matrix. Default is True.

  • flatten (bool) – If True, the output will be flattened to a 1D array. Useful when combining with SubsetModule or other modules in order to avoid ragged arrays. Default is False.

__call__

Call the module with the current parameters and given input, state, and rng.

_get_callable

Returns a jax.jit-able and jax.grad-able callable that represents the module's forward pass.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the module to be used with the given input shape.

deserialize

Deserialize the module from a dictionary.

get_hyperparameters

Get the hyperparameters of the module.

get_num_trainable_floats

Returns the number of trainable floats in the module.

get_output_shape

Get the output shape of the module given the input shape.

get_params

Get the current trainable parameters of the module.

get_state

Get the current state of the module.

is_ready

Return True if the module is initialized and ready for training or inference.

name

Returns the name of the module, unless overridden, this is the class name.

serialize

Serialize the module to a dictionary.

set_hyperparameters

Set the hyperparameters of the module.

set_params

Set the trainable parameters of the module.

set_precision

Set the precision of the module parameters and state.

set_state

Set the state of the module.

__call__(input_NF, training=False, state=(), rng=None)#

Call the module with the current parameters and given input, state, and rng.

Parameters:
  • input_NF (Array) – Input array of shape (num_samples, num_features).

  • training (bool) – Whether the module is in training mode, by default False.

  • state (tuple[Array, ...]) – State of the module, by default ().

  • rng (Any) – JAX random key, by default None.

Return type:

tuple[Array, tuple[Array, ...]]

Returns:

Output array of shape (num_samples, num_output_features) and new state.

Raises:

ValueError – If the module is not ready (i.e., compile() has not been called).

See also

_get_callable

Returns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.

_get_callable()[source]#

Returns a jax.jit-able and jax.grad-able callable that represents the module’s forward pass.

This method must be implemented by all subclasses and must return a jax-jit-able and jax-grad-able callable in the form of

module_callable(
    params: tuple[np.ndarray, ...],
    input_NF: np.ndarray[num_samples, num_features],
    training: bool,
    state: tuple[np.ndarray, ...],
    rng: Any
) -> (
        output_NF: np.ndarray[num_samples, num_output_features],
        new_state: tuple[np.ndarray, ...]
    )

That is, all hyperparameters are traced out and the callable depends explicitly only on a tuple of parameter jax.numpy arrays, the input array, the training flag, a state tuple of jax.numpy arrays, and a JAX rng key.

The training flag will be traced out, so it doesn’t need to be jittable

Return type:

Callable

Returns:

A callable that takes the module’s parameters, input data, training flag, state, and rng key and returns the output data and new state.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

See also

__call__

Calls the module with the current parameters and given input, state, and rng.

astype(dtype)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (dtype | str) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32 For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64

Return type:

BaseModule

Returns:

BaseModule – The module itself, with updated precision.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

See also

set_precision

Sets the precision of the module parameters and state.

compile(rng, input_shape)[source]#

Compile the module to be used with the given input shape.

This method initializes the module’s parameters and state based on the input shape and random key.

This is needed since Model s are built before the input data is given, so before training or inference can be done, the module needs to be compiled and each module passes its output shape to the next module’s compile method.

The RNG key is used to initialize random parameters, if needed.

This is not used to trace or jit the module’s callable, that is done automatically later.

Parameters:
  • rng (Any) – JAX random key.

  • input_shape (tuple[int, ...]) – Shape of the input data, e.g. (num_features,).

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

deserialize(data)#

Deserialize the module from a dictionary.

This method sets the module’s parameters and state based on the provided dictionary.

The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.

Parameters:

data (dict[str, Any]) – Dictionary containing the serialized module data.

Raises:

ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.

Return type:

None

get_hyperparameters()[source]#

Get the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via set_hyperparameters.

Return type:

dict[str, Any]

Returns:

Dictionary containing the hyperparameters of the module.

get_num_trainable_floats()[source]#

Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns 0. If the module is not ready, returns None.

Return type:

int | None

Returns:

Number of trainable floats in the module, or None if the module is not ready.

get_output_shape(input_shape)[source]#

Get the output shape of the module given the input shape.

Parameters:

input_shape (tuple[int, ...]) – Shape of the input data, e.g. (num_features,).

Return type:

tuple[int, ...]

Returns:

Shape of the output data, e.g. (num_output_features,).

Raises:

NotImplementedError – If the method is not implemented in the subclass.

get_params()[source]#

Get the current trainable parameters of the module. If the module has no trainable parameters, this method should return an empty tuple.

Return type:

tuple[Array, ...]

Returns:

Tuple of numpy arrays representing the module’s parameters.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

get_state()#

Get the current state of the module.

States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.

The state is optional, in which case this method should return the empty tuple.

Return type:

tuple[Array, ...]

Returns:

Tuple of numpy arrays representing the module’s state.

is_ready()[source]#

Return True if the module is initialized and ready for training or inference.

Return type:

bool

Returns:

True if the module is ready, False otherwise.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

name()[source]#

Returns the name of the module, unless overridden, this is the class name.

Return type:

str

Returns:

Name of the module.

serialize()#

Serialize the module to a dictionary.

This method returns a dictionary representation of the module, including its parameters and state.

The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.

This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.

Return type:

dict[str, Any]

Returns:

Dictionary containing the serialized module data.

set_hyperparameters(hyperparams)[source]#

Set the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via this method.

The default implementation uses setattr to set the hyperparameters as attributes of the class instance.

Parameters:

hyperparameters – Dictionary containing the hyperparameters to set.

Raises:

TypeError – If hyperparameters is not a dictionary.

Return type:

None

set_params(params)[source]#

Set the trainable parameters of the module.

Parameters:

params (tuple[Array, ...]) – Tuple of numpy arrays representing the new parameters.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

set_precision(prec)#

Set the precision of the module parameters and state.

Parameters:

prec (dtype | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32. For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

None

See also

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

set_state(state)#

Set the state of the module.

This method is optional.

Parameters:

state (tuple[Array, ...]) – Tuple of numpy arrays representing the new state.

Return type:

None