Source code for parametricmatrixmodels.modules.comment
from __future__ import annotations
from typing import Any, Callable
import jax.numpy as np
from .basemodule import BaseModule
[docs]
class Comment(BaseModule):
"""
A module that allows adding comments to ``Model`` summaries.
"""
[docs]
def __init__(self, comment: str = None) -> None:
"""
Create a ``Comment`` module.
Parameters
----------
comment
Comment text to be shown in the ``Model`` summary where this module
is placed.
"""
self.comment = comment
[docs]
def _get_callable(
self,
) -> Callable[
[
tuple[np.ndarray, ...],
np.ndarray,
bool,
tuple[np.ndarray, ...],
Any,
],
tuple[np.ndarray, tuple[np.ndarray, ...]],
]:
return lambda params, input_NF, training, state, rng: (
input_NF, # output is the same as input
state, # state is unchanged
)
[docs]
def get_output_shape(
self, input_shape: tuple[int, ...]
) -> tuple[int, ...]:
return input_shape # output shape is the same as input shape