SoftSign#
parametricmatrixmodels.modules.SoftSign
- class SoftSign(*args, **kwargs)[source]#
Bases:
ActivationBaseElementwise activation function for
jax.nn.soft_sign.See also
jax.nn.soft_signThe function used for the elementwise activation.
- __init__(*args, **kwargs)[source]#
Initialize the elementwise activation function module.
- Parameters:
args – Positional arguments for the activation function, starts with the second argument, as the first is the input array.
kwargs – Keyword arguments for the activation function.
Call the module with the current parameters and given input, state, and rng.
Get the callable for the activation function.
Convenience wrapper to set_precision using the dtype argument, returns self.
Compile the activation function module.
Deserialize the module from a dictionary.
Apply the activation function to the input array.
Get the hyperparameters of the activation function module.
Get the number of trainable floats in the module.
Get the output shape of the activation function given the input shape.
Get the parameters of the activation function module.
Get the current state of the module.
Check if the module is ready to be used.
Serialize the module to a dictionary.
Set the hyperparameters of the module.
Set the trainable parameters of the module.
Set the precision of the module parameters and state.
Set the state of the module.
Get the name of the activation function module.
- __call__(data, /, *, training=False, state=(), rng=None)#
Call the module with the current parameters and given input, state, and rng.
- Parameters:
data (pmm.typing.Data) – PyTree of input arrays of shape (num_samples, …). Only the first dimension (num_samples) is guaranteed to be the same for all input arrays.
training (bool) – Whether the module is in training mode, by default False.
state (pmm.typing.State) – State of the module, by default
().rng (Any) – JAX random key, by default None.
- Returns:
Output array of shape (num_samples, num_output_features) and new
state.
- Raises:
ValueError – If the module is not ready (i.e., compile() has not been called).
- Return type:
tuple[TypeAliasForwardRef(’pmm.typing.Data’), TypeAliasForwardRef(’pmm.typing.State’)]
See also
_get_callableReturns a callable that can be used to compute the output and new state given the parameters, input, training flag, state, and rng.
ParamsTyping for the module parameters.
DataTyping for the input and output data.
StateTyping for the module state.
- _get_callable()#
Get the callable for the activation function.
- Returns:
The activation function callable in the form the PMM library
expects
- Return type:
Callable[[PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’], PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], bool, PyTree[jaxtyping.Num[Array, ’*?d’], ’State’], Any], tuple[PyTree[jaxtyping.Inexact[Array, ‘batch_size …’]] | Inexact[Array, ‘batch_size …’], PyTree[jaxtyping.Num[Array, ’*?d’], ’State’]]]
- astype(dtype, /)#
Convenience wrapper to set_precision using the dtype argument, returns self.
- Parameters:
dtype (str | type[Any] | dtype | SupportsDType) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)
np.float32,np.complex64,"float32","complex64","single","f32","c64",32For 64-bit precision (all options are equivalent)np.float64,np.complex128,"float64","complex128","double","f64","c128",64- Returns:
BaseModule – The module itself, with updated precision.
- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
- Return type:
See also
set_precisionSets the precision of the module parameters and state.
- compile(rng, input_shape)#
Compile the activation function module. This method is a no-op for activation functions.
- deserialize(data, /)#
Deserialize the module from a dictionary.
This method sets the module’s parameters and state based on the provided dictionary.
The default implementation expects the dictionary to contain the module’s name, trainable parameters, and state.
- Parameters:
data (dict[str, Any]) – Dictionary containing the serialized module data.
- Raises:
ValueError – If the serialized data does not contain the expected keys or if the version of the serialized data is not compatible with with the current package version.
- Return type:
None
- get_hyperparameters()#
Get the hyperparameters of the activation function module.
- get_num_trainable_floats()#
Get the number of trainable floats in the module. Activation functions do not have trainable parameters.
- Returns:
Always returns 0.
- Return type:
int | None
- get_output_shape(input_shape)#
Get the output shape of the activation function given the input shape.
- get_params()#
Get the parameters of the activation function module.
- Returns:
An empty tuple, as activation functions do not have parameters.
- Return type:
PyTree[jaxtyping.Inexact[Array, ’…’], ’Params’]
- get_state()#
Get the current state of the module.
States are used to store “memory” or other information that is not passed between modules, is not trainable, but may be updated during either training or inference. e.g. batch normalization state.
The state is optional, in which case this method should return the empty tuple.
- Returns:
PyTree with leaf nodes as JAX arrays representing the module’s
state.
- Return type:
pmm.typing.State
See also
set_stateSet the state of the module.
StateTyping for the module state.
- is_ready()#
Check if the module is ready to be used. Activation functions are always ready.
- Returns:
Always returns True.
- Return type:
- property name: str#
Get the name of the activation function module.
- Returns:
Name of the activation function module.
- serialize()#
Serialize the module to a dictionary.
This method returns a dictionary representation of the module, including its parameters and state.
The default implementation serializes the module’s name, hyperparameters, trainable parameters, and state via a simple dictionary.
This only works if the module’s hyperparameters are auto-serializable. This includes basic types as well as numpy arrays.
- set_hyperparameters(hyperparameters, /)#
Set the hyperparameters of the module.
Hyperparameters are used to configure the module and are not trainable. They can be set via this method.
The default implementation uses setattr to set the hyperparameters as attributes of the class instance.
- Parameters:
hyperparameters (pmm.typing.HyperParams) – Dictionary containing the hyperparameters to set.
- Raises:
TypeError – If hyperparameters is not a dictionary.
- Return type:
None
See also
get_hyperparametersGet the hyperparameters of the module.
HyperParamsTyping for the hyperparameters. Simply an alias for Dict[str, Any].
- set_params(params)#
Set the trainable parameters of the module.
- Parameters:
params (PyTree[jaxtyping.Inexact[Array, '...'], 'Params']) – PyTree with leaf nodes as JAX arrays representing the new trainable parameters of the module.
- Raises:
NotImplementedError – If the method is not implemented in the subclass.
- Return type:
None
See also
get_paramsGet the trainable parameters of the module.
ParamsTyping for the module parameters.
- set_precision(prec, /)#
Set the precision of the module parameters and state.
- Parameters:
prec (Any | str | int) – Precision to set for the module parameters. Valid options are: For 32-bit precision (all options are equivalent)
np.float32,np.complex64,"float32","complex64","single","f32","c64",32. For 64-bit precision (all options are equivalent)np.float64,np.complex128,"float64","complex128","double","f64","c128",64.- Raises:
ValueError – If the precision is invalid or if 64-bit precision is requested but
JAX_ENABLE_X64is not set.RuntimeError – If the module is not ready (i.e., compile() has not been called).
- Return type:
None
See also
astypeConvenience wrapper to set_precision using the dtype argument, returns self.