Source code for parametricmatrixmodels.modules.bias

from __future__ import annotations

from typing import Any, Callable

import jax
import jax.numpy as np

from .basemodule import BaseModule


[docs] class Bias(BaseModule): r""" A simple bias module that adds a (trainable by default) bias array (default) or scalar to the input. Can be real (default) or complex-valued. """
[docs] def __init__( self, bias: np.ndarray | float | complex | None = None, init_magnitude: float = 1e-2, real: bool = True, scalar: bool = False, trainable: bool = True, ) -> None: """ Parameters ---------- bias Bias array or scalar. If None, it will be initialized randomly init_magnitude Magnitude for the random initialization of the bias. Default is ``1e-2``. real If ``True``, the biases will be real-valued. If ``False``, they will be complex-valued. Default is ``True``. scalar If ``True`` the bias will be a scalar shared across all input features. If ``False``, the bias will be a array with one entry per input feature. Default is ``False``. trainable If ``True``, the bias will be trainable. Default is ``True``. """ self.bias = bias self.init_magnitude = init_magnitude self.real = real self.scalar = scalar self.trainable = trainable if self.bias is not None: # input validation if self.scalar and not np.isscalar(self.bias): raise ValueError( "If scalar is True, bias must be a scalar or None" ) if not self.scalar and not isinstance(self.bias, np.ndarray): raise ValueError( "If scalar is False, bias must be a numpy array or None" ) if self.real and not np.isrealobj(self.bias): raise ValueError("Bias must be real-valued for a real module") if not self.real and np.isrealobj(self.bias): raise ValueError( "Bias must be complex-valued for a complex module" ) if self.scalar: self.bias = np.array(self.bias).reshape((1,))
[docs] def name(self) -> str: return f"Bias(real={self.real})"
[docs] def is_ready(self) -> bool: return self.bias is not None
[docs] def _get_callable(self) -> Callable: return lambda params, input_NF, training, state, rng: ( input_NF + params[0], state, # state is not used in this module, return it unchanged )
[docs] def compile(self, rng: Any, input_shape: tuple[int, ...]) -> None: # if the module is already ready, just verify the input shape if self.is_ready(): if self.bias.shape != (1,) and self.bias.shape != input_shape: raise ValueError( f"Bias shape {self.bias.shape} does not match input " f"shape {input_shape}" ) return shape = (1,) if self.scalar else input_shape # otherwise, initialize the bias subkey_real, subkey_imag = jax.random.split(rng, 2) if self.bias is None: if self.real: self.bias = ( jax.random.normal(subkey_real, shape) * self.init_magnitude ) else: real_part = ( jax.random.normal(subkey_real, shape) * self.init_magnitude ) imag_part = ( jax.random.normal(subkey_imag, shape) * self.init_magnitude ) self.bias = real_part + 1j * imag_part
[docs] def get_output_shape( self, input_shape: tuple[int, ...] ) -> tuple[int, ...]: return input_shape
[docs] def get_hyperparameters(self) -> dict[str, Any]: return { "init_magnitude": self.init_magnitude, "real": self.real, "scalar": self.scalar, "trainable": self.trainable, }
[docs] def set_hyperparameters(self, hyperparams: dict[str, Any]) -> None: super(Bias, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> tuple[np.ndarray, ...]: return (self.bias,)
[docs] def set_params(self, params: tuple[np.ndarray, ...]) -> None: if not isinstance(params, tuple) or not all( isinstance(p, np.ndarray) for p in params ): raise ValueError("params must be a tuple of numpy arrays") if len(params) != 1: raise ValueError(f"Expected 1 parameter array, got {len(params)}") if self.real and not np.isrealobj(params[0]): raise ValueError( "Parameter array 0 must be real-valued for a real module" ) if not self.real and np.isrealobj(params[0]): raise ValueError( "Parameter array 0 must be complex-valued for a complex module" ) if self.scalar and params[0].shape != (1,): raise ValueError( "Parameter array 0 must be a scalar array with shape (1,)," f" got {params[0].shape}" ) self.bias = params[0]