Source code for parametricmatrixmodels.modules.eigenvalues

from __future__ import annotations

from typing import Any, Callable

import jax
import jax.numpy as np

from ._affine_backing_funcs import select_eigenvalues
from .basemodule import BaseModule


[docs] class Eigenvalues(BaseModule):
[docs] def __init__( self, num_eig: int = 1, which: str = "SA", ) -> None: if num_eig <= 0 or not isinstance(num_eig, int): raise ValueError("num_eig must be a positive integer") if which.lower() not in [ "sa", "la", "sm", "lm", "ea", "em", "ia", "im", ]: raise ValueError( "which must be one of: 'SA', 'LA', 'SM', 'LM', 'EA', 'EM', " f"'IA', 'IM'. Got: {which}" ) self.num_eig = num_eig self.which = which.lower()
[docs] def name(self) -> str: if self.num_eig == 1 and self.which == "sa": return "Eigenvalues(ground state)" else: return ( f"Eigenvalues(num_eig={self.num_eig}," f" which={self.which.upper()})" )
[docs] def is_ready(self) -> bool: return True
[docs] def get_num_trainable_floats(self) -> int | None: return 0
[docs] def _get_callable(self) -> Callable[ [ tuple[np.ndarray, ...], np.ndarray, bool, tuple[np.ndarray, ...], Any, ], tuple[np.ndarray, tuple[np.ndarray, ...]], ]: return lambda params, input_NF, training, state, rng: ( jax.vmap(select_eigenvalues, in_axes=(0, None, None))( np.linalg.eigvalsh(input_NF), self.num_eig, self.which ), state, # state is not used in this module, return it unchanged )
[docs] def compile(self, rng: Any, input_shape: tuple[int, ...]) -> None: # ensure input shape is valid if len(input_shape) != 2 or input_shape[0] != input_shape[1]: raise ValueError( f"Input shape must be a square matrix, got {input_shape}" )
[docs] def get_output_shape( self, input_shape: tuple[int, ...] ) -> tuple[int, ...]: return (self.num_eig,)
[docs] def get_hyperparameters(self) -> dict[str, Any]: return { "num_eig": self.num_eig, "which": self.which, }
[docs] def set_hyperparameters(self, hyperparams: dict[str, Any]) -> None: super(Eigenvalues, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> tuple[np.ndarray, ...]: return ()
[docs] def set_params(self, params: tuple[np.ndarray, ...]) -> None: return