Source code for parametricmatrixmodels.modules.affineeigenvaluepmm

from __future__ import annotations

import warnings

import jax
import jax.numpy as np

from ..sequentialmodel import SequentialModel
from ..tree_util import is_shape_leaf, is_single_leaf
from ..typing import (
    Any,
    DataShape,
    HyperParams,
    Tuple,
)
from .affinehermitianmatrix import AffineHermitianMatrix
from .basemodule import BaseModule
from .eigenvalues import Eigenvalues


[docs] class AffineEigenvaluePMM(SequentialModel): r""" ``AffineEigenvaluePMM`` is a module that implements the affine eigenvalue Parametric Matrix Model (PMM) using two primitive modules combined in a SequentialModel: an AffineHermitianMatrix module followed by an Eigenvalues module. The Affine Eigenvalue PMM (AEPMM) is described in [1]_ and is summarized as follows: Given input features :math:`x_1, \ldots, x_p`, construct the Hermitian matrix that is affine in these features as .. math:: M(x) = M_0 + \sum_{i=1}^p x_i M_i where :math:`M_0, \ldots, M_p` are trainable Hermitian matrices. An optional smoothing term :math:`s C` parameterized by the smoothing hyperparameter :math:`s` can be added to smooth the eigenvalues and eigenvectors of :math:`M(x)`. The matrix :math:`C` is equal to the imaginary unit times the sum of all commutators of the :math:`M_i`. The requested eigenvalues of :math:`M(x)` are then computed and returned as the output of the module. See Also -------- AffineHermitianMatrix Module that constructs the affine Hermitian matrix :math:`M(x)` from trainable Hermitian matrices :math:`M_i` and input features. Eigenvalues Module that computes the eigenvalues of a matrix. SequentialModel Module and Model that evaluates multiple modules in sequence. References ---------- .. [1] Cook, P., Jammooa, D., Hjorth-Jensen, M. et al. Parametric matrix models. Nat Commun 16, 5929 (2025). https://doi.org/10.1038/s41467-025-61362-4 """
[docs] def __init__( self, matrix_size: int = None, num_eig: int = 1, which: str = "SA", smoothing: float = None, Ms: np.ndarray = None, init_magnitude: float = 0.01, bias_term: bool = True, ): r""" Initialize the ``AffineEigenvaluePMM`` module. By default this module is initialized to compute the smallest algebraic eigenvalue (ground state). Parameters ---------- matrix_size Size of the PMM matrices, shorthand :math:`n`. num_eig Number of eigenvalues to compute, shorthand :math:`k`. Default is 1. which Which eigenvalues to compute. Default is "SA". Options are: - 'SA' for smallest algebraic (default) - 'LA' for largest algebraic - 'SM' for smallest magnitude - 'LM' for largest magnitude - 'EA' for exterior algebraically - 'EM' for exterior by magnitude - 'IA' for interior algebraically - 'IM' for interior by magnitude For algebraic 'which' options, the eigenvalues are returned in ascending algebraic order. For magnitude 'which' options, the eigenvalues are returned in ascending magnitude order. smoothing Optional smoothing parameter. Set to ``0.0`` to disable smoothing. Default is ``None``/``0.0`` (no smoothing). Ms Optional array of shape ``(input_size+1, matrix_size, matrix_size)`` (if ``bias_term`` is ``True``) or ``(input_size, matrix_size, matrix_size)`` (if ``bias_term`` is ``False``), containing the :math:`M_i` Hermitian matrices. If not provided, the matrices will be initialized randomly when the module is compiled. Default is ``None`` (random initialization). init_magnitude Initial magnitude for the random matrices if ``Ms`` is not provided. Default is ``1e-2``. bias_term If ``True``, include the bias term :math:`M_0` in the affine matrix. Default is ``True``. .. warning:: When using smoothing, the ``which`` options involving magnitude should be avoided, as the smoothing only guarantees that eigenvalues near each other algebraically are smoothed, not across the spectrum. """ self.matrix_size = matrix_size self.num_eig = num_eig self.which = which self.smoothing = smoothing self.Ms = Ms self.init_magnitude = init_magnitude self.bias_term = bias_term self.modules: Tuple[BaseModule] | None = None super().__init__()
[docs] def compile( self, rng: Any | int | None, input_shape: DataShape, verbose: bool = False, ) -> None: valid, _ = is_single_leaf(input_shape, is_leaf=is_shape_leaf) if not valid: raise ValueError( "Input shape must be a PyTree with a single leaf." ) if self.matrix_size is None: raise ValueError("matrix_size must be specified before compiling.") # raise a warning if smoothing is used with magnitude 'which' if self.smoothing not in (None, 0.0) and "m" in self.which.lower(): warnings.warn( "Using smoothing with magnitude 'which' options may lead to " "unexpected behavior, as the smoothing only guarantees that " "eigenvalues near each other algebraically are smoothed, not " "across the spectrum.", UserWarning, ) if self.input_shape != input_shape or self.modules is None: # Create the AffineHermitianMatrix module affine_module = AffineHermitianMatrix( matrix_size=self.matrix_size, smoothing=self.smoothing, Ms=self.Ms, init_magnitude=self.init_magnitude, bias_term=self.bias_term, ) # Create the Eigenvalues module eigen_module = Eigenvalues( num_eig=self.num_eig, which=self.which, ) # Set the modules in the SequentialModel self.modules = (affine_module, eigen_module) # Call the parent compile method super().compile(rng, input_shape, verbose=verbose)
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: valid, _ = is_single_leaf(input_shape, is_leaf=is_shape_leaf) if not valid: raise ValueError( "Input shape must be a PyTree with a single leaf." ) return jax.tree.map( lambda s: (self.num_eig,), input_shape, is_leaf=is_shape_leaf )
[docs] def get_hyperparameters(self) -> HyperParams: return { "matrix_size": self.matrix_size, "num_eig": self.num_eig, "which": self.which, "smoothing": self.smoothing, "init_magnitude": self.init_magnitude, "bias_term": self.bias_term, **super().get_hyperparameters(), }
[docs] def set_hyperparameters(self, hyperparams: HyperParams) -> None: self.matrix_size = hyperparams["matrix_size"] self.num_eig = hyperparams["num_eig"] self.which = hyperparams["which"] self.smoothing = hyperparams["smoothing"] self.init_magnitude = hyperparams["init_magnitude"] self.bias_term = hyperparams["bias_term"] super().set_hyperparameters(hyperparams)