Source code for parametricmatrixmodels.modules.affinehermitianmatrix

from __future__ import annotations

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import Array, Inexact, jaxtyped

from ..tree_util import is_shape_leaf, is_single_leaf
from ..typing import (
    Any,
    Data,
    DataShape,
    HyperParams,
    ModuleCallable,
    Params,
    State,
    Tuple,
)
from ._smoothing import exact_smoothing_matrix
from .basemodule import BaseModule


[docs] class AffineHermitianMatrix(BaseModule): r""" Module that builds a parametric hermitian matrix that is affine in the input features. :math:`M(x) = M_0 + x_1 M_1 + ... + x_p M_p + s C` where :math:`M_0, M_1, ..., M_p` are (trainable) Hermitian matrices, :math:`x_1, ..., x_p` are the input features, :math:`s` is the smoothing hyperparameter, and :math:`C` is a matrix that is computed as the imaginary unit times the sum of the commutators of all the :math:`M_i` matrices, in an efficient way using cumulative sums and the linearity of the commutator: .. math:: C &= i\sum_{\substack{i,j\\i\neq j}} \left[M_i, M_j\right] \\ &= i\sum_{\substack{i\\i\neq j}} \left[M_i, \sum_k^j M_k\right] Only accepts PyTree data that has a single leaf array that is 1D, excluding the batch dimension. The PyTree structure is preserved in the output. See Also -------- AffineEigenvaluePMM Module that builds a parametric matrix that is affine in the input features, the same as this module, but returns the eigenvalues of said matrix. AffineObservablePMM Module that builds a parametric matrix that is affine in the input features, the same as this module, but returns the sum of trainable observables and transition probabilities of eigenstates of said matrix. Eigenvalues Module that computes the eigenvalues of a given Hermitian matrix. Can be applied after this module to effectively re-create the ``AffineEigenvaluePMM`` module. """
[docs] def __init__( self, matrix_size: int | None = None, smoothing: float | None = None, Ms: Inexact[Array, "_ n n"] | None = None, init_magnitude: float = 1e-2, bias_term: bool = True, ) -> None: """ Create an ``AffineHermitianMatrix`` module. Parameters ---------- matrix_size Size of the PMM matrices (square), shorthand :math:`n`. smoothing Optional smoothing parameter. Set to ``0.0`` to disable smoothing. Default is ``None``/``0.0`` (no smoothing). Ms Optional array of matrices :math:`M_0, M_1, ..., M_p` that define the parametric affine matrix. Each :math:`M_i` must be Hermitian. If not provided, the matrices will be randomly initialized when the module is compiled. Default is ``None``. init_magnitude Optional initial magnitude of the random matrices, used when initializing the module. Default is ``1e-2``. bias_term If ``True``, include the bias term :math:`M_0` in the equation for :math:`M(x)`. Default is ``True``. """ # input validation if matrix_size is not None and ( matrix_size <= 0 or not isinstance(matrix_size, int) ): raise ValueError("matrix_size must be a positive integer") if Ms is not None: if not isinstance(Ms, np.ndarray): raise ValueError("Ms must be a numpy array") matrix_size = matrix_size or Ms.shape[1] if Ms.shape != (Ms.shape[0], matrix_size, matrix_size): raise ValueError( "Ms must be a 3D array of shape (input_size+1," f" matrix_size, matrix_size) [({Ms.shape[0]}," f" {matrix_size}, {matrix_size})], got {Ms.shape}" ) # ensure Ms are Hermitian if not np.allclose(Ms, Ms.conj().transpose((0, 2, 1))): raise ValueError("Ms must be Hermitian matrices") self.matrix_size = matrix_size self.smoothing = smoothing if smoothing is not None else 0.0 self.bias_term = bias_term self.Ms = Ms # matrices M0, M1, ..., Mp self.init_magnitude = init_magnitude
@property def name(self) -> str: return ( f"AffineHermitianMatrix({self.matrix_size}x{self.matrix_size}," f"{'' if self.smoothing == 0.0 else f' smooth={self.smoothing}'}" f"{'' if self.bias_term else ', no bias'})" )
[docs] def is_ready(self) -> bool: return self.Ms is not None
[docs] def get_num_trainable_floats(self) -> int | None: if not self.is_ready(): return None # each matrix M is Hermitian, and so contains n * (n - 1) / 2 distinct # complex numbers and n distinct real numbers on the diagonal # the total number of trainable floats is then just n^2 per matrix # so Ms contributes (p + 1) * n^2 floats return self.Ms.size
[docs] def _get_callable(self) -> ModuleCallable: if self.bias_term: @jaxtyped(typechecker=beartype) def make_affine_hermitian_matrix( Ms: Inexact[Array, "p n n"], features: Inexact[Array, "b p-1"], ) -> Inexact[Array, "b n n"]: # convert to common dtype, this should be traced out dtype = np.result_type(Ms.dtype, features.dtype) Ms = Ms.astype(dtype) features = features.astype(dtype) M = Ms[0][None, :, :] + np.einsum( "ni,ijk->njk", features, Ms[1:] ) if self.smoothing != 0.0: M += ( self.smoothing * exact_smoothing_matrix(Ms)[None, :, :] ) return M else: @jaxtyped(typechecker=beartype) def make_affine_hermitian_matrix( Ms: Inexact[Array, "p n n"], features: Inexact[Array, "b p"], ) -> Inexact[Array, "b n n"]: # convert to common dtype, this should be traced out dtype = np.result_type(Ms.dtype, features.dtype) Ms = Ms.astype(dtype) features = features.astype(dtype) M = np.einsum("ni,ijk->njk", features, Ms) if self.smoothing != 0.0: M += ( self.smoothing * exact_smoothing_matrix(Ms)[None, :, :] ) return M @jaxtyped(typechecker=beartype) def affine_hermitian_matrix( params: Params, data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: Ms = params # enforce Hermitian matrices Ms = (Ms + Ms.conj().transpose((0, 2, 1))) / 2.0 # tree map over the data to preserve the PyTree structure # compile will have validated that there is only one leaf that is # a 1D array, excluding the batch dimension M = jax.tree.map( lambda x: make_affine_hermitian_matrix(Ms, x), data, ) return (M, state) return affine_hermitian_matrix
[docs] def compile(self, rng: Any, input_shape: DataShape) -> None: valid, leaf = is_single_leaf( input_shape, ndim=1, is_leaf=is_shape_leaf ) # input shape must be a PyTree with a single leaf that is 1D if not valid: raise ValueError( "Input shape must be a PyTree with a single leaf consisting of" " a 1D array, got: {input_shape}" ) # number of matrices is number of features + 1 (bias) if bias is used p = leaf[0] + 1 if self.bias_term else leaf[0] # if the module is already ready, just verify the input shape if self.is_ready(): if self.Ms.shape[0] != p: raise ValueError( f"Input shape {leaf} does not match the expected " f"number of features {self.Ms.shape[0] - 1} " ) return rng_Mreal, rng_Mimag = jax.random.split(rng, 2) self.Ms = self.init_magnitude * ( jax.random.normal( rng_Mreal, (p, self.matrix_size, self.matrix_size), dtype=np.complex64, ) + 1j * jax.random.normal( rng_Mimag, (p, self.matrix_size, self.matrix_size), dtype=np.complex64, ) ) # ensure the matrices are Hermitian self.Ms = (self.Ms + self.Ms.conj().transpose((0, 2, 1))) / 2.0
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: # return (n,n) with the same PyTree structure as input_shape, so long # as the input shape is valid valid, _ = is_single_leaf(input_shape, ndim=1, is_leaf=is_shape_leaf) if not valid: raise ValueError( "Input shape must be a PyTree with a single leaf consisting of" " a 1D array, got: {input_shape}" ) return jax.tree.map( lambda x: (self.matrix_size, self.matrix_size), input_shape, is_leaf=is_shape_leaf, )
[docs] def get_hyperparameters(self) -> HyperParams: return { "matrix_size": self.matrix_size, "smoothing": self.smoothing, "init_magnitude": self.init_magnitude, "bias_term": self.bias_term, }
[docs] def set_hyperparameters(self, hyperparams: HyperParams) -> None: if self.Ms is not None: raise ValueError( "Cannot set hyperparameters after the module has parameters" ) super().set_hyperparameters(hyperparams)
[docs] def get_params(self) -> Params: return self.Ms
[docs] def set_params(self, params: Params) -> None: if not isinstance(params, np.ndarray): raise ValueError("Params must be a numpy array") Ms = params expected_shape = ( Ms.shape[0] if self.Ms is None else self.Ms.shape[0], self.matrix_size, self.matrix_size, ) if Ms.shape != expected_shape: raise ValueError( "Ms must be a 3D array of shape (input_size" f"{'+1' if self.bias_term else ''}, matrix_size," f" matrix_size) [{expected_shape}], got {Ms.shape}" ) # ensure Ms are Hermitian if not np.allclose(Ms, Ms.conj().transpose((0, 2, 1))): raise ValueError("Ms must be Hermitian matrices") self.Ms = Ms