Source code for parametricmatrixmodels.modules.prelu

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import jaxtyped

from parametricmatrixmodels.typing import (
    Any,
    ArrayData,
    Data,
    DataShape,
    HyperParams,
    ModuleCallable,
    Params,
    State,
    Tuple,
)

from .basemodule import BaseModule


[docs] class PReLU(BaseModule): r""" Element-wise Parametric Rectified Linear Unit (PReLU) activation function. .. math:: \text{PReLU}(x) = \begin{cases} x, & \text{ if } x \ge 0 \\ ax, & \text{ otherwise } \end{cases} where :math:`a` is a learnable parameter that controls the slope of the negative part of the function. :math:`a` can be either a single parameter shared across all input features, or a separate parameter for each input feature. Operates both on PyTrees and bare arrays. See Also -------- torch.nn.PReLU PyTorch implementation of PReLU activation function. LeakyReLU Non-parametric ReLU activation function with a fixed negative slope. """ __version__: str = "0.0.0"
[docs] def __init__( self, single_parameter: bool = True, init_magnitude: float = 0.25, real: bool = True, ) -> None: """ Create a new ``PReLU`` module. Parameters ---------- single_parameter If ``True``, use a single learnable parameter for all input features. If ``False``, use a separate learnable parameter for each input feature. Default is ``True``. init_magnitude Initial magnitude of the learnable parameter(s). Default is ``0.25``. real If ``True``, the learnable parameter(s) will be real-valued. If ``False``, the learnable parameter(s) will be complex-valued. Default is ``True``. """ self.single_parameter = single_parameter self.init_magnitude = init_magnitude self.real = real self.a: Params | float | complex | None = ( None # learnable parameter(s), will be set in compilation ) self.input_shape: DataShape | None = ( None # input shape, will be set in compilation )
@property def name(self) -> str: return f"PReLU(real={self.real}, single={self.single_parameter})"
[docs] def is_ready(self) -> bool: return (self.a is not None) and (self.input_shape is not None)
[docs] def _get_callable(self) -> ModuleCallable: @jaxtyped(typechecker=beartype) def prelu_array( arr: ArrayData, a: np.ndarray | float | complex, ) -> ArrayData: return jax.nn.leaky_relu( arr, negative_slope=a, ) if self.single_parameter: def module_callable( params: Params, input_data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: a = params[0].squeeze() return ( jax.tree.map(lambda arr: prelu_array(arr, a), input_data), state, ) else: def module_callable( params: Params, input_data: Data, training: bool, state: State, rng: Any, ) -> Tuple[Data, State]: return ( jax.tree.map( prelu_array, input_data, params, ), state, ) return module_callable
[docs] def compile(self, rng: Any, input_shape: DataShape) -> None: # if the module is already ready, check the input shape if self.is_ready(): if self.input_shape != input_shape: raise ValueError( "PReLU module has already been compiled with a different " "input shape." ) return self.input_shape = input_shape if self.single_parameter: a_shape = (1,) else: a_shape = input_shape def make_a( key: Any, shape: Tuple[int, ...], ) -> np.ndarray: if self.real: return self.init_magnitude * jax.random.normal( key, shape, dtype=np.float32 ) else: rkey, ikey = jax.random.split(key) return self.init_magnitude * ( jax.random.normal(rkey, shape, dtype=np.complex64) + 1j * jax.random.normal(ikey, shape, dtype=np.complex64) ) if self.single_parameter: self.a = (make_a(rng, a_shape),) else: keys = jax.random.split( rng, len( jax.tree.leaves( input_shape, is_leaf=lambda x: isinstance(x, tuple) and all(isinstance(i, int) for i in x), ) ), ) keys = jax.tree.unflatten( jax.tree.structure( input_shape, is_leaf=lambda x: isinstance(x, tuple) and all(isinstance(i, int) for i in x), ), keys, ) self.a = jax.tree.map(make_a, keys, a_shape)
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: return input_shape
[docs] def get_hyperparameters(self) -> HyperParams: return { "single_parameter": self.single_parameter, "init_magnitude": self.init_magnitude, "input_shape": self.input_shape, "real": self.real, }
[docs] def set_hyperparameters(self, hyperparams: HyperParams) -> None: if self.a is not None: raise ValueError( "Cannot set hyperparameters after the module has parameters" ) super(PReLU, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> Params: return self.a
[docs] def set_params(self, params: Params) -> None: # ensure the params match the expected shape if self.is_ready(): if self.single_parameter: expected_shape = (1,) if len(params) != 1 or params[0].shape != expected_shape: raise ValueError( f"Expected single parameter of shape {expected_shape}," f" got {params}" ) else: expected_shape = self.input_shape param_structure = jax.tree.structure( params, ) expected_structure = jax.tree.structure( expected_shape, is_leaf=lambda x: isinstance(x, tuple) and all(isinstance(i, int) for i in x), ) if param_structure != expected_structure: raise ValueError( "Expected parameters with structure" f" {expected_structure}, got {param_structure}" ) def check_shape( param: np.ndarray, shape: Tuple[int, ...] ) -> None: if param.shape != shape: raise ValueError( f"Expected parameter of shape {shape}, got" f" {param.shape}" ) jax.tree.map(check_shape, params, expected_shape) self.a = params