model_util Module#
- ModelCallable#
Type alias for the callable signature of a model’s forward method. Similar (actually identical) to ModuleCallable, but suggestively uses ModelParams and ModelState, which are nested PyTrees.
alias of
Callable[['], 'modelparamsstruct'],']]|'],bool,'] | tuple | list, 'modelstatestruct'],Any],tuple[']]|'],'] | tuple | list, 'modelstatestruct']]]
- ModelModules#
Type alias for a PyTree of modules in a model.
- ModelParams#
Type alias for a PyTree of parameters in a model.
- ModelParamsStruct: str = ' modelparamsstruct'#
ModelParamsStruct is a type tag for PyTrees of model parameters. It really should be the composition of ModelStruct and “params” or “ …” but this doesn’t work currently with jaxtyping due to unbound types in returns. See: patrick-kidger/jaxtyping#357
- ModelState#
Type alias for a PyTree of states in a model.
- ModelStateStruct: str = ' modelstatestruct'#
ModelStateStruct is a type tag for PyTrees of model states. It really should be the composition of ModelStruct and “state” or “ …” but this doesn’t work currently with jaxtyping due to unbound types in returns. See: patrick-kidger/jaxtyping#357
- autobatch(fn, max_batch_size)[source]#
Decorator to automatically limit the batch size of a
ModelCallableorModuleCallablefunction. This is not the same as taking a function and vmap’ing it. The original function must already be able to handle batches of data. This decorator simply breaks up large batches into smaller batches of sizemax_batch_size, calls the original function on each smaller batch, and then concatenates the results.This would usually be used on a function that has already been jit-compiled.
The returned state is the state returned from the last batch processed.
The rng parameter is passed through to each call of the original function unchanged.
- Parameters:
fn (ModelCallable | ModuleCallable) – The function to be decorated. Must be a
ModelCallableorModuleCallable.max_batch_size (int | None) – The maximum batch size to use when calling the function. If
None, then no batching is performed and the original function is returned.
- Return type:
ModelCallable | ModuleCallable