model_util Module#

ModelCallable#

Type alias for the callable signature of a model’s forward method. Similar (actually identical) to ModuleCallable, but suggestively uses ModelParams and ModelState, which are nested PyTrees.

alias of Callable[['], 'modelparamsstruct'], ']] | '], bool, '] | tuple | list, 'modelstatestruct'], Any], tuple[']] | '], '] | tuple | list, 'modelstatestruct']]]

ModelModules#

Type alias for a PyTree of modules in a model.

ModelParams#

Type alias for a PyTree of parameters in a model.

ModelParamsStruct: str = ' modelparamsstruct'#

ModelParamsStruct is a type tag for PyTrees of model parameters. It really should be the composition of ModelStruct and “params” or “ …” but this doesn’t work currently with jaxtyping due to unbound types in returns. See: patrick-kidger/jaxtyping#357

ModelState#

Type alias for a PyTree of states in a model.

ModelStateStruct: str = ' modelstatestruct'#

ModelStateStruct is a type tag for PyTrees of model states. It really should be the composition of ModelStruct and “state” or “ …” but this doesn’t work currently with jaxtyping due to unbound types in returns. See: patrick-kidger/jaxtyping#357

autobatch(fn, max_batch_size)[source]#

Decorator to automatically limit the batch size of a ModelCallable or ModuleCallable function. This is not the same as taking a function and vmap’ing it. The original function must already be able to handle batches of data. This decorator simply breaks up large batches into smaller batches of size max_batch_size, calls the original function on each smaller batch, and then concatenates the results.

This would usually be used on a function that has already been jit-compiled.

The returned state is the state returned from the last batch processed.

The rng parameter is passed through to each call of the original function unchanged.

Parameters:
  • fn (ModelCallable | ModuleCallable) – The function to be decorated. Must be a ModelCallable or ModuleCallable.

  • max_batch_size (int | None) – The maximum batch size to use when calling the function. If None, then no batching is performed and the original function is returned.

Return type:

ModelCallable | ModuleCallable