Source code for parametricmatrixmodels.sequentialmodel

from __future__ import annotations

import random

import jax
from beartype import beartype
from jaxtyping import jaxtyped

from .model import Model
from .model_util import (
    ModelCallable,
    ModelModules,
    ModelParams,
    ModelState,
)
from .modules import BaseModule
from .typing import (
    Any,
    Data,
    DataShape,
    ModuleCallable,
    Params,
    State,
    Tuple,
)


[docs] class SequentialModel(Model): r""" A simple sequential model that chains modules (or other models) together sequentially. For confidence intervals or uncertainty quantification, wrap a trained model with ``ConformalModel``. See Also -------- jax.tree PyTree utilities and concepts in JAX. Model Abstract base class for all models. NonsequentialModel A model that allows for non-sequential connections between modules. ConformalModel Wrap a trained model to produce confidence intervals. """ __version__: str = "0.0.0"
[docs] def __init__( self, modules: ModelModules | BaseModule | None = None, /, *, rng: Any | int | None = None, ) -> None: r""" Initialize a sequential model with a PyTree of modules and a random key. For sequential models, the modules are applied in the order they appear in the flattened PyTree. Since insertion order is preserved in dictionaries since Python 3.7, using a dictionary to specify modules is a convenient way to name modules while controlling their application order. If a sequential model is initialized with a dictionary, the ``append``/``prepend``-style methods will will use UUIDs to name new modules if the optional key argument is not provided. Parameters ---------- modules module(s) to initialize the model with. Default is None, which will become an empty list. rng Initial random key for the model. Default is None. If None, a new random key will be generated using JAX's ``random.key``. If an integer is provided, it will be used as the seed to create the key. See Also -------- ModelModules Type alias for a PyTree of modules in a model. jax.random.key JAX function to create a random key. jax.tree.flatten JAX function to flatten a PyTree, which determines the order of module application in a sequential model. jax.tree.leaves JAX function to flatten a PyTree without returning the structure. Equivalent to ``jax.tree.flatten(x)[0]``. """ # no custom initialization needed for sequential model super().__init__(modules, rng=rng)
[docs] def compile( self, rng: Any | int | None, input_shape: DataShape, /, *, verbose: bool = False, ) -> None: r""" Compile the model for training by compiling each module. Must be implemented by all subclasses. Parameters ---------- rng Random key for initializing the model parameters. JAX PRNGKey or integer seed. input_shape Shape of the input array, excluding the batch size. For example, (input_features,) for a 1D input or (input_height, input_width, input_channels) for a 3D input. verbose Print debug information during compilation. Default is False. """ if rng is None: rng = jax.random.key(random.randint(0, 2**32 - 1)) elif isinstance(rng, int): rng = jax.random.key(rng) if verbose: print(f"Compiling {self.name} for input shape {input_shape}.") self.input_shape = input_shape for i, module in enumerate(jax.tree.leaves(self.modules)): rng, modrng = jax.random.split(rng) try: module.compile(modrng, input_shape) except Exception as e: raise RuntimeError( f"Error compiling module {i} ({module.name}) " f"with input shape {input_shape}: {e}" ) from e input_shape = module.get_output_shape(input_shape) if verbose: print(f" {i}: {module.name} output shape: {input_shape}") self.output_shape = input_shape
[docs] def get_output_shape(self, input_shape: DataShape, /) -> DataShape: r""" Get the output shape of the model given an input shape. Must be implemented by all subclasses. Parameters ---------- input_shape Shape of the input, excluding the batch dimension. For example, (input_features,) for 1D bare-array input, or (input_height, input_width, input_channels) for 3D bare-array input, [(input_features1,), (input_features2,)] for a List (PyTree) of 1D arrays, etc. Returns ------- output_shape Shape of the output after passing through the model. """ if self.is_ready(): return self.output_shape else: shape = input_shape for module in jax.tree.leaves(self.modules): shape = module.get_output_shape(shape) return shape
[docs] def _get_callable( self, ) -> ModelCallable: r""" Returns a ``jax.jit``-able and ``jax.grad``-able callable that represents the model's forward pass. This must be implemented by all subclasses. This method must be implemented by all subclasses and must return a ``jax-jit``-able and ``jax-grad``-able callable in the form of .. code-block:: python model_callable( params: parametricmatrixmodels.model_util.ModelParams, data: parametricmatrixmodels.typing.Data, training: bool, state: parametricmatrixmodels.model_util.ModelState, rng: Any, ) -> ( output: parametricmatrixmodels.typing.Data, new_state: parametricmatrixmodels.model_util.ModelState, ) That is, all hyperparameters are traced out and the callable depends explicitly only on * the model's parameters, as a PyTree with leaf nodes as JAX arrays, * the input data, as a PyTree with leaf nodes as JAX arrays, each of which has shape (num_samples, ...), * the training flag, as a boolean, * the model's state, as a PyTree with leaf nodes as JAX arrays and returns * the output data, as a PyTree with leaf nodes as JAX arrays, each of which has shape (num_samples, ...), * the new model state, as a PyTree with leaf nodes as JAX arrays. The PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes. The training flag will be traced out, so it doesn't need to be jittable Returns ------- A callable that takes the model's parameters, input data, training flag, state, and rng key and returns the output data and new state. See Also -------- __call__ : Calls the model with the current parameters and given input, state, and rng. ModelCallable : Typing for the callable returned by this method. Params : Typing for the model parameters. Data : Typing for the input and output data. State : Typing for the model state. """ if not self.is_ready(): raise RuntimeError( f"{self.name} is not ready. Call compile() first." ) # get the callables for each module and put them in a PyTree with the # same structure module_callables = jax.tree.map( lambda m: m._get_callable(), self.modules ) modules_structure = jax.tree.structure(self.modules) @jaxtyped(typechecker=beartype) def sequential( carry: Tuple[Data, ModelState], module_data: Tuple[ModuleCallable, Params, State, Any], training: bool, ) -> Tuple[Data, ModelState]: # carry is (data, [flattened model state]) # module_data is (module_callable, module_params, # module_state, module_rng) ( module_callable, module_params, module_state, module_rng, ) = module_data data, modelstate_flat = carry output, new_module_state = module_callable( module_params, data, training, module_state, module_rng, ) return output, modelstate_flat + [new_module_state] @jaxtyped(typechecker=beartype) def model_callable( params: ModelParams, data: Data, training: bool, state: ModelState, rng: Any, ) -> Tuple[Data, ModelState]: # params, state, and module_callables are PyTrees with the same # structure as self.modules # split rng for each module, put in a PyTree with same structure rngs = jax.random.split(rng, len(jax.tree.leaves(self.modules))) rngs = jax.tree.unflatten(modules_structure, rngs) # use jax.tree.reduce to sequentially apply each module # then reconstruct the new state PyTree # we apply reduce over the zipped module_callables, params, # state, and rngs PyTrees module_data = jax.tree.map( lambda mc, mp, ms, mr: (mc, mp, ms, mr), module_callables, params, state, rngs, ) output, new_state_flat = jax.tree.reduce( lambda ds, md: sequential(ds, md, training), module_data, initializer=(data, []), is_leaf=lambda x: isinstance(x, tuple) and len(x) == 4 and callable(x[0]) and isinstance(x[1], (jax.numpy.ndarray, list, tuple, dict)), ) # reconstruct new state PyTree new_state = jax.tree.unflatten(modules_structure, new_state_flat) return output, new_state return model_callable