Source code for parametricmatrixmodels.modules.linearnn

from __future__ import annotations

import jax

from ..sequentialmodel import SequentialModel

# direct import to avoid circular import
from ..tree_util import is_shape_leaf
from ..typing import Any, DataShape, Dict, HyperParams, List, Serializable
from .basemodule import BaseModule
from .bias import Bias
from .flatten import Flatten
from .matmul import MatMul


[docs] class LinearNN(SequentialModel): r""" A Module (SequentialModel) representing a single linear neural network layer. This module first flattens the input data, then applies a linear transformation using a weight matrix and bias vector. Optionally, an element-wise activation function can be applied after the linear transformation. This module accepts only bare arrays or PyTrees with only a single leaf array. """ __version__: str = "0.0.0"
[docs] def __init__( self, out_features: int | None = None, bias: bool = True, activation: BaseModule | None = None, init_magnitude: float = 1e-2, real: bool = True, ): r""" Initialize the LinearNN module. Parameters ---------- out_features The number of output features. If None, this must be set later using ``set_hyperparameters``. """ self.modules = None self.out_features = out_features self.bias = bias self.activation = activation self.init_magnitude = init_magnitude self.real = real super().__init__()
[docs] def compile( self, rng: Any | int | None, input_shape: DataShape, verbose: bool = False, ) -> None: r""" Compile the LinearNN module by initializing its sub-modules, then calling the compile method of the parent SequentialModel class. Parameters ---------- rng Random key for initializing the model parameters. JAX PRNGKey or integer seed. input_shape Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input. verbose Print debug information during compilation. Default is False. """ if self.out_features is None: raise ValueError( "out_features must be specified before compiling the module." ) if len(jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)) != 1: raise ValueError( "LinearNN only supports input shapes with a single leaf array." ) if input_shape != self.input_shape or self.modules is None: # don't overwrite modules if input shape hasn't changed modules: List[BaseModule] = [ Flatten(), MatMul( output_shape=self.out_features, trainable=True, init_magnitude=self.init_magnitude, real=self.real, ), ] if self.bias: modules.append( Bias( init_magnitude=self.init_magnitude, real=self.real, scalar=False, trainable=True, ) ) if self.activation is not None: modules.append(self.activation) self.modules = modules super().compile(rng, input_shape, verbose=verbose)
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: r""" Get the output shape of the LinearNN module given the input shape. Parameters ---------- input_shape Shape of the input array, excluding the batch size. Returns ------- DataShape Shape of the output array, excluding the batch size. """ if self.out_features is None: raise ValueError( "out_features must be specified before getting output shape." ) return (self.out_features,)
[docs] def get_hyperparameters(self) -> HyperParams: r""" Get the hyperparameters of the LinearNN module. Returns ------- HyperParams A dictionary containing the hyperparameters of the module. """ return { "out_features": self.out_features, "bias": self.bias, "activation": self.activation, "init_magnitude": self.init_magnitude, "real": self.real, **super().get_hyperparameters(), }
[docs] def set_hyperparameters(self, hyperparams: HyperParams) -> None: r""" Set the hyperparameters of the LinearNN module. Parameters ---------- hyperparams A dictionary containing the hyperparameters to set. """ self.out_features = hyperparams.get("out_features", self.out_features) self.bias = hyperparams.get("bias", self.bias) self.activation = hyperparams.get("activation", self.activation) self.init_magnitude = hyperparams.get( "init_magnitude", self.init_magnitude ) self.real = hyperparams.get("real", self.real) super().set_hyperparameters(hyperparams)
[docs] def deserialize( self, data: Dict[str, Serializable], /, *, strict_package_version: bool = False, ) -> None: r""" Deserialize the LinearNN module from a dictionary. Parameters ---------- data A dictionary containing the serialized data of the module. """ # base implementation takes care of almost everything super().deserialize( data, strict_package_version=strict_package_version ) # just need to read .modules to make sure .activation matches the # modules if self.modules is not None: # figure out if modules contains an activation function # by checking the number of modules if len(self.modules) == 2: self.activation = None elif len(self.modules) == 3 and not self.bias: self.activation = self.modules[2] elif len(self.modules) == 4 and self.bias: self.activation = self.modules[3] else: self.activation = None