from __future__ import annotations
import jax
from .activationbase import ActivationBase
[docs]
class ReLU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.relu``.
See Also
--------
jax.nn.relu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ReLU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.relu(x, *self.args, **self.kwargs)
[docs]
class ReLU6(ActivationBase):
"""
Elementwise activation function for ``jax.nn.relu6``.
See Also
--------
jax.nn.relu6 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ReLU6"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.relu6(x, *self.args, **self.kwargs)
[docs]
class Sigmoid(ActivationBase):
"""
Elementwise activation function for ``jax.nn.sigmoid``.
See Also
--------
jax.nn.sigmoid : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sigmoid"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.sigmoid(x, *self.args, **self.kwargs)
[docs]
class Softplus(ActivationBase):
"""
Elementwise activation function for ``jax.nn.softplus``.
See Also
--------
jax.nn.softplus : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Softplus"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.softplus(x, *self.args, **self.kwargs)
[docs]
class SparsePlus(ActivationBase):
"""
Elementwise activation function for ``jax.nn.sparse_plus``.
See Also
--------
jax.nn.sparse_plus : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SparsePlus"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.sparse_plus(x, *self.args, **self.kwargs)
[docs]
class SparseSigmoid(ActivationBase):
"""
Elementwise activation function for ``jax.nn.sparse_sigmoid``.
See Also
--------
jax.nn.sparse_sigmoid : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SparseSigmoid"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.sparse_sigmoid(x, *self.args, **self.kwargs)
[docs]
class SoftSign(ActivationBase):
"""
Elementwise activation function for ``jax.nn.soft_sign``.
See Also
--------
jax.nn.soft_sign : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SoftSign"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.soft_sign(x, *self.args, **self.kwargs)
[docs]
class SiLU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.silu``.
See Also
--------
jax.nn.silu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SiLU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.silu(x, *self.args, **self.kwargs)
[docs]
class Swish(ActivationBase):
"""
Elementwise activation function for ``jax.nn.swish``.
See Also
--------
jax.nn.swish : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Swish"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.swish(x, *self.args, **self.kwargs)
[docs]
class LogSigmoid(ActivationBase):
"""
Elementwise activation function for ``jax.nn.log_sigmoid``.
See Also
--------
jax.nn.log_sigmoid : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "LogSigmoid"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.log_sigmoid(x, *self.args, **self.kwargs)
[docs]
class LeakyReLU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.leaky_relu``.
See Also
--------
jax.nn.leaky_relu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "LeakyReLU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.leaky_relu(x, *self.args, **self.kwargs)
[docs]
class HardSigmoid(ActivationBase):
"""
Elementwise activation function for ``jax.nn.hard_sigmoid``.
See Also
--------
jax.nn.hard_sigmoid : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "HardSigmoid"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.hard_sigmoid(x, *self.args, **self.kwargs)
[docs]
class HardSiLU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.hard_silu``.
See Also
--------
jax.nn.hard_silu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "HardSiLU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.hard_silu(x, *self.args, **self.kwargs)
[docs]
class HardSwish(ActivationBase):
"""
Elementwise activation function for ``jax.nn.hard_swish``.
See Also
--------
jax.nn.hard_swish : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "HardSwish"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.hard_swish(x, *self.args, **self.kwargs)
[docs]
class HardTanh(ActivationBase):
"""
Elementwise activation function for ``jax.nn.hard_tanh``.
See Also
--------
jax.nn.hard_tanh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "HardTanh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.hard_tanh(x, *self.args, **self.kwargs)
[docs]
class ELU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.elu``.
See Also
--------
jax.nn.elu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ELU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.elu(x, *self.args, **self.kwargs)
[docs]
class CELU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.celu``.
See Also
--------
jax.nn.celu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "CELU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.celu(x, *self.args, **self.kwargs)
[docs]
class SELU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.selu``.
See Also
--------
jax.nn.selu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SELU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.selu(x, *self.args, **self.kwargs)
[docs]
class GELU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.gelu``.
See Also
--------
jax.nn.gelu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "GELU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.gelu(x, *self.args, **self.kwargs)
[docs]
class GLU(ActivationBase):
"""
Elementwise activation function for ``jax.nn.glu``.
See Also
--------
jax.nn.glu : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "GLU"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.glu(x, *self.args, **self.kwargs)
[docs]
class SquarePlus(ActivationBase):
"""
Elementwise activation function for ``jax.nn.squareplus``.
See Also
--------
jax.nn.squareplus : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "SquarePlus"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.squareplus(x, *self.args, **self.kwargs)
[docs]
class Mish(ActivationBase):
"""
Elementwise activation function for ``jax.nn.mish``.
See Also
--------
jax.nn.mish : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Mish"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.mish(x, *self.args, **self.kwargs)
[docs]
class Identity(ActivationBase):
"""
Elementwise activation function for ``jax.nn.identity``.
See Also
--------
jax.nn.identity : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Identity"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.identity(x, *self.args, **self.kwargs)
[docs]
class Softmax(ActivationBase):
"""
Elementwise activation function for ``jax.nn.softmax``.
See Also
--------
jax.nn.softmax : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Softmax"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.softmax(x, *self.args, **self.kwargs)
[docs]
class LogSoftmax(ActivationBase):
"""
Elementwise activation function for ``jax.nn.log_softmax``.
See Also
--------
jax.nn.log_softmax : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "LogSoftmax"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.log_softmax(x, *self.args, **self.kwargs)
[docs]
class LogSumExp(ActivationBase):
"""
Elementwise activation function for ``jax.nn.logsumexp``.
See Also
--------
jax.nn.logsumexp : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "LogSumExp"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.logsumexp(x, *self.args, **self.kwargs)
[docs]
class Standardize(ActivationBase):
"""
Elementwise activation function for ``jax.nn.standardize``.
See Also
--------
jax.nn.standardize : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Standardize"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.standardize(x, *self.args, **self.kwargs)
[docs]
class OneHot(ActivationBase):
"""
Elementwise activation function for ``jax.nn.one_hot``.
See Also
--------
jax.nn.one_hot : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "OneHot"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.nn.one_hot(x, *self.args, **self.kwargs)
[docs]
class Abs(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.abs``.
See Also
--------
jax.numpy.abs : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Abs"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.abs(x, *self.args, **self.kwargs)
[docs]
class Absolute(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.absolute``.
See Also
--------
jax.numpy.absolute : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Absolute"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.absolute(x, *self.args, **self.kwargs)
[docs]
class ACos(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.acos``.
See Also
--------
jax.numpy.acos : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ACos"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.acos(x, *self.args, **self.kwargs)
[docs]
class ACosh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.acosh``.
See Also
--------
jax.numpy.acosh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ACosh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.acosh(x, *self.args, **self.kwargs)
[docs]
class AMax(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.amax``.
See Also
--------
jax.numpy.amax : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "AMax"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.amax(x, *self.args, **self.kwargs)
[docs]
class AMin(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.amin``.
See Also
--------
jax.numpy.amin : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "AMin"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.amin(x, *self.args, **self.kwargs)
[docs]
class Angle(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.angle``.
See Also
--------
jax.numpy.angle : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Angle"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.angle(x, *self.args, **self.kwargs)
[docs]
class ArcCos(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arccos``.
See Also
--------
jax.numpy.arccos : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcCos"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arccos(x, *self.args, **self.kwargs)
[docs]
class ArcCosh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arccosh``.
See Also
--------
jax.numpy.arccosh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcCosh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arccosh(x, *self.args, **self.kwargs)
[docs]
class ArcSin(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arcsin``.
See Also
--------
jax.numpy.arcsin : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcSin"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arcsin(x, *self.args, **self.kwargs)
[docs]
class ArcSinh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arcsinh``.
See Also
--------
jax.numpy.arcsinh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcSinh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arcsinh(x, *self.args, **self.kwargs)
[docs]
class ArcTan(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arctan``.
See Also
--------
jax.numpy.arctan : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcTan"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arctan(x, *self.args, **self.kwargs)
[docs]
class ArcTan2(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arctan2``.
See Also
--------
jax.numpy.arctan2 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcTan2"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arctan2(x, *self.args, **self.kwargs)
[docs]
class ArcTanh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.arctanh``.
See Also
--------
jax.numpy.arctanh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ArcTanh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.arctanh(x, *self.args, **self.kwargs)
[docs]
class ASin(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.asin``.
See Also
--------
jax.numpy.asin : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ASin"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.asin(x, *self.args, **self.kwargs)
[docs]
class ASinh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.asinh``.
See Also
--------
jax.numpy.asinh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ASinh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.asinh(x, *self.args, **self.kwargs)
[docs]
class ATan(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.atan``.
See Also
--------
jax.numpy.atan : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ATan"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.atan(x, *self.args, **self.kwargs)
[docs]
class ATanh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.atanh``.
See Also
--------
jax.numpy.atanh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "ATanh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.atanh(x, *self.args, **self.kwargs)
[docs]
class Cbrt(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.cbrt``.
See Also
--------
jax.numpy.cbrt : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Cbrt"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.cbrt(x, *self.args, **self.kwargs)
[docs]
class Ceil(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.ceil``.
See Also
--------
jax.numpy.ceil : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Ceil"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.ceil(x, *self.args, **self.kwargs)
[docs]
class Clip(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.clip``.
See Also
--------
jax.numpy.clip : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Clip"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.clip(x, *self.args, **self.kwargs)
[docs]
class Conj(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.conj``.
See Also
--------
jax.numpy.conj : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Conj"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.conj(x, *self.args, **self.kwargs)
[docs]
class Conjugate(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.conjugate``.
See Also
--------
jax.numpy.conjugate : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Conjugate"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.conjugate(x, *self.args, **self.kwargs)
[docs]
class Cos(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.cos``.
See Also
--------
jax.numpy.cos : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Cos"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.cos(x, *self.args, **self.kwargs)
[docs]
class Cosh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.cosh``.
See Also
--------
jax.numpy.cosh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Cosh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.cosh(x, *self.args, **self.kwargs)
[docs]
class Deg2Rad(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.deg2rad``.
See Also
--------
jax.numpy.deg2rad : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Deg2Rad"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.deg2rad(x, *self.args, **self.kwargs)
[docs]
class Degrees(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.degrees``.
See Also
--------
jax.numpy.degrees : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Degrees"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.degrees(x, *self.args, **self.kwargs)
[docs]
class Exp(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.exp``.
See Also
--------
jax.numpy.exp : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Exp"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.exp(x, *self.args, **self.kwargs)
[docs]
class Exp2(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.exp2``.
See Also
--------
jax.numpy.exp2 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Exp2"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.exp2(x, *self.args, **self.kwargs)
[docs]
class Expm1(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.expm1``.
See Also
--------
jax.numpy.expm1 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Expm1"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.expm1(x, *self.args, **self.kwargs)
[docs]
class FAbs(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.fabs``.
See Also
--------
jax.numpy.fabs : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "FAbs"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.fabs(x, *self.args, **self.kwargs)
[docs]
class Fix(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.fix``.
See Also
--------
jax.numpy.fix : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Fix"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.fix(x, *self.args, **self.kwargs)
[docs]
class FloatPower(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.float_power``.
See Also
--------
jax.numpy.float_power : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "FloatPower"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.float_power(x, *self.args, **self.kwargs)
[docs]
class Floor(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.floor``.
See Also
--------
jax.numpy.floor : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Floor"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.floor(x, *self.args, **self.kwargs)
[docs]
class FloorDivide(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.floor_divide``.
See Also
--------
jax.numpy.floor_divide : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "FloorDivide"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.floor_divide(x, *self.args, **self.kwargs)
[docs]
class FrExp(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.frexp``.
See Also
--------
jax.numpy.frexp : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "FrExp"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.frexp(x, *self.args, **self.kwargs)
[docs]
class I0(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.i0``.
See Also
--------
jax.numpy.i0 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "I0"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.i0(x, *self.args, **self.kwargs)
[docs]
class Imag(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.imag``.
See Also
--------
jax.numpy.imag : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Imag"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.imag(x, *self.args, **self.kwargs)
[docs]
class Invert(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.invert``.
See Also
--------
jax.numpy.invert : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Invert"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.invert(x, *self.args, **self.kwargs)
[docs]
class LDExp(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.ldexp``.
See Also
--------
jax.numpy.ldexp : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "LDExp"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.ldexp(x, *self.args, **self.kwargs)
[docs]
class Log(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.log``.
See Also
--------
jax.numpy.log : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Log"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.log(x, *self.args, **self.kwargs)
[docs]
class Log10(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.log10``.
See Also
--------
jax.numpy.log10 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Log10"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.log10(x, *self.args, **self.kwargs)
[docs]
class Log1p(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.log1p``.
See Also
--------
jax.numpy.log1p : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Log1p"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.log1p(x, *self.args, **self.kwargs)
[docs]
class Log2(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.log2``.
See Also
--------
jax.numpy.log2 : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Log2"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.log2(x, *self.args, **self.kwargs)
[docs]
class NaNToNum(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.nan_to_num``.
See Also
--------
jax.numpy.nan_to_num : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "NaNToNum"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.nan_to_num(x, *self.args, **self.kwargs)
[docs]
class NanToNum(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.nan_to_num``.
See Also
--------
jax.numpy.nan_to_num : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "NanToNum"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.nan_to_num(x, *self.args, **self.kwargs)
[docs]
class NextAfter(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.nextafter``.
See Also
--------
jax.numpy.nextafter : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "NextAfter"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.nextafter(x, *self.args, **self.kwargs)
[docs]
class Packbits(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.packbits``.
See Also
--------
jax.numpy.packbits : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Packbits"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.packbits(x, *self.args, **self.kwargs)
[docs]
class Piecewise(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.piecewise``.
See Also
--------
jax.numpy.piecewise : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Piecewise"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.piecewise(x, *self.args, **self.kwargs)
[docs]
class Positive(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.positive``.
See Also
--------
jax.numpy.positive : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Positive"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.positive(x, *self.args, **self.kwargs)
[docs]
class Pow(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.pow``.
See Also
--------
jax.numpy.pow : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Pow"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.pow(x, *self.args, **self.kwargs)
[docs]
class Power(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.power``.
See Also
--------
jax.numpy.power : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Power"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.power(x, *self.args, **self.kwargs)
[docs]
class Rad2Deg(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.rad2deg``.
See Also
--------
jax.numpy.rad2deg : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Rad2Deg"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.rad2deg(x, *self.args, **self.kwargs)
[docs]
class Radians(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.radians``.
See Also
--------
jax.numpy.radians : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Radians"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.radians(x, *self.args, **self.kwargs)
[docs]
class Real(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.real``.
See Also
--------
jax.numpy.real : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Real"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.real(x, *self.args, **self.kwargs)
[docs]
class Reciprocal(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.reciprocal``.
See Also
--------
jax.numpy.reciprocal : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Reciprocal"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.reciprocal(x, *self.args, **self.kwargs)
[docs]
class RInt(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.rint``.
See Also
--------
jax.numpy.rint : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "RInt"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.rint(x, *self.args, **self.kwargs)
[docs]
class Round(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.round``.
See Also
--------
jax.numpy.round : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Round"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.round(x, *self.args, **self.kwargs)
[docs]
class Sign(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.sign``.
See Also
--------
jax.numpy.sign : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sign"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.sign(x, *self.args, **self.kwargs)
[docs]
class Signbit(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.signbit``.
See Also
--------
jax.numpy.signbit : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Signbit"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.signbit(x, *self.args, **self.kwargs)
[docs]
class Sin(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.sin``.
See Also
--------
jax.numpy.sin : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sin"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.sin(x, *self.args, **self.kwargs)
[docs]
class Sinc(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.sinc``.
See Also
--------
jax.numpy.sinc : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sinc"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.sinc(x, *self.args, **self.kwargs)
[docs]
class Sinh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.sinh``.
See Also
--------
jax.numpy.sinh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sinh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.sinh(x, *self.args, **self.kwargs)
[docs]
class Sqrt(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.sqrt``.
See Also
--------
jax.numpy.sqrt : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Sqrt"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.sqrt(x, *self.args, **self.kwargs)
[docs]
class Square(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.square``.
See Also
--------
jax.numpy.square : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Square"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.square(x, *self.args, **self.kwargs)
[docs]
class Tan(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.tan``.
See Also
--------
jax.numpy.tan : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Tan"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.tan(x, *self.args, **self.kwargs)
[docs]
class Tanh(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.tanh``.
See Also
--------
jax.numpy.tanh : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Tanh"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.tanh(x, *self.args, **self.kwargs)
[docs]
class Trunc(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.trunc``.
See Also
--------
jax.numpy.trunc : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Trunc"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.trunc(x, *self.args, **self.kwargs)
[docs]
class Unpackbits(ActivationBase):
"""
Elementwise activation function for ``jax.numpy.unpackbits``.
See Also
--------
jax.numpy.unpackbits : The function used for the elementwise activation.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def name(self) -> str:
return "Unpackbits"
[docs]
def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray:
return jax.numpy.unpackbits(x, *self.args, **self.kwargs)
# This file is autogenerated. Do not edit manually.