Source code for parametricmatrixmodels.model_util

from __future__ import annotations

from functools import partial, wraps
from typing import TypeAlias

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import (
    Array,
    Inexact,
    Num,
    PyTree,
    jaxtyped,
)

from .modules import BaseModule
from .progressbar import ProgressBar
from .tree_util import concatenate, pytree_split
from .typing import (
    Any,
    Callable,
    Data,
    List,
    ModuleCallable,
    Params,
    State,
    Tuple,
)

# type aliases for Models

ModelStruct: str = " modelstruct"

#: ModelParamsStruct is a type tag for PyTrees of model parameters. It really
#: should be the composition of ModelStruct and "params" or " ..." but this
#: doesn't work currently with jaxtyping due to unbound types in returns.
#: See: https://github.com/patrick-kidger/jaxtyping/issues/357
ModelParamsStruct: str = " modelparamsstruct"
#: ModelStateStruct is a type tag for PyTrees of model states. It really
#: should be the composition of ModelStruct and "state" or " ..." but this
#: doesn't work currently with jaxtyping due to unbound types in returns.
#: See: https://github.com/patrick-kidger/jaxtyping/issues/357
ModelStateStruct: str = " modelstatestruct"

#: Type alias for a PyTree of modules in a model.
ModelModules: TypeAlias = PyTree[BaseModule, ModelStruct]

#: Type alias for a PyTree of parameters in a model.
ModelParams: TypeAlias = PyTree[Inexact[Array, "..."], ModelParamsStruct]

#: Type alias for a PyTree of states in a model.
ModelState: TypeAlias = PyTree[
    Num[Array, "..."] | Tuple | List, ModelStateStruct
]

#: Type alias for the callable signature of a model's forward method. Similar
#: (actually identical) to ModuleCallable, but suggestively uses ModelParams
#: and ModelState, which are nested PyTrees.
ModelCallable: TypeAlias = Callable[
    [
        ModelParams,
        Data,
        bool,
        ModelState,
        Any,
    ],
    Tuple[Data, ModelState],
]


[docs] @jaxtyped(typechecker=beartype) def autobatch( fn: ModelCallable | ModuleCallable, max_batch_size: int | None, verbose: bool = False, ) -> ModelCallable | ModuleCallable: r""" Decorator to automatically limit the batch size of a ``ModelCallable`` or ``ModuleCallable`` function. This is not the same as taking a function and vmap'ing it. The original function must already be able to handle batches of data. This decorator simply breaks up large batches into smaller batches of size ``max_batch_size``, calls the original function on each smaller batch, and then concatenates the results. This would usually be used on a function that has already been jit-compiled. The returned state is the state returned from the last batch processed. The rng parameter is passed through to each call of the original function unchanged. Parameters ---------- fn The function to be decorated. Must be a ``ModelCallable`` or ``ModuleCallable``. max_batch_size The maximum batch size to use when calling the function. If ``None``, then no batching is performed and the original function is returned. """ # rewrite to be jittable @partial(jax.jit, static_argnames=["training"]) @wraps(fn) @jaxtyped(typechecker=beartype) def batched_fn( params: Params | ModelParams, X: Data, training: bool, state: State | ModelState, rng: Any, ) -> Tuple[Data, ModelState]: orig_batch_size = jax.tree.leaves(X)[0].shape[0] # max_batch_size is a static argument and so is the size of the batch, # so this condition will be traced out if max_batch_size is None or orig_batch_size <= max_batch_size: # nothing to do return fn(params, X, training, state, rng) else: num_batches = orig_batch_size // max_batch_size remainder = orig_batch_size % max_batch_size if verbose: pb = ProgressBar( num_batches + 1 + (1 if remainder > 0 else 0), length=20 ) def update_pb(out: Data, i: int) -> Data: pb.update(i) return out def end_pb(out: Data) -> Data: pb.end() return out def body_fn(i_new_state, X_batch): i, new_state = i_new_state out, new_state = fn(params, X_batch, training, new_state, rng) if verbose: out = jax.pure_callback(update_pb, out, out, i + 1) return (i + 1, new_state), out X_batches, X_remainder = pytree_split(X, max_batch_size, axis=0) i_new_state, out = jax.lax.scan( body_fn, (0, state), X_batches, ) _, new_state = i_new_state # out is of shape [num_batches, batch_size, ...], so we need # to reshape it to [num_batches * batch_size, ...] out = np.reshape(out, (-1,) + out.shape[2:]) if remainder > 0: # process the remainder batch first, so that the final state is # from the last batch processed if verbose: out = jax.pure_callback( update_pb, out, out, num_batches + 2 ) out_remainder, new_state = fn( params, X_remainder, training, new_state, rng, ) out = concatenate([out, out_remainder], axis=0) if verbose: out = jax.pure_callback(end_pb, out, out) return out, new_state return batched_fn