Source code for parametricmatrixmodels.modules.bias

from __future__ import annotations

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import Array, Inexact, PyTree, jaxtyped

from ..tree_util import is_shape_leaf
from ..typing import (
    Any,
    Data,
    DataShape,
    HyperParams,
    ModuleCallable,
    Params,
    State,
    Tuple,
)
from .basemodule import BaseModule


[docs] class Bias(BaseModule): r""" A simple bias module that adds a (trainable by default) bias array (default) or scalar to the input. Can be real (default) or complex-valued. If the input is a ``PyTree`` of arrays, the same bias will be added to each leaf array and therefore the bias shape must match the shape of each leaf array (or be a scalar). """
[docs] def __init__( self, bias: ( PyTree[Inexact[Array, "..."]] | Inexact[Array, "..."] | float | complex | None ) = None, init_magnitude: float = 1e-2, real: bool = True, scalar: bool = False, trainable: bool = True, ) -> None: """ Parameters ---------- bias Bias array, scalar, or PyTree of arrays. If None, the shape will be inferred at compile time and values will be initialized randomly init_magnitude Magnitude for the random initialization of the bias. Default is ``1e-2``. real If ``True``, the biases will be real-valued. If ``False``, they will be complex-valued. Default is ``True``. scalar If ``True`` the bias will be a scalar shared across all input features. If ``False``, the bias will be a array with one entry per input feature. Default is ``False``. trainable If ``True``, the bias will be trainable. Default is ``True``. """ self.bias = bias self.init_magnitude = init_magnitude self.real = real self.scalar = scalar self.trainable = trainable if self.bias is not None: # input validation if self.scalar and not np.isscalar(self.bias): raise ValueError( "If scalar is True, bias must be a scalar or None" ) if not self.scalar and not isinstance( jax.tree.leaves(self.bias)[0], np.ndarray ): raise ValueError( "If scalar is False, bias must be a PyTree of numpy " "arrays, a numpy array, or None" ) if self.real and not jax.tree.all( jax.tree.map(np.isrealobj, self.bias) ): raise ValueError("Bias must be real-valued for a real module") if not self.real and jax.tree.all( jax.tree.map(np.isrealobj, self.bias) ): raise ValueError( "Bias must have at least one complex-valued leaf for a" " complex module" ) if self.scalar: self.bias = np.array(self.bias).reshape((1,))
@property def name(self) -> str: return f"Bias(real={self.real})"
[docs] def is_ready(self) -> bool: return self.bias is not None
[docs] def _get_callable(self) -> ModuleCallable: if isinstance(self.bias, np.ndarray): @jaxtyped(typechecker=beartype) def bias_callable( params: Params, data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: # tree map over data to add bias bias = params def add_bias(x: np.ndarray) -> np.ndarray: return x + bias output = jax.tree.map(add_bias, data) return output, state else: @jaxtyped(typechecker=beartype) def bias_callable( params: Params, data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: biases = params def add_bias(x: np.ndarray, b: np.ndarray) -> np.ndarray: return x + b output = jax.tree.map(add_bias, data, biases) return output, state return bias_callable
[docs] def compile(self, rng: Any, input_shape: DataShape) -> None: # if the module is already ready, just verify the input shape if self.is_ready(): if ( isinstance(self.bias, np.ndarray) and self.bias.shape != (1,) and self.bias.shape != () ): # check if input shape is a single tuple (no PyTree) if ( isinstance(input_shape, tuple) and all(isinstance(i, int) for i in input_shape) and self.bias.shape != input_shape ): raise ValueError( f"Bias shape {self.bias.shape} does not match input " f"shape {input_shape}" ) # else if the input shape is a PyTree of tuples, # all shapes must match elif any( [ shape != self.bias.shape for shape in jax.tree.leaves( input_shape, is_leaf=is_shape_leaf ) ] ): input_shapes_list = jax.tree.leaves( input_shape, is_leaf=is_shape_leaf ) raise ValueError( f"Bias shape {self.bias.shape} does not match all " f"input leaf shapes {input_shapes_list}" ) elif isinstance(self.bias, np.ndarray) and ( self.bias.shape == (1,) or self.bias.shape == () ): pass else: # pytree case def check_shape(x: np.ndarray, s: Tuple[int, ...]) -> None: if x.shape != s: raise ValueError( f"Bias leaf shape {x.shape} does not match input " f"leaf shape {s}" ) jax.tree.map( check_shape, self.bias, input_shape, is_leaf=is_shape_leaf, ) if self.bias is None: shape = (1,) if self.scalar else input_shape # otherwise, initialize the bias # bare array case if isinstance(shape, tuple) and all( isinstance(i, int) for i in shape ): if self.real: self.bias = self.init_magnitude * jax.random.normal( rng, shape, dtype=np.float32 ) else: rkey, ikey = jax.random.split(rng) self.bias = self.init_magnitude * ( jax.random.normal(rkey, shape, dtype=np.complex64) + 1j * jax.random.normal(ikey, shape, dtype=np.complex64) ) else: def init_bias(s: Tuple[int, ...], k: Any) -> np.ndarray: if self.real: return self.init_magnitude * jax.random.normal( k, s, dtype=np.float32 ) else: rkey, ikey = jax.random.split(k) return self.init_magnitude * ( jax.random.normal(rkey, s, dtype=np.complex64) + 1j * jax.random.normal(ikey, s, dtype=np.complex64) ) keys = jax.random.split( rng, len(jax.tree.leaves(shape, is_leaf=is_shape_leaf)) ) keys = jax.tree.unflatten( jax.tree.structure(shape, is_leaf=is_shape_leaf), keys ) self.bias = jax.tree.map( init_bias, shape, keys, is_leaf=is_shape_leaf, )
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: return input_shape
[docs] def get_hyperparameters(self) -> HyperParams: return { "init_magnitude": self.init_magnitude, "real": self.real, "scalar": self.scalar, "trainable": self.trainable, }
[docs] def set_hyperparameters(self, hyperparams: HyperParams) -> None: super(Bias, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> Params: return self.bias
[docs] def set_params(self, params: Params) -> None: if self.bias is not None: # ensure structure matches param_struct = jax.tree.structure(params) bias_struct = jax.tree.structure(self.bias) if param_struct != bias_struct: raise ValueError( "Structure of params does not match existing parameters. " f"Expected structure {bias_struct}, got {param_struct}" ) if not jax.tree.all( jax.tree.map(lambda p: isinstance(p, np.ndarray), params) ): raise ValueError("All leaves must be numpy arrays") if self.scalar and not ( isinstance(params, np.ndarray) and params.shape == (1,) ): raise ValueError( "For scalar bias, params must be a single numpy array with " f"shape (1,), got {params}" ) # check real/complex consistency if self.real and not jax.tree.all(jax.tree.map(np.isrealobj, params)): raise ValueError("Params must be real-valued for a real module") if not self.real and jax.tree.all(jax.tree.map(np.isrealobj, params)): raise ValueError( "Params must have at least one complex-valued leaf for a " "complex module" ) self.bias = params