Source code for parametricmatrixmodels.modules.nonnegativelinearnn

from __future__ import annotations

from typing import Any, Callable

import jax
import jax.numpy as np

from .basemodule import BaseModule


[docs] class NonnegativeLinearNN(BaseModule): """ Module that implements a single linear NN layer with non-negative weights and biases. """
[docs] def __init__( self, k: int = None, W: np.ndarray = None, b: np.ndarray = None, init_magnitude: float = 1e-2, real: bool = True, ) -> None: """ Parameters ---------- k Number of output features. W Weight matrix of shape (num_features, k). If None, it will be initialized randomly. b Bias vector of shape (k,). If None, it will be initialized randomly init_magnitude Magnitude for the random initialization of weights and biases. Default is ``1e-2``. real If ``True``, the weights and biases will be real-valued. If ``False``, they will be complex-valued. Default is ``True``. """ self.k = k self.W = W self.b = b self.init_magnitude = init_magnitude self.real = real # make sure either neither W nor b are provided, or both are provided if (W is None) != (b is None): raise ValueError( "Either both W and b must be provided, or neither." ) if W is not None: self.p = W.shape[0] # number of input features else: self.p = None # ensure that real is True if not real: raise NotImplementedError( "Complex-valued weights and biases are not supported in this " "module." ) # ensure that W and b are real if provided if real and W is not None and not np.isrealobj(W): raise ValueError("W must be real-valued for real weights") if real and b is not None and not np.isrealobj(b): raise ValueError("b must be real-valued for real biases")
[docs] def name(self) -> str: return f"NonnegativeLinearNN(k={self.k}, real={self.real})"
[docs] def is_ready(self) -> bool: return ( self.k is not None and self.p is not None and self.W is not None and self.b is not None )
[docs] def get_num_trainable_floats(self) -> int | None: if not self.is_ready(): return None num_params = self.k * self.p + self.k # W and b if self.real: return num_params else: return 2 * num_params
[docs] def _get_callable(self) -> Callable: # nonnegativity is ensured by taking the square of the weights and # biases return lambda params, input_NF, training, state, rng: ( input_NF @ (params[0] ** 2) + (params[1] ** 2)[None, :], state, # state is not used in this module, return it unchanged )
[docs] def compile(self, rng: Any, input_shape: tuple[int, ...]) -> None: # input shape must be 1D if len(input_shape) != 1: raise ValueError( f"Input shape must be 1D, got {len(input_shape)}D shape: " f"{input_shape}" ) # if the module is already ready, just verify the input shape if self.is_ready(): if self.p != input_shape[0]: raise ValueError( f"Input shape {input_shape} does not match the expected " f"number of features {self.p}" ) return # otherwise, initialize the matrices self.p = input_shape[0] # number of input features subkey_real_W, subkey_imag_W, subkey_real_b, subkey_imag_b = ( jax.random.split(rng, 4) ) if self.W is None: if self.real: self.W = ( jax.random.normal(subkey_real_W, (self.p, self.k)) * self.init_magnitude ) else: real_part = ( jax.random.normal(subkey_real_W, (self.p, self.k)) * self.init_magnitude ) imag_part = ( jax.random.normal(subkey_imag_W, (self.p, self.k)) * self.init_magnitude ) self.W = real_part + 1j * imag_part if self.b is None: if self.real: self.b = ( jax.random.normal(subkey_real_b, (self.k,)) * self.init_magnitude ) else: real_part = ( jax.random.normal(subkey_real_b, (self.k,)) * self.init_magnitude ) imag_part = ( jax.random.normal(subkey_imag_b, (self.k,)) * self.init_magnitude ) self.b = real_part + 1j * imag_part
[docs] def get_output_shape( self, input_shape: tuple[int, ...] ) -> tuple[int, ...]: return (self.k,)
[docs] def get_hyperparameters(self) -> dict[str, Any]: return { "k": self.k, "p": self.p, "init_magnitude": self.init_magnitude, "real": self.real, }
[docs] def set_hyperparameters(self, hyperparams: dict[str, Any]) -> None: if self.W is not None or self.b is not None: raise ValueError( "Cannot set hyperparameters after the module has parameters" ) super(NonnegativeLinearNN, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> tuple[np.ndarray, ...]: return (self.W, self.b)
[docs] def set_params(self, params: tuple[np.ndarray, ...]) -> None: if not isinstance(params, tuple) or not all( isinstance(p, np.ndarray) for p in params ): raise ValueError("params must be a tuple of numpy arrays") if len(params) != 2: raise ValueError(f"Expected 2 parameter array, got {len(params)}") if params[0].shape != (self.p, self.k): raise ValueError( f"Parameter array 0 must be of shape ({self.p}, {self.k}, " f"{self.n}), got {params[0].shape}" ) if params[1].shape != (self.k,): raise ValueError( f"Parameter array 1 must be of shape ({self.k},), " f"got {params[1].shape}" ) if self.real and not np.isrealobj(params[0]): raise ValueError( "Parameter array 0 must be real-valued for a real module" ) if not self.real and np.isrealobj(params[0]): raise ValueError( "Parameter array 0 must be complex-valued for a complex module" ) if self.real and not np.isrealobj(params[1]): raise ValueError( "Parameter array 1 must be real-valued for a real module" ) if not self.real and np.isrealobj(params[1]): raise ValueError( "Parameter array 1 must be complex-valued for a complex module" ) self.W = params[0] self.b = params[1]