typing Module#
- ArrayData#
A special case of
Datawhere the input data is represented as a single JAX array.
- ArrayDataShape#
A special case of
DataShapewhere the input data shape is represented as a single tuple of integers.
- BatchlessComplexDataFixed#
A PyTree with guaranteed structure and shape and only Complex numbers with no batch
- BatchlessDataFixed#
A PyTree with guaranteed structure and shape and only numerical arrays with no batch
- BatchlessRealDataFixed#
A PyTree with guaranteed structure and shape and only Reals with no batch
- ComplexDataFixed#
A PyTree with guaranteed structure and shape and only Complex numbers
- Data: TypeAlias = jaxtyping.PyTree[jaxtyping.Inexact[Array, 'batch_size ...']] | jaxtyping.Inexact[Array, 'batch_size ...']#
A PyTree or JAX array representing input data. If a PyTree, each leaf is a numerical JAX array with a leading batch dimension. The batch dimension must not change during evaluation of a model or module, but all other dimensions, including their number and the structure of the PyTree, may vary. Alternatively, a single JAX array with a leading batch dimension can be used. A module or model may take as input either a PyTree or a single JAX array, and may return either a PyTree or a single JAX array as output, regardless of the input type.
The structure of the PyTree can change throughout evaluation, so it is not specified in the type alias.
- DataFixed#
A PyTree with guaranteed structure and shape and only numerical arrays
- DataShape: TypeAlias = jaxtyping.PyTree[tuple[int | None, ...]] | tuple[int | None, ...]#
A PyTree representing the shape of input data. Each leaf is a tuple of integers representing the shape of the corresponding leaf in a
DataPyTree, excluding the leading batch dimension. Alternatively, a single tuple of integers can be used to represent the shape of a single JAX array.
- DictParams#
A special case of
Paramswhere the parameters are represented as a dictionary of JAX arrays.
- DictState#
A special case of
Statewhere the state is represented as a dictionary of JAX arrays.
- HyperParams#
A dictionary representing hyperparameters for model configuration. The keys are strings representing hyperparameter names, and the values can be of any (ideally serializable) type. If the values are not serializable, then default implementations for saving and loading models and modules to file will need to be overridden in subclasses.
- ListParams#
A special case of
Paramswhere the parameters are represented as a list of JAX arrays.alias of
list[']]
- ListState#
A special case of
Statewhere the state is represented as a list of JAX arrays.alias of
list[Num[Array, '*?d']]
- ModuleCallable#
A Callable that represents the forward pass of a module. The Callable must JAX-jittable, pure (i.e., no side effects), and JAX-differentiable. It takes the following arguments: -
params: A PyTree of model parameters. -data: A PyTree or JAX array of input data. -training: A boolean flag indicating whether the module is being usedfor training or evaluation. Useful for modules that behave differently during training (e.g., dropout, batch normalization).
state: A PyTree representing the current state of the module.rng: An optional JAX random key for stochastic operations.
The Callable returns a tuple containing: -
data: A PyTree or JAX array of output data. -state: A PyTree representing the updated state of the module.If the module does not maintain any state, the
stateargument can be passed as an empty tuple()and the returned state will also be an empty tuple.alias of
Callable[['], 'Params'],']]|'],bool,Num[Array, '*?d'], 'State'],Any],tuple[']]|'],Num[Array, '*?d'], 'State']]]
- Params#
A PyTree representing model parameters. Each leaf is a numerical JAX array.
DictExample:params: Params = { "weights": jnp.array([[1.0, 2.0], [3.0, 4.0]]), "bias": jnp.array([1.0, 2.0]) }
TupleExample:params: Params = (jnp.array([1.0, 2.0]), jnp.array([0.5]))
Arbitrary
PyTreeExample:params: Params = { "w": (jnp.array([1.0, 1.0]), jnp.array([0.0])), "b": jnp.array([0.5]), "a": [jnp.array([[1.0]]), jnp.array([[2.0]])], }
- RealDataFixed#
A PyTree with guaranteed structure and shape and only Reals
- State#
A PyTree representing the private and persistent state of a module. Each leaf is a numerical JAX array of arbitrary shape. The structure of the PyTree and shape of the arrays must not change during evaluation of a module.