Source code for parametricmatrixmodels.modules.lowrankaffineobservablepmm

from __future__ import annotations

import warnings
from typing import Any

import jax.numpy as np

from .basemodule import BaseModule
from .bias import Bias
from .eigenvectors import Eigenvectors
from .lowrankaffinehermitianmatrix import LowRankAffineHermitianMatrix
from .lowranktransitionamplitudesum import LowRankTransitionAmplitudeSum
from .multimodule import MultiModule


[docs] class LowRankAffineObservablePMM(MultiModule): r""" ``LowRankAffineObservablePMM`` is a module that implements a general regression model via the affine observable Parametric Matrix Model (PMM) with low-rank trainable matrices using four primitive modules combined in a MultiModule: a LowRankAffineHermitianMatrix module followed by an Eigenvectors module followed by a LowrankTransitionAmplitudeSum module followed optionally by a Bias module. The Affine Observable PMM (AOPMM) is described in [1]_ and is summarized as follows: Given input features :math:`x_1, \ldots, x_p`, construct the Hermitian matrix that is affine in these features as .. math:: M(x) = M_0 + \sum_{i=1}^p x_i M_i where :math:`M_0, \ldots, M_p` are trainable Hermitian matrices. An optional smoothing term :math:`s C` parameterized by the smoothing hyperparameter :math:`s` can be added to smooth the eigenvalues and eigenvectors of :math:`M(x)`. The matrix :math:`C` is equal to the imaginary unit times the sum of all commutators of the :math:`M_i`. Then, take the leading :math:`r` eigenvectors (by default corresponding to the largest magnitude eigenvalues if there is no smoothing, or the smallest algebraic if there is smoothing) of :math:`M(x)` and compute the sum of the transition amplitudes of these eigenvectors with trainable Hermitian observable matrices (secondaries) :math:`D_{ql}` to form the output vector :math:`z` with :math:`q` components as .. math:: z_k &= \sum_{m=1}^l \left( \sum_{i,j=1}^r \left( |v_i^H D_{km} v_j|^2 \right)\\ &\quad - \frac{r^2}{2} ||D_{km}||^2_2 \right) where :math:`||\cdot||_2` is the operator 2-norm (largest singular value) so for Hermitian :math:`D`, :math:`||D||_2` is the largest absolute eigenvalue. The :math:`-\frac{1}{2} ||D_{km}||^2_2` term centers the value of each term and can be disabled by setting the ``centered`` parameter to ``False``. Finally, an optional trainable bias term :math:`b_k` can be added to each component. In this module, the trainable Hermitian matrices, both :math:`M_i` and :math:`D_{km}`, are parametrized in low-rank form by sums of rank-1 terms constructed from outer products of complex vectors. This reduces the number of trainable parameters. .. warning:: Even though the math shows that the centering term should be multiplied by :math:`r^2`, in practice this doesn't work well and instead setting the centering term to :math:`\frac{1}{2} ||D_{km}||^2_2` works much better. This non-:math:`r^2` scaling is used here. See Also -------- LowRankAffineHermitianMatrix Module that constructs the affine Hermitian matrix :math:`M(x)` from low-rank trainable Hermitian matrices :math:`M_i` and input features. Eigenvectors Module that computes the eigenvectors of a matrix. LowRankTransitionAmplitudeSum Module that computes the sum of transition amplitudes of eigenvectors with trainable low-rank observable matrices. Bias Module that adds a trainable bias term to the output. MultiModule Module that combines multiple modules in sequence. AffineObservablePMM Full-rank version of this module. References ---------- .. [1] Cook, P., Jammooa, D., Hjorth-Jensen, M. et al. Parametric matrix models. Nat Commun 16, 5929 (2025). https://doi.org/10.1038/s41467-025-61362-4 """
[docs] def __init__( self, matrix_size: int = None, primary_rank: int = None, num_eig: int = None, which: str = None, smoothing: float = None, affine_bias_matrix: bool = True, num_secondaries: int = 1, secondary_rank: int = None, output_size: int = None, centered: bool = True, bias_term: bool = True, uMs: np.ndarray = None, uDs: np.ndarray = None, b: np.ndarray = None, init_magnitude: float = 0.01, ): r""" Initialize the module. Parameters ---------- matrix_size Size of the trainable matrices, shorthand :math:`n`. primary_rank Rank of the trainable Hermitian matrices :math:`M_i`. num_eig Number of eigenvectors to use in the transition amplitude calculation, shorthand :math:`r`. which Which eigenvectors to use based on eigenvalue. Options are: - 'SA' for smallest algebraic (default with smoothing) - 'LA' for largest algebraic - 'SM' for smallest magnitude - 'LM' for largest magnitude (default without smoothing) - 'EA' for exterior algebraically - 'EM' for exterior by magnitude - 'IA' for interior algebraically - 'IM' for interior by magnitude smoothing Optional smoothing parameter for the affine matrix. Set to ``None``/``0.0`` to disable smoothing. Default is ``None``/``0.0`` (no smoothing). affine_bias_matrix If ``True``, include the bias term :math:`M_0` in the affine matrix. Default is ``True``. num_secondaries Number of secondary observable matrices :math:`D_{km}` per output component. Shorthand :math:`l`. Default is ``1``. secondary_rank Rank of the trainable Hermitian observable matrices :math:`D_{km}`. output_size Size of the output vector, shorthand :math:`q`. centered If ``True``, include the centering term in the transition amplitude sum. Default is ``True``. bias_term If ``True``, include a trainable bias term :math:`b_k` in the output. Default is ``True``. uMs Optional array of shape ``(input_size+1, primary_rank, matrix_size)`` (if ``affine_bias_matrix`` is ``True``) or ``(input_size, primary_rank, matrix_size)`` (if ``affine_bias_matrix`` is ``False``), containing the complex vectors which parameterize the low-rank :math:`M_i` Hermitian matrices. If not provided, the vectors will be initialized randomly when the module is compiled. Default is ``None`` (random initialization). uDs Optional array of shape ``(output_size, num_secondaries, secondary_rank, matrix_size)`` containing the complex vectors which parameterize the low-rank :math:`D_{km}` Hermitian observable matrices. If not provided, the vectors will be initialized randomly when the module is compiled. Default is ``None`` (random initialization). b Optional array of shape ``(output_size,)`` containing the bias terms :math:`b_k`. If not provided, the bias terms will be randomly initialized when the module is compiled. Default is ``None`` (random initialization). init_magnitude Initial magnitude for the random initialization. Default is ``1e-2``. .. warning:: When using smoothing, the ``which`` options involving magnitude should be avoided, as the smoothing only guarantees that eigenvalues near each other algebraically are smoothed, not across the spectrum. """ # select default which based on smoothing if which is None: if smoothing in (None, 0.0): which = "LM" else: which = "SA" self.matrix_size = matrix_size self.primary_rank = primary_rank self.num_eig = num_eig self.which = which self.smoothing = smoothing self.affine_bias_matrix = affine_bias_matrix self.num_secondaries = num_secondaries self.secondary_rank = secondary_rank self.output_size = output_size self.centered = centered self.bias_term = bias_term self.uMs = uMs self.uDs = uDs self.b = b self.init_magnitude = init_magnitude # raise a warning if smoothing is used with magnitude 'which' if smoothing not in (None, 0.0) and "m" in which.lower(): warnings.warn( "Using smoothing with magnitude 'which' options may lead to " "unexpected behavior, as the smoothing only guarantees that " "eigenvalues near each other algebraically are smoothed, not " "across the spectrum.", UserWarning, ) self.modules = ( LowRankAffineHermitianMatrix( matrix_size=matrix_size, rank=primary_rank, smoothing=smoothing, us=uMs, init_magnitude=init_magnitude, bias_term=affine_bias_matrix, flatten=False, ), Eigenvectors( num_eig=num_eig, which=which, ), LowRankTransitionAmplitudeSum( num_observables=num_secondaries, rank=secondary_rank, output_size=output_size, centered=centered, us=uDs, init_magnitude=init_magnitude, ), ) if bias_term: self.modules += ( Bias( bias=b, init_magnitude=init_magnitude, real=True, scalar=False, trainable=True, ), ) super(LowRankAffineObservablePMM, self).__init__(*self.modules)
[docs] def name(self) -> str: multistr = super(LowRankAffineObservablePMM, self).name() namestr = f"LowRankAffineObservablePMM as {multistr}" return namestr
[docs] def get_hyperparameters(self) -> dict[str, Any]: data = { "matrix_size": self.matrix_size, "primary_rank": self.primary_rank, "num_eig": self.num_eig, "which": self.which, "smoothing": self.smoothing, "init_magnitude": self.init_magnitude, "affine_bias_matrix": self.affine_bias_matrix, "num_secondaries": self.num_secondaries, "secondary_rank": self.secondary_rank, "output_size": self.output_size, "centered": self.centered, "bias_term": self.bias_term, } return { **data, **super(LowRankAffineObservablePMM, self).get_hyperparameters(), }
[docs] def set_hyperparameters(self, hyperparams: dict[str, Any]) -> None: self.matrix_size = hyperparams["matrix_size"] self.primary_rank = hyperparams["primary_rank"] self.num_eig = hyperparams["num_eig"] self.which = hyperparams["which"] self.smoothing = hyperparams["smoothing"] self.init_magnitude = hyperparams["init_magnitude"] self.bias_term = hyperparams["bias_term"] self.affine_bias_matrix = hyperparams["affine_bias_matrix"] self.num_secondaries = hyperparams["num_secondaries"] self.secondary_rank = hyperparams["secondary_rank"] self.output_size = hyperparams["output_size"] self.centered = hyperparams["centered"] self.parameter_counts = hyperparams["parameter_counts"] self.state_counts = hyperparams["state_counts"] self.input_shape = hyperparams["input_shape"] self.output_shape = hyperparams["output_shape"] self.modules = ( LowRankAffineHermitianMatrix( matrix_size=self.matrix_size, rank=self.primary_rank, smoothing=self.smoothing, us=self.uMs, init_magnitude=self.init_magnitude, bias_term=self.affine_bias_matrix, flatten=False, ), Eigenvectors( num_eig=self.num_eig, which=self.which, ), LowRankTransitionAmplitudeSum( num_observables=self.num_secondaries, rank=self.secondary_rank, output_size=self.output_size, centered=self.centered, us=self.uDs, init_magnitude=self.init_magnitude, ), ) if self.bias_term: self.modules += ( Bias( bias=self.b, init_magnitude=self.init_magnitude, real=True, scalar=False, trainable=True, ), )
[docs] def serialize(self) -> dict[str, Any]: # revert to BaseModule serialization return BaseModule.serialize(self)
[docs] def deserialize(self, data: dict[str, Any]) -> None: # revert to BaseModule deserialization BaseModule.deserialize(self, data)