Einsum#

parametricmatrixmodels.modules.Einsum

class Einsum(einsum_str=None, params=None, dim_map=None, trainable=False, init_magnitude=0.01, real=True)[source]#

Bases: BaseModule

Module that implements Einsum operations between array leaves in a PyTree input. Optionally with leading trainable arrays. The einsum operation is defined by the provided einsum string, which must specify all indices except for the batch index. The batch index should not be included in the provided einsum string, as it will be automatically added as the leading index later.

Examples

Taking in a PyTree with two leaves, each being a 1D array (excluding the batch dimension), the dot product can be expressed as:

>>> einsum_module = Einsum("i,i->")

Taking in a PyTree with three leaves, each of which being a matrix (excluding the batch dimension), and computing the matrix multiplication of a trainable matrix with the elementwise product of these three input matrices can be expressed as:

>>> einsum_module = Einsum("ij,jk,jk,jk->ik")

Since the einsum string in this example has four input arrays (separated by commas), when the module is compiled with only three input leaves, it will infer that the leading array is trainable, and initialize it randomly during compilation. This array can be directly specified by calling set_params or by providing it in the params argument during initialization.

Alternatively, to perform the same operation but with the leading array fixed (not trainable), the module must be initialized with this array specified in the params argument and trainable set to False:

>>> W = np.random.normal(size=(input_dim, output_dim))
>>> einsum_module = Einsum(
...     "ij,jk,jk,jk->ik",
...     params = W,
...     trainable = False)

In this case, the fixed arrays can be specified later by calling set_hyperparameters with a dictionary containing the key params.

Any additional trainable or fixed arrays will always be treated as leading arrays in the einsum operation.

If no additional fixed or trainable arrays are to be used, the einsum string can alternatively be provided as a two-element Tuple consisting of a PyTree of strings with the same structure as the input data and string representing the output of the einsum, which can be omitted if the output string is to be inferred. For example, for a PyTree input with structure PyTree([*, (*, *)]), with each leaf being a 1D array, to specify the operation of the three-way outer product between the three leaves in the order PyTree([2, (0, 1)]), the einsum string can be provided in any of the following equivalent ways:

>>> Einsum('c,a,b->abc') # full string
>>> Einsum('c,a,b') # output inferred 'abc'
>>> Einsum(['k', ('i', 'j')]) # output inferred 'ijk'
>>> Einsum((['k', ('i', 'j')], 'ijk')) # full tuple
>>> Einsum((['a', ('b', 'c')], 'bca')) # full tuple

If additional fixed or trainable arrays are to be used, the einsum string can be provided as a three-element tuple where the first element is an einsum str for the additional arrays, the second element is a PyTree of strings with the same structure as the input data, and the third element is the output string, which can be omitted if to be inferred. For example, to perform the same three-way outer product as above but with a leading trainable array, the einsum string can be provided in any of the following equivalent ways:

>>> Einsum('ab,c,a,b->c') # full string
>>> Einsum('ab,c,a,b') # output inferred 'c'
>>> Einsum(('ij', ['k', ('i', 'j')])) # output inferred 'k'
>>> Einsum(('ij', ['k', ('i', 'j')], 'k')) # full tuple

For multiple leading arrays the following are equivalent:

>>> Einsum('ab,cd,c,a,b->d') # full string
>>> Einsum('ab,cd,c,a,b') # output inferred 'd'
>>> Einsum(('ab,cd', ['c', ('a', 'b')])) # output inferred 'd'
>>> Einsum(('ab,cd', ['c', ('a', 'b')], 'd')) # full tuple
Parameters:
__init__(einsum_str=None, params=None, dim_map=None, trainable=False, init_magnitude=0.01, real=True)[source]#

Initialize an Einsum module.

Parameters:
  • einsum_str (tuple[str, PyTree[str], str] | tuple[str, PyTree[str]] | tuple[PyTree[str], str] | PyTree[str] | str | None) – The einsum string defining the operation. The batch index should not be included in the provided einsum string, as it will be automatically added as the leading index later. Can be provided as a single string, a PyTree of input_strings, a Tuple of (PyTree[input_strings], output_string), or a tuple of (leading_arrays_einsum_str, PyTree[input_strings], output_string). The input_strings in a PyTree must have the same structure as the input data. output_string can be omitted to have it inferred. If None, it must be set before compilation via set_hyperparameters with the einsum_str key. Default is None.

  • params (PyTree[jaxtyping.Inexact[Array, '...']] | Inexact[Array, '...'] | None) – Optional additional leading arrays for the einsum operation. If trainable is True, these will be treated as the initial values for trainable arrays. If False, they will be treated as fixed arrays. If None and trainable is True, the leading arrays will be initialized randomly during compilation. Default is None. Can be provided later via set_hyperparameters with the params key if trainable is False, or via set_params if trainable is True.

  • dim_map (dict[str, int] | None) – Dictionary mapping einsum indices (characters) to integer sizes for the array dimensions. Only entries for indices that cannot be inferred from the input data shapes or parameter shapes need to be provided. Default is None.

  • trainable (bool) – Whether the provided params are trainable or fixed. If True, the arrays in params will be treated as initial values for trainable arrays. If False, they will be treated as fixed arrays. Default is False.

  • init_magnitude (float) – Magnitude for the random initialization of weights. Default is 1e-2.

  • real (bool) – Ignored when there are no trainable arrays. If True, the weights and biases will be real-valued. If False, they will be complex-valued. Default is True.

Return type:

None

__call__

Call the module with the current parameters and given input, state, and rng.

_get_callable

Returns a jax.jit-able and jax.grad-able callable that represents the module's forward pass.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the module to be used with the given input shape.

deserialize

Deserialize the module from a dictionary.

get_hyperparameters

Get the hyperparameters of the module.

get_num_trainable_floats

Returns the number of trainable floats in the module.

get_output_shape

Get the output shape of the module given the input shape.

get_params

Get the current trainable parameters of the module.

get_state

Get the current state of the module.

is_ready

Return True if the module is initialized and ready for training or inference.

serialize

Serialize the module to a dictionary.

set_hyperparameters

Set the hyperparameters of the module.

set_params

Set the trainable parameters of the module.

set_precision

Set the precision of the module parameters and state.

set_state

Set the state of the module.

name

Returns the name of the module, unless overridden, this is the class name.

__call__(data, /, *, training=False, state=(), rng=None)#

Call the module with the current parameters and given input, state, and rng.

Parameters:
  • data (pmm.typing.Data) – PyTree of input arrays of shape (num_samples, …). Only the first dimension (num_samples) is guaranteed to be the same for all input arrays.

  • training (bool) – Whether the module is in training mode, by default False.

  • state (pmm.typing.State) – State of the module, by default ().

  • rng (Any) – JAX random key, by default None.

Returns:

  • Output array of shape (num_samples, num_output_features) and new

  • state.

Raises:

ValueError – If the module is not ready (i.e., compile() has not been called).

Return type:

tuple[TypeAliasForwardRef(’pmm.typing.Data’), TypeAliasForwardRef(’pmm.typing.State’)]

See also

_get_callable

Returns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.

Params

Typing for the module parameters.

Data

Typing for the input and output data.

State

Typing for the module state.

_get_callable()[source]#

Returns a jax.jit-able and jax.grad-able callable that represents the module’s forward pass.

This method must be implemented by all subclasses and must return a jax-jit-able and jax-grad-able callable in the form of

module_callable(
    params: parametricmatrixmodels.typing.Params,
    data: parametricmatrixmodels.typing.Data,
    training: bool,
    state: parametricmatrixmodels.typing.State,
    rng: Any,
) -> (
    output: parametricmatrixmodels.typing.Data,
    new_state: parametricmatrixmodels.typing.State,
    )

That is, all hyperparameters are traced out and the callable depends explicitly only on

  • the module’s parameters, as a PyTree with leaf nodes as JAX arrays,

  • the input data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the training flag, as a boolean,

  • the module’s state, as a PyTree with leaf nodes as JAX arrays

and returns

  • the output data, as a PyTree with leaf nodes as JAX arrays, each of

    which has shape (num_samples, …),

  • the new module state, as a PyTree with leaf nodes as JAX arrays. The

    PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.

The training flag will be traced out, so it doesn’t need to be jittable

Returns:

  • A callable that takes the module’s parameters, input data,

  • training flag, state, and rng key and returns the output data and

  • new state.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

Callable[[PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’], PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], bool, PyTree[jaxtyping.Num[Array, ’*?d’], ’State’], Any], tuple[PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], PyTree[jaxtyping.Num[Array, ’*?d’], ’State’]]]

See also

__call__

Calls the module with the current parameters and given input, state, and rng.

ModuleCallable

Typing for the callable returned by this method.

Params

Typing for the module parameters.

Data

Typing for the input and output data.

State

Typing for the module state.

_get_concrete_einsum_str(input_shape)[source]#

Get the concrete einsum string by parsing the provided einsum string, adding batch indices and inferring any missing output indices. To account for batch dimensions. If the einsum string is a PyTree, it must have the same structure as the input data.

Parameters:

input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – The shape of the input data, used to infer any missing output indices as well as validate existing indices. Should not include the batch dimension.

Return type:

str

Examples

>>> m = Einsum("ij,jk->ik")
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
'aij,ajk->aik' # leading index 'a' added for batch dimension
>>> m = Einsum((('ij', 'jk'), 'ik'))
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
'aij,ajk->aik'
>>> m = Einsum("ab,bc->ac")
>>> m._get_concrete_einsum_str(((5, 2), (2, 4)))
'dab,dbc->dac' # leading index 'd' added for batch dimension
>>> m = Einsum((['ij', 'jk'], 'ik'))
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
ValueError: The structure of the einsum_str PyTree must match that
of the input data.
# (since the input data is a Tuple of two arrays, not a List)
>>> m = Einsum(('ij,jk', {'x': 'ik', 'y': 'ab'}, 'ab'))
>>> m._get_concrete_einsum_str({'x': (2, 3), 'y': (3, 4)})
'ij,jk,cik,cab->cab' # leading arrays don't have batch index
# all arrays from PyTrees are inserted in the same order as the
# list from jax.tree.leaves(...)
_get_dimension_map(concrete_einsum_str, input_shape)[source]#

Fill in the dimension map by inferring sizes from the input shapes and parameter shapes based on the provided concrete einsum string and self.params if applicable.

Parameters:
  • concrete_einsum_str (str) – The concrete einsum string with all indices specified, including the output indices and batch index.

  • input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – The shape of the input data, used to infer dimension sizes. Should not include the batch dimension.

Returns:

  • A complete dimension map with sizes for all indices in the

  • einsum string.

Return type:

dict[str, int]

astype(dtype, /)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (str | type[Any] | dtype | SupportsDType) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32 For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64

Returns:

BaseModule – The module itself, with updated precision.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

BaseModule

See also

set_precision

Sets the precision of the module parameters and state.

compile(rng, input_shape)[source]#

Compile the module to be used with the given input shape.

This method initializes the module’s parameters and state based on the input shape and random key.

This is needed since Model s are built before the input data is given, so before training or inference can be done, the module needs to be compiled and each module passes its output shape to the next module’s compile method.

The RNG key is used to initialize random parameters, if needed.

This is not used to trace or jit the module’s callable, that is done automatically later.

Parameters:
  • rng (Any) – JAX random key.

  • input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – PyTree of input shape tuples, e.g. ((num_features,),), to compile the module for. All data passed to the module later must have the same PyTree structure and shape in all leaf array dimensions except the leading batch dimension.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

See also

DataShape

Typing for the input shape.

get_output_shape

Get the output shape of the module

deserialize(data, /)#

Deserialize the module from a dictionary.

This method sets the module’s parameters and state based on the provided dictionary.

The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.

Parameters:

data (dict[str, Any]) – Dictionary containing the serialized module data.

Raises:

ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.

Return type:

None

get_hyperparameters()[source]#

Get the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via set_hyperparameters.

Returns:

Dictionary containing the hyperparameters of the module.

Return type:

dict[str, Any]

See also

set_hyperparameters

Set the hyperparameters of the module.

HyperParams

Typing for the hyperparameters. Simply an alias for Dict[str, Any].

get_num_trainable_floats()#

Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns 0. If the module is not ready, returns None.

Returns:

  • Number of trainable floats in the module, or None if the module

  • is not ready.

Return type:

int | None

get_output_shape(input_shape)[source]#

Get the output shape of the module given the input shape.

Parameters:

input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – PyTree of input shape tuples, e.g. ((num_features,),), to get the output shape for.

Returns:

  • PyTree of output shape tuples, e.g. ((num_output_features,),),

  • corresponding to the output shape of the module for the given

  • input shape.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

PyTree[tuple[int | None, …]] | tuple[int | None, …]

See also

DataShape

Typing for the input and output shape.

get_params()[source]#

Get the current trainable parameters of the module. If the module has no trainable parameters, this method should return an empty tuple.

Returns:

  • PyTree with leaf nodes as JAX arrays representing the module’s

  • trainable parameters.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’]

See also

set_params

Set the trainable parameters of the module.

Params

Typing for the module parameters.

get_state()#

Get the current state of the module.

States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.

The state is optional, in which case this method should return the empty tuple.

Returns:

  • PyTree with leaf nodes as JAX arrays representing the module’s

  • state.

Return type:

pmm.typing.State

See also

set_state

Set the state of the module.

State

Typing for the module state.

is_ready()[source]#

Return True if the module is initialized and ready for training or inference.

Returns:

True if the module is ready, False otherwise.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

bool

property name: str#

Returns the name of the module, unless overridden, this is the class name.

Returns:

Name of the module.

serialize()#

Serialize the module to a dictionary.

This method returns a dictionary representation of the module, including its parameters and state.

The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.

This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.

Returns:

Dictionary containing the serialized module data.

Return type:

dict[str, Any]

set_hyperparameters(hyperparams)[source]#

Set the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via this method.

The default implementation uses setattr to set the hyperparameters as attributes of the class instance.

Parameters:
  • hyperparameters – Dictionary containing the hyperparameters to set.

  • hyperparams (dict[str, Any])

Raises:

TypeError – If hyperparameters is not a dictionary.

Return type:

None

See also

get_hyperparameters

Get the hyperparameters of the module.

HyperParams

Typing for the hyperparameters. Simply an alias for Dict[str, Any].

set_params(params)[source]#

Set the trainable parameters of the module.

Parameters:

params (PyTree[jaxtyping.Inexact[Array, '...'], 'Params']) – PyTree with leaf nodes as JAX arrays representing the new trainable parameters of the module.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

See also

get_params

Get the trainable parameters of the module.

Params

Typing for the module parameters.

set_precision(prec, /)#

Set the precision of the module parameters and state.

Parameters:

prec (Any | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32. For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

None

See also

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

set_state(state, /)#

Set the state of the module.

This method is optional.

Parameters:

state (pmm.typing.State) – PyTree with leaf nodes as JAX arrays representing the new state of the module.

Return type:

None

See also

get_state

Get the state of the module.

State

Typing for the module state.