Source code for parametricmatrixmodels.modules.funcbase

from abc import abstractmethod
from typing import final

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import jaxtyped

from parametricmatrixmodels.typing import (
    Any,
    Data,
    DataShape,
    HyperParams,
    ModuleCallable,
    Params,
    State,
    Tuple,
)

from ..tree_util import get_shapes, is_shape_leaf
from .basemodule import BaseModule


[docs] class FuncBase(BaseModule): """ Base class for simple non-trainable function modules. Not to be instantiated directly. """
[docs] @final def __init__(self): """ Initialize the function module. """ pass
@property def name(self) -> str: return f"FuncBase({self.__class__.__name__})"
[docs] @final def is_ready(self) -> bool: """ Funcs are always ready to be used. Returns ------- Always returns True. """ return True
[docs] @final def get_num_trainable_floats(self) -> int | None: """ Funcs do not have trainable parameters. Returns ------- Always returns 0. """ return 0
[docs] @abstractmethod def f(self, data: Data) -> Data: """ Apply the function to the input data Parameters ---------- data Input Data (PyTree of arrays). Returns ------- Output Data (PyTree of arrays). """ raise NotImplementedError("Subclasses must implement `f`.")
[docs] @final def _get_callable(self) -> ModuleCallable: """ Get the callable for the function module. Returns ------- A callable that applies the function to the input data in the form the PMM library expects. """ @jaxtyped(typechecker=beartype) def func_callable( params: Params, data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: return self.f(data), state return func_callable
[docs] @final def compile(self, rng: Any, input_shape: DataShape) -> None: """ Compile the function module. No action is needed for function modules. Parameters ---------- rng Random number generator state. input_shape Shape of the input arrays. """ # make sure that f can handle the input shape self.get_output_shape(input_shape)
[docs] @final def get_output_shape(self, input_shape: DataShape) -> DataShape: """ Get the output shape of the function given an input shape. Parameters ---------- input_shape Shape of the input arrays. Returns ------- Output shapes after applying the function. """ # only way to do this automatically is to run a dummy input # add batch dimension to all shapes input_w_batch_shape = jax.tree.map( lambda s: (1,) + s, input_shape, is_leaf=is_shape_leaf ) dummy_input = jax.tree.map( lambda s: np.zeros(s, dtype=np.float32), input_w_batch_shape, is_leaf=is_shape_leaf, ) try: dummy_output = self.f(dummy_input) except Exception as e: raise RuntimeError( "Failed to compute output shape in `get_output_shape`. " "Make sure the function `f` can handle F32 inputs with shape " f"{get_shapes(dummy_input)}, which includes a leading size-1 " "batch dimension." ) from e output_shape = get_shapes(dummy_output, axis=slice(1, None)) return output_shape
[docs] @final def get_hyperparameters(self) -> HyperParams: """ Get the hyperparameters of the function module, of which there are none. Returns ------- An empty dictionary, as function modules do not have hyperparameters. """ return {}
[docs] @final def get_params(self) -> Params: """ Get the parameters of the function module, of which there are none. Returns ------- An empty tuple, as function modules do not have parameters. """ return ()
[docs] @final def set_params(self, params: Params) -> None: return