Source code for parametricmatrixmodels.modules._autogenerated_activationfuncs

from __future__ import annotations

import jax

from .activationbase import ActivationBase


[docs] class ReLU(ActivationBase): """ Elementwise activation function for ``jax.nn.relu``. See Also -------- jax.nn.relu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ReLU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.relu(x, *self.args, **self.kwargs)
[docs] class ReLU6(ActivationBase): """ Elementwise activation function for ``jax.nn.relu6``. See Also -------- jax.nn.relu6 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ReLU6"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.relu6(x, *self.args, **self.kwargs)
[docs] class Sigmoid(ActivationBase): """ Elementwise activation function for ``jax.nn.sigmoid``. See Also -------- jax.nn.sigmoid : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sigmoid"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.sigmoid(x, *self.args, **self.kwargs)
[docs] class Softplus(ActivationBase): """ Elementwise activation function for ``jax.nn.softplus``. See Also -------- jax.nn.softplus : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Softplus"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.softplus(x, *self.args, **self.kwargs)
[docs] class SparsePlus(ActivationBase): """ Elementwise activation function for ``jax.nn.sparse_plus``. See Also -------- jax.nn.sparse_plus : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SparsePlus"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.sparse_plus(x, *self.args, **self.kwargs)
[docs] class SparseSigmoid(ActivationBase): """ Elementwise activation function for ``jax.nn.sparse_sigmoid``. See Also -------- jax.nn.sparse_sigmoid : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SparseSigmoid"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.sparse_sigmoid(x, *self.args, **self.kwargs)
[docs] class SoftSign(ActivationBase): """ Elementwise activation function for ``jax.nn.soft_sign``. See Also -------- jax.nn.soft_sign : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SoftSign"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.soft_sign(x, *self.args, **self.kwargs)
[docs] class SiLU(ActivationBase): """ Elementwise activation function for ``jax.nn.silu``. See Also -------- jax.nn.silu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SiLU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.silu(x, *self.args, **self.kwargs)
[docs] class Swish(ActivationBase): """ Elementwise activation function for ``jax.nn.swish``. See Also -------- jax.nn.swish : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Swish"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.swish(x, *self.args, **self.kwargs)
[docs] class LogSigmoid(ActivationBase): """ Elementwise activation function for ``jax.nn.log_sigmoid``. See Also -------- jax.nn.log_sigmoid : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "LogSigmoid"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.log_sigmoid(x, *self.args, **self.kwargs)
[docs] class LeakyReLU(ActivationBase): """ Elementwise activation function for ``jax.nn.leaky_relu``. See Also -------- jax.nn.leaky_relu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "LeakyReLU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.leaky_relu(x, *self.args, **self.kwargs)
[docs] class HardSigmoid(ActivationBase): """ Elementwise activation function for ``jax.nn.hard_sigmoid``. See Also -------- jax.nn.hard_sigmoid : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "HardSigmoid"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.hard_sigmoid(x, *self.args, **self.kwargs)
[docs] class HardSiLU(ActivationBase): """ Elementwise activation function for ``jax.nn.hard_silu``. See Also -------- jax.nn.hard_silu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "HardSiLU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.hard_silu(x, *self.args, **self.kwargs)
[docs] class HardSwish(ActivationBase): """ Elementwise activation function for ``jax.nn.hard_swish``. See Also -------- jax.nn.hard_swish : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "HardSwish"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.hard_swish(x, *self.args, **self.kwargs)
[docs] class HardTanh(ActivationBase): """ Elementwise activation function for ``jax.nn.hard_tanh``. See Also -------- jax.nn.hard_tanh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "HardTanh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.hard_tanh(x, *self.args, **self.kwargs)
[docs] class ELU(ActivationBase): """ Elementwise activation function for ``jax.nn.elu``. See Also -------- jax.nn.elu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ELU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.elu(x, *self.args, **self.kwargs)
[docs] class CELU(ActivationBase): """ Elementwise activation function for ``jax.nn.celu``. See Also -------- jax.nn.celu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "CELU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.celu(x, *self.args, **self.kwargs)
[docs] class SELU(ActivationBase): """ Elementwise activation function for ``jax.nn.selu``. See Also -------- jax.nn.selu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SELU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.selu(x, *self.args, **self.kwargs)
[docs] class GELU(ActivationBase): """ Elementwise activation function for ``jax.nn.gelu``. See Also -------- jax.nn.gelu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "GELU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.gelu(x, *self.args, **self.kwargs)
[docs] class GLU(ActivationBase): """ Elementwise activation function for ``jax.nn.glu``. See Also -------- jax.nn.glu : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "GLU"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.glu(x, *self.args, **self.kwargs)
[docs] class SquarePlus(ActivationBase): """ Elementwise activation function for ``jax.nn.squareplus``. See Also -------- jax.nn.squareplus : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "SquarePlus"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.squareplus(x, *self.args, **self.kwargs)
[docs] class Mish(ActivationBase): """ Elementwise activation function for ``jax.nn.mish``. See Also -------- jax.nn.mish : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Mish"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.mish(x, *self.args, **self.kwargs)
[docs] class Identity(ActivationBase): """ Elementwise activation function for ``jax.nn.identity``. See Also -------- jax.nn.identity : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Identity"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.identity(x, *self.args, **self.kwargs)
[docs] class Softmax(ActivationBase): """ Elementwise activation function for ``jax.nn.softmax``. See Also -------- jax.nn.softmax : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Softmax"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.softmax(x, *self.args, **self.kwargs)
[docs] class LogSoftmax(ActivationBase): """ Elementwise activation function for ``jax.nn.log_softmax``. See Also -------- jax.nn.log_softmax : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "LogSoftmax"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.log_softmax(x, *self.args, **self.kwargs)
[docs] class LogSumExp(ActivationBase): """ Elementwise activation function for ``jax.nn.logsumexp``. See Also -------- jax.nn.logsumexp : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "LogSumExp"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.logsumexp(x, *self.args, **self.kwargs)
[docs] class Standardize(ActivationBase): """ Elementwise activation function for ``jax.nn.standardize``. See Also -------- jax.nn.standardize : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Standardize"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.standardize(x, *self.args, **self.kwargs)
[docs] class OneHot(ActivationBase): """ Elementwise activation function for ``jax.nn.one_hot``. See Also -------- jax.nn.one_hot : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "OneHot"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.nn.one_hot(x, *self.args, **self.kwargs)
[docs] class Abs(ActivationBase): """ Elementwise activation function for ``jax.numpy.abs``. See Also -------- jax.numpy.abs : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Abs"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.abs(x, *self.args, **self.kwargs)
[docs] class Absolute(ActivationBase): """ Elementwise activation function for ``jax.numpy.absolute``. See Also -------- jax.numpy.absolute : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Absolute"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.absolute(x, *self.args, **self.kwargs)
[docs] class ACos(ActivationBase): """ Elementwise activation function for ``jax.numpy.acos``. See Also -------- jax.numpy.acos : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ACos"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.acos(x, *self.args, **self.kwargs)
[docs] class ACosh(ActivationBase): """ Elementwise activation function for ``jax.numpy.acosh``. See Also -------- jax.numpy.acosh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ACosh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.acosh(x, *self.args, **self.kwargs)
[docs] class AMax(ActivationBase): """ Elementwise activation function for ``jax.numpy.amax``. See Also -------- jax.numpy.amax : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "AMax"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.amax(x, *self.args, **self.kwargs)
[docs] class AMin(ActivationBase): """ Elementwise activation function for ``jax.numpy.amin``. See Also -------- jax.numpy.amin : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "AMin"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.amin(x, *self.args, **self.kwargs)
[docs] class Angle(ActivationBase): """ Elementwise activation function for ``jax.numpy.angle``. See Also -------- jax.numpy.angle : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Angle"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.angle(x, *self.args, **self.kwargs)
[docs] class ArcCos(ActivationBase): """ Elementwise activation function for ``jax.numpy.arccos``. See Also -------- jax.numpy.arccos : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcCos"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arccos(x, *self.args, **self.kwargs)
[docs] class ArcCosh(ActivationBase): """ Elementwise activation function for ``jax.numpy.arccosh``. See Also -------- jax.numpy.arccosh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcCosh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arccosh(x, *self.args, **self.kwargs)
[docs] class ArcSin(ActivationBase): """ Elementwise activation function for ``jax.numpy.arcsin``. See Also -------- jax.numpy.arcsin : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcSin"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arcsin(x, *self.args, **self.kwargs)
[docs] class ArcSinh(ActivationBase): """ Elementwise activation function for ``jax.numpy.arcsinh``. See Also -------- jax.numpy.arcsinh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcSinh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arcsinh(x, *self.args, **self.kwargs)
[docs] class ArcTan(ActivationBase): """ Elementwise activation function for ``jax.numpy.arctan``. See Also -------- jax.numpy.arctan : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcTan"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arctan(x, *self.args, **self.kwargs)
[docs] class ArcTan2(ActivationBase): """ Elementwise activation function for ``jax.numpy.arctan2``. See Also -------- jax.numpy.arctan2 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcTan2"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arctan2(x, *self.args, **self.kwargs)
[docs] class ArcTanh(ActivationBase): """ Elementwise activation function for ``jax.numpy.arctanh``. See Also -------- jax.numpy.arctanh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ArcTanh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.arctanh(x, *self.args, **self.kwargs)
[docs] class ASin(ActivationBase): """ Elementwise activation function for ``jax.numpy.asin``. See Also -------- jax.numpy.asin : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ASin"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.asin(x, *self.args, **self.kwargs)
[docs] class ASinh(ActivationBase): """ Elementwise activation function for ``jax.numpy.asinh``. See Also -------- jax.numpy.asinh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ASinh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.asinh(x, *self.args, **self.kwargs)
[docs] class ATan(ActivationBase): """ Elementwise activation function for ``jax.numpy.atan``. See Also -------- jax.numpy.atan : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ATan"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.atan(x, *self.args, **self.kwargs)
[docs] class ATanh(ActivationBase): """ Elementwise activation function for ``jax.numpy.atanh``. See Also -------- jax.numpy.atanh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "ATanh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.atanh(x, *self.args, **self.kwargs)
[docs] class Cbrt(ActivationBase): """ Elementwise activation function for ``jax.numpy.cbrt``. See Also -------- jax.numpy.cbrt : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Cbrt"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.cbrt(x, *self.args, **self.kwargs)
[docs] class Ceil(ActivationBase): """ Elementwise activation function for ``jax.numpy.ceil``. See Also -------- jax.numpy.ceil : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Ceil"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.ceil(x, *self.args, **self.kwargs)
[docs] class Clip(ActivationBase): """ Elementwise activation function for ``jax.numpy.clip``. See Also -------- jax.numpy.clip : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Clip"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.clip(x, *self.args, **self.kwargs)
[docs] class Conj(ActivationBase): """ Elementwise activation function for ``jax.numpy.conj``. See Also -------- jax.numpy.conj : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Conj"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.conj(x, *self.args, **self.kwargs)
[docs] class Conjugate(ActivationBase): """ Elementwise activation function for ``jax.numpy.conjugate``. See Also -------- jax.numpy.conjugate : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Conjugate"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.conjugate(x, *self.args, **self.kwargs)
[docs] class Cos(ActivationBase): """ Elementwise activation function for ``jax.numpy.cos``. See Also -------- jax.numpy.cos : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Cos"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.cos(x, *self.args, **self.kwargs)
[docs] class Cosh(ActivationBase): """ Elementwise activation function for ``jax.numpy.cosh``. See Also -------- jax.numpy.cosh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Cosh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.cosh(x, *self.args, **self.kwargs)
[docs] class Deg2Rad(ActivationBase): """ Elementwise activation function for ``jax.numpy.deg2rad``. See Also -------- jax.numpy.deg2rad : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Deg2Rad"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.deg2rad(x, *self.args, **self.kwargs)
[docs] class Degrees(ActivationBase): """ Elementwise activation function for ``jax.numpy.degrees``. See Also -------- jax.numpy.degrees : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Degrees"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.degrees(x, *self.args, **self.kwargs)
[docs] class Exp(ActivationBase): """ Elementwise activation function for ``jax.numpy.exp``. See Also -------- jax.numpy.exp : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Exp"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.exp(x, *self.args, **self.kwargs)
[docs] class Exp2(ActivationBase): """ Elementwise activation function for ``jax.numpy.exp2``. See Also -------- jax.numpy.exp2 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Exp2"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.exp2(x, *self.args, **self.kwargs)
[docs] class Expm1(ActivationBase): """ Elementwise activation function for ``jax.numpy.expm1``. See Also -------- jax.numpy.expm1 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Expm1"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.expm1(x, *self.args, **self.kwargs)
[docs] class FAbs(ActivationBase): """ Elementwise activation function for ``jax.numpy.fabs``. See Also -------- jax.numpy.fabs : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "FAbs"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.fabs(x, *self.args, **self.kwargs)
[docs] class Fix(ActivationBase): """ Elementwise activation function for ``jax.numpy.fix``. See Also -------- jax.numpy.fix : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Fix"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.fix(x, *self.args, **self.kwargs)
[docs] class FloatPower(ActivationBase): """ Elementwise activation function for ``jax.numpy.float_power``. See Also -------- jax.numpy.float_power : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "FloatPower"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.float_power(x, *self.args, **self.kwargs)
[docs] class Floor(ActivationBase): """ Elementwise activation function for ``jax.numpy.floor``. See Also -------- jax.numpy.floor : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Floor"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.floor(x, *self.args, **self.kwargs)
[docs] class FloorDivide(ActivationBase): """ Elementwise activation function for ``jax.numpy.floor_divide``. See Also -------- jax.numpy.floor_divide : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "FloorDivide"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.floor_divide(x, *self.args, **self.kwargs)
[docs] class FrExp(ActivationBase): """ Elementwise activation function for ``jax.numpy.frexp``. See Also -------- jax.numpy.frexp : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "FrExp"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.frexp(x, *self.args, **self.kwargs)
[docs] class I0(ActivationBase): """ Elementwise activation function for ``jax.numpy.i0``. See Also -------- jax.numpy.i0 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "I0"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.i0(x, *self.args, **self.kwargs)
[docs] class Imag(ActivationBase): """ Elementwise activation function for ``jax.numpy.imag``. See Also -------- jax.numpy.imag : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Imag"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.imag(x, *self.args, **self.kwargs)
[docs] class Invert(ActivationBase): """ Elementwise activation function for ``jax.numpy.invert``. See Also -------- jax.numpy.invert : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Invert"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.invert(x, *self.args, **self.kwargs)
[docs] class LDExp(ActivationBase): """ Elementwise activation function for ``jax.numpy.ldexp``. See Also -------- jax.numpy.ldexp : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "LDExp"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.ldexp(x, *self.args, **self.kwargs)
[docs] class Log(ActivationBase): """ Elementwise activation function for ``jax.numpy.log``. See Also -------- jax.numpy.log : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Log"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.log(x, *self.args, **self.kwargs)
[docs] class Log10(ActivationBase): """ Elementwise activation function for ``jax.numpy.log10``. See Also -------- jax.numpy.log10 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Log10"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.log10(x, *self.args, **self.kwargs)
[docs] class Log1p(ActivationBase): """ Elementwise activation function for ``jax.numpy.log1p``. See Also -------- jax.numpy.log1p : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Log1p"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.log1p(x, *self.args, **self.kwargs)
[docs] class Log2(ActivationBase): """ Elementwise activation function for ``jax.numpy.log2``. See Also -------- jax.numpy.log2 : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Log2"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.log2(x, *self.args, **self.kwargs)
[docs] class NaNToNum(ActivationBase): """ Elementwise activation function for ``jax.numpy.nan_to_num``. See Also -------- jax.numpy.nan_to_num : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "NaNToNum"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.nan_to_num(x, *self.args, **self.kwargs)
[docs] class NanToNum(ActivationBase): """ Elementwise activation function for ``jax.numpy.nan_to_num``. See Also -------- jax.numpy.nan_to_num : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "NanToNum"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.nan_to_num(x, *self.args, **self.kwargs)
[docs] class NextAfter(ActivationBase): """ Elementwise activation function for ``jax.numpy.nextafter``. See Also -------- jax.numpy.nextafter : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "NextAfter"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.nextafter(x, *self.args, **self.kwargs)
[docs] class Packbits(ActivationBase): """ Elementwise activation function for ``jax.numpy.packbits``. See Also -------- jax.numpy.packbits : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Packbits"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.packbits(x, *self.args, **self.kwargs)
[docs] class Piecewise(ActivationBase): """ Elementwise activation function for ``jax.numpy.piecewise``. See Also -------- jax.numpy.piecewise : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Piecewise"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.piecewise(x, *self.args, **self.kwargs)
[docs] class Positive(ActivationBase): """ Elementwise activation function for ``jax.numpy.positive``. See Also -------- jax.numpy.positive : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Positive"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.positive(x, *self.args, **self.kwargs)
[docs] class Pow(ActivationBase): """ Elementwise activation function for ``jax.numpy.pow``. See Also -------- jax.numpy.pow : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Pow"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.pow(x, *self.args, **self.kwargs)
[docs] class Power(ActivationBase): """ Elementwise activation function for ``jax.numpy.power``. See Also -------- jax.numpy.power : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Power"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.power(x, *self.args, **self.kwargs)
[docs] class Rad2Deg(ActivationBase): """ Elementwise activation function for ``jax.numpy.rad2deg``. See Also -------- jax.numpy.rad2deg : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Rad2Deg"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.rad2deg(x, *self.args, **self.kwargs)
[docs] class Radians(ActivationBase): """ Elementwise activation function for ``jax.numpy.radians``. See Also -------- jax.numpy.radians : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Radians"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.radians(x, *self.args, **self.kwargs)
[docs] class Real(ActivationBase): """ Elementwise activation function for ``jax.numpy.real``. See Also -------- jax.numpy.real : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Real"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.real(x, *self.args, **self.kwargs)
[docs] class Reciprocal(ActivationBase): """ Elementwise activation function for ``jax.numpy.reciprocal``. See Also -------- jax.numpy.reciprocal : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Reciprocal"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.reciprocal(x, *self.args, **self.kwargs)
[docs] class RInt(ActivationBase): """ Elementwise activation function for ``jax.numpy.rint``. See Also -------- jax.numpy.rint : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "RInt"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.rint(x, *self.args, **self.kwargs)
[docs] class Round(ActivationBase): """ Elementwise activation function for ``jax.numpy.round``. See Also -------- jax.numpy.round : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Round"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.round(x, *self.args, **self.kwargs)
[docs] class Sign(ActivationBase): """ Elementwise activation function for ``jax.numpy.sign``. See Also -------- jax.numpy.sign : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sign"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.sign(x, *self.args, **self.kwargs)
[docs] class Signbit(ActivationBase): """ Elementwise activation function for ``jax.numpy.signbit``. See Also -------- jax.numpy.signbit : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Signbit"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.signbit(x, *self.args, **self.kwargs)
[docs] class Sin(ActivationBase): """ Elementwise activation function for ``jax.numpy.sin``. See Also -------- jax.numpy.sin : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sin"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.sin(x, *self.args, **self.kwargs)
[docs] class Sinc(ActivationBase): """ Elementwise activation function for ``jax.numpy.sinc``. See Also -------- jax.numpy.sinc : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sinc"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.sinc(x, *self.args, **self.kwargs)
[docs] class Sinh(ActivationBase): """ Elementwise activation function for ``jax.numpy.sinh``. See Also -------- jax.numpy.sinh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sinh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.sinh(x, *self.args, **self.kwargs)
[docs] class Sqrt(ActivationBase): """ Elementwise activation function for ``jax.numpy.sqrt``. See Also -------- jax.numpy.sqrt : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Sqrt"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.sqrt(x, *self.args, **self.kwargs)
[docs] class Square(ActivationBase): """ Elementwise activation function for ``jax.numpy.square``. See Also -------- jax.numpy.square : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Square"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.square(x, *self.args, **self.kwargs)
[docs] class Tan(ActivationBase): """ Elementwise activation function for ``jax.numpy.tan``. See Also -------- jax.numpy.tan : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Tan"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.tan(x, *self.args, **self.kwargs)
[docs] class Tanh(ActivationBase): """ Elementwise activation function for ``jax.numpy.tanh``. See Also -------- jax.numpy.tanh : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Tanh"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.tanh(x, *self.args, **self.kwargs)
[docs] class Trunc(ActivationBase): """ Elementwise activation function for ``jax.numpy.trunc``. See Also -------- jax.numpy.trunc : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Trunc"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.trunc(x, *self.args, **self.kwargs)
[docs] class Unpackbits(ActivationBase): """ Elementwise activation function for ``jax.numpy.unpackbits``. See Also -------- jax.numpy.unpackbits : The function used for the elementwise activation. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def name(self) -> str: return "Unpackbits"
[docs] def func(self, x: jax.numpy.ndarray) -> jax.numpy.ndarray: return jax.numpy.unpackbits(x, *self.args, **self.kwargs)
# This file is autogenerated. Do not edit manually.