MatMul#
parametricmatrixmodels.modules.MatMul
- class MatMul(params=None, output_shape=None, path_order=None, trainable=False, init_magnitude=0.01, real=True, separator='.')[source]#
Bases:
EinsumModule for optionally trainable matrix multiplication. Just a special case of Einsum. Computes the matrix multiplication of all leaves in an input PyTree with an optional trainable or fixed matrix.
I.e. for a PyTree input with matrix leaves
(A, B, C), and an optional trainable or fixed parameter matrixM, the output will be the single arrayM @ A @ B @ CThe order of the multiplication can be changed by providing the paths of each matrix in the parameter
path_order. Paths are period-separated strings of keys/indices to reach each matrix in the input PyTree. A double period..indicates the trainable/fixed parameter matrix if applicable.The final input array may be a vector instead of a matrix.
All operations are applied over the batch dimension.
Examples
To create a module that takes a single vector input (ignoring batch dim) and multiplies it by a trainable weight matrix:
>>> m = MatMul(output_shape=2, trainable=True) >>> m(np.ones((batch_dim, 4))) # vec of size 4 in, vec of size 2 out
To create a module that multiplies two input matrices together in an order different from the default:
>>> input_data = { ... 'x': np.ones((batch_dim, 3, 4)), ... 'y': [np.ones((batch_dim, 5, 3)),], ... } >>> m = MatMul(path_order=['y.0', 'x'])
- Parameters:
- __init__(params=None, output_shape=None, path_order=None, trainable=False, init_magnitude=0.01, real=True, separator='.')[source]#
Initialize the MatMul module.
- Parameters:
params (Inexact[Array, '...'] | None) – The parameter matrix to use for multiplication. If None and
trainableisFalse(the default), then no parameter matrix is used. If None andtrainableisTrue, then a randomly initialized trainable matrix is created during compilation.output_shape (tuple[int] | tuple[None, int] | tuple[int, int] | tuple[int, None] | int | None) – The shape of the output matrix/vector (excluding batch dimension). If an integer is provided, it is treated as the size of a vector output. If None (the default), the output size is inferred during compilation based on the input shapes and the parameter matrix shape if applicable. Can be a tuple with None in one position to indicate that the size in that dimension should be inferred.
path_order (list[str] | None) – A list of period-separated strings indicating the order of the PyTree paths to multiply. A double separator
..indicates the position of the parameter matrix if applicable. If None (the default), the order of the matrices is the parameter matrix (if applicable) followed by the input PyTree leaves in the order returned byjax.tree.leaves. Seejax.tree_util.keystrfor more details on path strings.trainable (bool) – Whether the params is trainable.
init_magnitude (float) – The magnitude of the random initialization for the trainable matrix if applicable.
real (bool) – Whether to use real or complex parameters for the trainable matrix if applicable.
separator (str) – The separator to use for path strings. Default is period (‘.’).
- Return type:
None
Call the module with the current parameters and given input, state, and rng.
Returns a
jax.jit-able andjax.grad-able callable that represents the module's forward pass.Convenience wrapper to set_precision using the dtype argument, returns self.
Compile the module to be used with the given input shape.
Deserialize the module from a dictionary.
Get the hyperparameters of the module.
Returns the number of trainable floats in the module.
Get the output shape of the module given the input shape.
Get the current trainable parameters of the module.
Get the current state of the module.
Return True if the module is initialized and ready for training or inference.
Serialize the module to a dictionary.
Set the hyperparameters of the module.
Set the trainable parameters of the module.
Set the precision of the module parameters and state.
Set the state of the module.
Returns the name of the module, unless overridden, this is the class name.
- __call__(data, /, *, training=False, state=(), rng=None)#
Call the module with the current parameters and given input, state, and rng.
- Parameters:
data (pmm.typing.Data) – PyTree of input arrays of shape (num_samples, …). Only the first dimension (num_samples) is guaranteed to be the same for all input arrays.
training (bool) – Whether the module is in training mode, by default False.
state (pmm.typing.State) – State of the module, by default
().rng (Any) – JAX random key, by default None.
- Returns:
Output array of shape (num_samples, num_output_features) and new
state.
- Raises:
ValueError – If the module is not ready (i.e., compile() has not been called).
- Return type:
tuple[TypeAliasForwardRef(’pmm.typing.Data’), TypeAliasForwardRef(’pmm.typing.State’)]
See also
_get_callableReturns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.
ParamsTyping for the module parameters.
DataTyping for the input and output data.
StateTyping for the module state.
- _get_callable()#
Returns a
jax.jit-able andjax.grad-able callable that represents the module’s forward pass.This method must be implemented by all subclasses and must return a
jax-jit-able andjax-grad-able callable in the form ofmodule_callable( params: parametricmatrixmodels.typing.Params, data: parametricmatrixmodels.typing.Data, training: bool, state: parametricmatrixmodels.typing.State, rng: Any, ) -> ( output: parametricmatrixmodels.typing.Data, new_state: parametricmatrixmodels.typing.State, )
That is, all hyperparameters are traced out and the callable depends explicitly only on
the module’s parameters, as a PyTree with leaf nodes as JAX arrays,
- the input data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
the training flag, as a boolean,
the module’s state, as a PyTree with leaf nodes as JAX arrays
and returns
- the output data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, …),
- the new module state, as a PyTree with leaf nodes as JAX arrays. The
PyTree structure must match that of the input state and additionally all leaf nodes must have the same shape as the input state leaf nodes.
The training flag will be traced out, so it doesn’t need to be jittable
- Returns:
A callable that takes the module’s parameters, input data,
training flag, state, and rng key and returns the output data and
new state.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
Callable[[PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’], PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], bool, PyTree[jaxtyping.Num[Array, ’*?d’], ’State’], Any], tuple[PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], PyTree[jaxtyping.Num[Array, ’*?d’], ’State’]]]
See also
__call__Calls the module with the current parameters and given input, state, and rng.
ModuleCallableTyping for the callable returned by this method.
ParamsTyping for the module parameters.
DataTyping for the input and output data.
StateTyping for the module state.
- _get_concrete_einsum_str(input_shape)#
Get the concrete einsum string by parsing the provided einsum string, adding batch indices and inferring any missing output indices. To account for batch dimensions. If the einsum string is a PyTree, it must have the same structure as the input data.
- Parameters:
input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – The shape of the input data, used to infer any missing output indices as well as validate existing indices. Should not include the batch dimension.
- Return type:
Examples
>>> m = Einsum("ij,jk->ik") >>> m._get_concrete_einsum_str(((2, 3), (3, 4))) 'aij,ajk->aik' # leading index 'a' added for batch dimension
>>> m = Einsum((('ij', 'jk'), 'ik')) >>> m._get_concrete_einsum_str(((2, 3), (3, 4))) 'aij,ajk->aik'
>>> m = Einsum("ab,bc->ac") >>> m._get_concrete_einsum_str(((5, 2), (2, 4))) 'dab,dbc->dac' # leading index 'd' added for batch dimension
>>> m = Einsum((['ij', 'jk'], 'ik')) >>> m._get_concrete_einsum_str(((2, 3), (3, 4))) ValueError: The structure of the einsum_str PyTree must match that of the input data. # (since the input data is a Tuple of two arrays, not a List)
>>> m = Einsum(('ij,jk', {'x': 'ik', 'y': 'ab'}, 'ab')) >>> m._get_concrete_einsum_str({'x': (2, 3), 'y': (3, 4)}) 'ij,jk,cik,cab->cab' # leading arrays don't have batch index # all arrays from PyTrees are inserted in the same order as the # list from jax.tree.leaves(...)
- _get_dimension_map(concrete_einsum_str, input_shape)#
Fill in the dimension map by inferring sizes from the input shapes and parameter shapes based on the provided concrete einsum string and
self.paramsif applicable.- Parameters:
concrete_einsum_str (str) – The concrete einsum string with all indices specified, including the output indices and batch index.
input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – The shape of the input data, used to infer dimension sizes. Should not include the batch dimension.
- Returns:
A complete dimension map with sizes for all indices in the
einsum string.
- Return type:
- astype(dtype, /)#
Convenience wrapper to set_precision using the dtype argument, returns self.
- Parameters:
dtype (str | type[Any] | dtype | SupportsDType) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)
np.float32,np.complex64,"float32","complex64","single","f32","c64",32For 64-bit precision (all options are equivalent)np.float64,np.complex128,"float64","complex128","double","f64","c128",64- Returns:
BaseModule – The module itself, with updated precision.
- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
- Return type:
See also
set_precisionSets the precision of the module parameters and state.
- compile(rng, input_shape)[source]#
Compile the module to be used with the given input shape.
This method initializes the module’s parameters and state based on the input shape and random key.
This is needed since
Models are built before the input data is given, so before training or inference can be done, the module needs to be compiled and each module passes its output shape to the next module’scompilemethod.The RNG key is used to initialize random parameters, if needed.
This is not used to trace or jit the module’s callable, that is done automatically later.
- Parameters:
rng (Any) – JAX random key.
input_shape (pmm.typing.DataShape) – PyTree of input shape tuples, e.g.
((num_features,),), to compile the module for. All data passed to the module later must have the same PyTree structure and shape in all leaf array dimensions except the leading batch dimension.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
None
See also
DataShapeTyping for the input shape.
get_output_shapeGet the output shape of the module
- deserialize(data, /)#
Deserialize the module from a dictionary.
This method sets the module’s parameters and state based on the provided dictionary.
The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.
- Parameters:
data (dict[str, Any]) – Dictionary containing the serialized module data.
- Raises:
ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.
- Return type:
None
- get_hyperparameters()[source]#
Get the hyperparameters of the module.
Hyperparameters are used to configure the module and are not trainable. They can be set via set_hyperparameters.
- Returns:
Dictionary containing the hyperparameters of the module.
- Return type:
pmm.typing.HyperParams
See also
set_hyperparametersSet the hyperparameters of the module.
HyperParamsTyping for the hyperparameters. Simply an alias for Dict[str, Any].
- get_num_trainable_floats()#
Returns the number of trainable floats in the module. If the module does not have trainable parameters, returns
0. If the module is not ready, returnsNone.- Returns:
Number of trainable floats in the module, or None if the module
is not ready.
- Return type:
int | None
- get_output_shape(input_shape)[source]#
Get the output shape of the module given the input shape.
- Parameters:
input_shape (pmm.typing.DataShape) – PyTree of input shape tuples, e.g.
((num_features,),), to get the output shape for.- Returns:
PyTree of output shape tuples, e.g.
((num_output_features,),),corresponding to the output shape of the module for the given
input shape.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
pmm.typing.DataShape
See also
DataShapeTyping for the input and output shape.
- get_params()#
Get the current trainable parameters of the module. If the module has no trainable parameters, this method should return an empty tuple.
- Returns:
PyTree with leaf nodes as JAX arrays representing the module’s
trainable parameters.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’]
See also
set_paramsSet the trainable parameters of the module.
ParamsTyping for the module parameters.
- get_state()#
Get the current state of the module.
States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.
The state is optional, in which case this method should return the empty tuple.
- Returns:
PyTree with leaf nodes as JAX arrays representing the module’s
state.
- Return type:
pmm.typing.State
See also
set_stateSet the state of the module.
StateTyping for the module state.
- is_ready()#
Return True if the module is initialized and ready for training or inference.
- Returns:
Trueif the module is ready,Falseotherwise.- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
- property name: str#
Returns the name of the module, unless overridden, this is the class name.
- Returns:
Name of the module.
- serialize()#
Serialize the module to a dictionary.
This method returns a dictionary representation of the module, including its parameters and state.
The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.
This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.
- set_hyperparameters(hyperparams)[source]#
Set the hyperparameters of the module.
Hyperparameters are used to configure the module and are not trainable. They can be set via this method.
The default implementation uses setattr to set the hyperparameters as attributes of the class instance.
- Parameters:
hyperparameters – Dictionary containing the hyperparameters to set.
hyperparams (pmm.typing.HyperParams)
- Raises:
TypeError – If hyperparameters is not a dictionary.
- Return type:
None
See also
get_hyperparametersGet the hyperparameters of the module.
HyperParamsTyping for the hyperparameters. Simply an alias for Dict[str, Any].
- set_params(params)#
Set the trainable parameters of the module.
- Parameters:
params (PyTree[jaxtyping.Inexact[Array, '...'], 'Params']) – PyTree with leaf nodes as JAX arrays representing the new trainable parameters of the module.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
None
See also
get_paramsGet the trainable parameters of the module.
ParamsTyping for the module parameters.
- set_precision(prec, /)#
Set the precision of the module parameters and state.
- Parameters:
prec (Any | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)
np.float32,np.complex64,"float32","complex64","single","f32","c64",32. For 64-bit precision (all options are equivalent)np.float64,np.complex128,"float64","complex128","double","f64","c128",64.- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
- Return type:
None
See also
astypeConvenience wrapper to set_precision using the dtype argument, returns self.