FuncBase#

parametricmatrixmodels.modules.FuncBase

class FuncBase[source]#

Bases: BaseModule

Base class for simple non-trainable function modules. Not to be instantiated directly.

__init__()[source]#

Initialize the function module.

__call__

Call the module with the current parameters and given input, state, and rng.

_get_callable

Get the callable for the function module.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the function module.

copy

Create a deep copy of the module.

deepcopy

Create a deep copy of the module.

deserialize

Deserialize the module from a dictionary.

f

Apply the function to the input data

freeze

Freeze the module parameters by setting trainable to False.

get_hyperparameters

Get the hyperparameters of the function module.

get_num_trainable_floats

Funcs do not have trainable parameters.

get_output_shape

Get the output shape of the function given an input shape.

get_params

Get the parameters of the function module, of which there are none.

get_state

Get the current state of the module.

is_ready

Funcs are always ready to be used.

serialize

Serialize the module to a dictionary.

set_hyperparameters

Set the hyperparameters of the module.

set_params

Set the trainable parameters of the module.

set_precision

Set the precision of the module parameters and state.

set_state

Set the state of the module.

unfreeze

Unfreeze the module parameters by setting trainable to True.

upgrade

Upgrade serialized module data to the current version.

name

Returns the name of the module, unless overridden, this is the class name.

trainable

Whether the module is trainable (i.e., whether its parameters should be updated during training).

__call__(data, /, *, training=False, state=(), rng=None)#

Call the module with the current parameters and given input, state, and rng.

Parameters:
  • data (pmm.typing.Data) – PyTree of input arrays of shape (num_samples, …). Only the first dimension (num_samples) is guaranteed to be the same for all input arrays.

  • training (bool) – Whether the module is in training mode, by default False.

  • state (pmm.typing.State) – State of the module, by default ().

  • rng (Any) – JAX random key, by default None.

Returns:

  • Output array of shape (num_samples, num_output_features) and new

  • state.

Raises:

ValueError – If the module is not ready (i.e., compile() has not been called).

Return type:

tuple[TypeAliasForwardRef(’pmm.typing.Data’), TypeAliasForwardRef(’pmm.typing.State’)]

See also

_get_callable

Returns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.

Params

Typing for the module parameters.

Data

Typing for the input and output data.

State

Typing for the module state.

__trainable: bool#
final _get_callable()[source]#

Get the callable for the function module.

Returns:

  • A callable that applies the function to the input data in the form

  • the PMM library expects.

Return type:

Callable[[PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’], PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], bool, PyTree[jaxtyping.Num[Array, ’*?d’], ’State’], Any], tuple[PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], PyTree[jaxtyping.Num[Array, ’*?d’], ’State’]]]

astype(dtype, /)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (str | type[Any] | dtype | SupportsDType) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32 For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64

Returns:

BaseModule – The module itself, with updated precision.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

BaseModule

See also

set_precision

Sets the precision of the module parameters and state.

final compile(rng, input_shape)[source]#

Compile the function module. No action is needed for function modules.

Parameters:
  • rng (Any) – Random number generator state.

  • input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – Shape of the input arrays.

Return type:

None

copy()#

Create a deep copy of the module.

Returns:

A deep copy of the module.

Return type:

BaseModule

deepcopy()#

Create a deep copy of the module.

Returns:

A deep copy of the module.

Return type:

BaseModule

deserialize(data, /, *, strict_package_version=False)#

Deserialize the module from a dictionary.

This method sets the module’s parameters and state based on the provided dictionary.

The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.

Parameters:
  • data (dict[str, Any]) – Dictionary containing the serialized module data.

  • strict_package_version – If True, raises an error if the package version used to serialize the model does not match the current package version. Default is False.

Raises:

ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.

Return type:

None

abstractmethod f(data)[source]#

Apply the function to the input data

Parameters:

data (PyTree[jaxtyping.Inexact[Array, 'batch_size ...']] | Inexact[Array, 'batch_size ...']) – Input Data (PyTree of arrays).

Returns:

Output Data (PyTree of arrays).

Return type:

PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’]

freeze()#

Freeze the module parameters by setting trainable to False.

Returns:

The module itself, with trainable set to False.

Return type:

BaseModule

abstractmethod get_hyperparameters()[source]#

Get the hyperparameters of the function module.

Returns:

  • An empty dictionary, as function modules do not have

  • hyperparameters.

Return type:

dict[str, Any]

final get_num_trainable_floats()[source]#

Funcs do not have trainable parameters.

Returns:

Always returns 0.

Return type:

int | None

final get_output_shape(input_shape)[source]#

Get the output shape of the function given an input shape.

Parameters:

input_shape (PyTree[tuple[int | None, ...]] | tuple[int | None, ...]) – Shape of the input arrays.

Returns:

Output shapes after applying the function.

Return type:

PyTree[tuple[int | None, …]] | tuple[int | None, …]

final get_params()[source]#

Get the parameters of the function module, of which there are none.

Returns:

An empty tuple, as function modules do not have parameters.

Return type:

PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’]

get_state()#

Get the current state of the module.

States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.

The state is optional, in which case this method should return the empty tuple.

Returns:

  • PyTree with leaf nodes as JAX arrays representing the module’s

  • state.

Return type:

pmm.typing.State

See also

set_state

Set the state of the module.

State

Typing for the module state.

final is_ready()[source]#

Funcs are always ready to be used.

Returns:

Always returns True.

Return type:

bool

property name: str#

Returns the name of the module, unless overridden, this is the class name.

Returns:

Name of the module.

serialize()#

Serialize the module to a dictionary.

This method returns a dictionary representation of the module, including its parameters and state.

The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.

This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.

Returns:

Dictionary containing the serialized module data.

Return type:

dict[str, PyTree[int] | PyTree[float] | PyTree[bool] | PyTree[str] | PyTree[complex] | PyTree[jaxtyping.Shaped[Array, ’…’]] | PyTree[ndarray, ’…’]]

set_hyperparameters(hyperparameters, /)#

Set the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via this method.

The default implementation uses setattr to set the hyperparameters as attributes of the class instance.

Parameters:

hyperparameters (pmm.typing.HyperParams) – Dictionary containing the hyperparameters to set.

Raises:

TypeError – If hyperparameters is not a dictionary.

Return type:

None

See also

get_hyperparameters

Get the hyperparameters of the module.

HyperParams

Typing for the hyperparameters. Simply an alias for Dict[str, Any].

final set_params(params)[source]#

Set the trainable parameters of the module.

Parameters:

params (PyTree[jaxtyping.Inexact[Array, '...'], 'Params']) – PyTree with leaf nodes as JAX arrays representing the new trainable parameters of the module.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

See also

get_params

Get the trainable parameters of the module.

Params

Typing for the module parameters.

set_precision(prec, /)#

Set the precision of the module parameters and state.

Parameters:

prec (Any | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32. For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

None

See also

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

set_state(state, /)#

Set the state of the module.

This method is optional.

Parameters:

state (pmm.typing.State) – PyTree with leaf nodes as JAX arrays representing the new state of the module.

Return type:

None

See also

get_state

Get the state of the module.

State

Typing for the module state.

property trainable: bool#

Whether the module is trainable (i.e., whether its parameters should be updated during training).

Returns:

True if the module is trainable, False otherwise.

unfreeze()#

Unfreeze the module parameters by setting trainable to True.

Returns:

The module itself, with trainable set to True.

Return type:

BaseModule

upgrade(data, /)#

Upgrade serialized module data to the current version.

This method can be overridden by subclasses to implement custom upgrade logic when the module’s serialization format changes between versions.

The default implementation simply returns the input data unchanged.

Parameters:

data (dict[str, Any]) – Dictionary containing the serialized module data.

Returns:

Upgraded dictionary containing the serialized module data.

Return type:

dict[str, Any]