ArcCos#

parametricmatrixmodels.modules.ArcCos

class ArcCos(*args, **kwargs)[source]#

Bases: ActivationBase

Elementwise activation function for jax.numpy.arccos.

See also

jax.numpy.arccos

The function used for the elementwise activation.

__init__(*args, **kwargs)[source]#

Initialize the elementwise activation function module.

Parameters:
  • args – Positional arguments for the activation function, starts with the second argument, as the first is the input array.

  • kwargs – Keyword arguments for the activation function.

__call__

Call the module with the current parameters and given input, state, and rng.

_get_callable

Get the callable for the activation function.

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

compile

Compile the activation function module.

deserialize

Deserialize the module from a dictionary.

func

Apply the activation function to the input array.

get_hyperparameters

Get the hyperparameters of the activation function module.

get_num_trainable_floats

Get the number of trainable floats in the module.

get_output_shape

Get the output shape of the activation function given the input shape.

get_params

Get the parameters of the activation function module.

get_state

Get the current state of the module.

is_ready

Check if the module is ready to be used.

name

Returns the name of the module, unless overridden, this is the class name.

serialize

Serialize the module to a dictionary.

set_hyperparameters

Set the hyperparameters of the module.

set_params

Set the trainable parameters of the module.

set_precision

Set the precision of the module parameters and state.

set_state

Set the state of the module.

__call__(input_NF, training=False, state=(), rng=None)#

Call the module with the current parameters and given input, state, and rng.

Parameters:
  • input_NF (Array) – Input array of shape (num_samples, num_features).

  • training (bool) – Whether the module is in training mode, by default False.

  • state (tuple[Array, ...]) – State of the module, by default ().

  • rng (Any) – JAX random key, by default None.

Return type:

tuple[Array, tuple[Array, ...]]

Returns:

Output array of shape (num_samples, num_output_features) and new state.

Raises:

ValueError – If the module is not ready (i.e., compile() has not been called).

See also

_get_callable

Returns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.

_get_callable()#

Get the callable for the activation function.

Return type:

Callable[[tuple[Array, ...], Array, bool, tuple[Array, ...], Any], tuple[Array, tuple[Array, ...]]]

Returns:

The activation function callable in the form the PMM library expects

astype(dtype)#

Convenience wrapper to set_precision using the dtype argument, returns self.

Parameters:

dtype (dtype | str) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32 For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64

Return type:

BaseModule

Returns:

BaseModule – The module itself, with updated precision.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

See also

set_precision

Sets the precision of the module parameters and state.

compile(rng, input_shape)#

Compile the activation function module. This method is a no-op for activation functions.

Parameters:
  • rng (Any) – Random number generator state.

  • input_shape (tuple[int, ...]) – Shape of the input array.

Return type:

None

deserialize(data)#

Deserialize the module from a dictionary.

This method sets the module’s parameters and state based on the provided dictionary.

The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.

Parameters:

data (dict[str, Any]) – Dictionary containing the serialized module data.

Raises:

ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.

Return type:

None

func(x)[source]#

Apply the activation function to the input array.

Parameters:

x (Array) – Input array to the activation function.

Return type:

Array

Returns:

Output array after applying the activation function.

get_hyperparameters()#

Get the hyperparameters of the activation function module.

Return type:

dict[str, Any]

Returns:

Hyperparameters of the activation function.

get_num_trainable_floats()#

Get the number of trainable floats in the module. Activation functions do not have trainable parameters.

Return type:

int | None

Returns:

Always returns 0.

get_output_shape(input_shape)#

Get the output shape of the activation function given the input shape.

Parameters:

input_shape (tuple[int, ...]) – Shape of the input array.

Return type:

tuple[int, ...]

Returns:

Output shape after applying the activation function.

get_params()#

Get the parameters of the activation function module.

Return type:

tuple[Array, ...]

Returns:

An empty tuple, as activation functions do not have parameters.

get_state()#

Get the current state of the module.

States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.

The state is optional, in which case this method should return the empty tuple.

Return type:

tuple[Array, ...]

Returns:

Tuple of numpy arrays representing the module’s state.

is_ready()#

Check if the module is ready to be used. Activation functions are always ready.

Return type:

bool

Returns:

Always returns True.

name()[source]#

Returns the name of the module, unless overridden, this is the class name.

Return type:

str

Returns:

Name of the module.

serialize()#

Serialize the module to a dictionary.

This method returns a dictionary representation of the module, including its parameters and state.

The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.

This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.

Return type:

dict[str, Any]

Returns:

Dictionary containing the serialized module data.

set_hyperparameters(hyperparameters)#

Set the hyperparameters of the module.

Hyperparameters are used to configure the module and are not trainable. They can be set via this method.

The default implementation uses setattr to set the hyperparameters as attributes of the class instance.

Parameters:

hyperparameters (dict[str, Any]) – Dictionary containing the hyperparameters to set.

Raises:

TypeError – If hyperparameters is not a dictionary.

Return type:

None

set_params(params)#

Set the trainable parameters of the module.

Parameters:

params (tuple[Array, ...]) – Tuple of numpy arrays representing the new parameters.

Raises:

NotImplementedError – If the method is not implemented in the subclass.

Return type:

None

set_precision(prec)#

Set the precision of the module parameters and state.

Parameters:

prec (dtype | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent) np.float32, np.complex64, "float32", "complex64", "single", "f32", "c64", 32. For 64-bit precision (all options are equivalent) np.float64, np.complex128, "float64", "complex128", "double", "f64", "c128", 64.

Raises:
  • ValueError – If the precision is invalid or if 64-bit precision is requested but JAX_ENABLE_X64 is not set.

  • RuntimeError – If the module is not ready (i.e., compile() has not been called).

Return type:

None

See also

astype

Convenience wrapper to set_precision using the dtype argument, returns self.

set_state(state)#

Set the state of the module.

This method is optional.

Parameters:

state (tuple[Array, ...]) – Tuple of numpy arrays representing the new state.

Return type:

None