typing Module#

ArrayData#

A special case of Data where the input data is represented as a single JAX array.

ArrayDataShape#

A special case of DataShape where the input data shape is represented as a single tuple of integers.

alias of tuple[int | None, …]

BatchlessComplexDataFixed#

A PyTree with guaranteed structure and shape and only Complex numbers with no batch

BatchlessDataFixed#

A PyTree with guaranteed structure and shape and only numerical arrays with no batch

BatchlessRealDataFixed#

A PyTree with guaranteed structure and shape and only Reals with no batch

ComplexDataFixed#

A PyTree with guaranteed structure and shape and only Complex numbers

Data: TypeAlias = jaxtyping.PyTree[jaxtyping.Inexact[Array, 'batch_size ...']] | jaxtyping.Inexact[Array, 'batch_size ...']#

A PyTree or JAX array representing input data. If a PyTree, each leaf is a numerical JAX array with a leading batch dimension. The batch dimension must not change during evaluation of a model or module, but all other dimensions, including their number and the structure of the PyTree, may vary. Alternatively, a single JAX array with a leading batch dimension can be used. A module or model may take as input either a PyTree or a single JAX array, and may return either a PyTree or a single JAX array as output, regardless of the input type.

The structure of the PyTree can change throughout evaluation, so it is not specified in the type alias.

DataFixed#

A PyTree with guaranteed structure and shape and only numerical arrays

DataShape: TypeAlias = jaxtyping.PyTree[tuple[int | None, ...]] | tuple[int | None, ...]#

A PyTree representing the shape of input data. Each leaf is a tuple of integers representing the shape of the corresponding leaf in a Data PyTree, excluding the leading batch dimension. Alternatively, a single tuple of integers can be used to represent the shape of a single JAX array.

DictParams#

A special case of Params where the parameters are represented as a dictionary of JAX arrays.

alias of dict[str, ']]

DictState#

A special case of State where the state is represented as a dictionary of JAX arrays.

alias of dict[str, Num[Array, '*?d']]

HyperParams#

A dictionary representing hyperparameters for model configuration. The keys are strings representing hyperparameter names, and the values can be of any (ideally serializable) type. If the values are not serializable, then default implementations for saving and loading models and modules to file will need to be overridden in subclasses.

alias of dict[str, Any]

ListParams#

A special case of Params where the parameters are represented as a list of JAX arrays.

alias of list[']]

ListState#

A special case of State where the state is represented as a list of JAX arrays.

alias of list[Num[Array, '*?d']]

ModuleCallable#

A Callable that represents the forward pass of a module. The Callable must JAX-jittable, pure (i.e., no side effects), and JAX-differentiable. It takes the following arguments: - params: A PyTree of model parameters. - data: A PyTree or JAX array of input data. - training: A boolean flag indicating whether the module is being used

for training or evaluation. Useful for modules that behave differently during training (e.g., dropout, batch normalization).

  • state: A PyTree representing the current state of the module.

  • rng: An optional JAX random key for stochastic operations.

The Callable returns a tuple containing: - data: A PyTree or JAX array of output data. - state: A PyTree representing the updated state of the module.

If the module does not maintain any state, the state argument can be passed as an empty tuple () and the returned state will also be an empty tuple.

alias of Callable[['], 'Params'], ']] | '], bool, Num[Array, '*?d'], 'State'], Any], tuple[']] | '], Num[Array, '*?d'], 'State']]]

Params#

A PyTree representing model parameters. Each leaf is a numerical JAX array.

Dict Example:

params: Params = {
                     "weights": jnp.array([[1.0, 2.0], [3.0, 4.0]]),
                     "bias": jnp.array([1.0, 2.0])
                 }

Tuple Example:

params: Params = (jnp.array([1.0, 2.0]), jnp.array([0.5]))

Arbitrary PyTree Example:

params: Params = {
                     "w": (jnp.array([1.0, 1.0]), jnp.array([0.0])),
                     "b": jnp.array([0.5]),
                     "a": [jnp.array([[1.0]]), jnp.array([[2.0]])],
                 }
RealDataFixed#

A PyTree with guaranteed structure and shape and only Reals

State#

A PyTree representing the private and persistent state of a module. Each leaf is a numerical JAX array of arbitrary shape. The structure of the PyTree and shape of the arrays must not change during evaluation of a module.

TupleParams#

A special case of Params where the parameters are represented as a tuple of JAX arrays.

alias of tuple['], …]

TupleState#

A special case of State where the state is represented as a tuple of JAX arrays.

alias of tuple[Num[Array, '*?d'], …]