Source code for parametricmatrixmodels.modules.comment
from parametricmatrixmodels.typing import (
Any,
DataShape,
HyperParams,
ModuleCallable,
Params,
State,
)
from .basemodule import BaseModule
[docs]
class Comment(BaseModule):
"""
A module that allows adding comments to ``Model`` summaries.
"""
[docs]
def __init__(self, comment: str = None) -> None:
"""
Create a ``Comment`` module.
Parameters
----------
comment
Comment text to be shown in the ``Model`` summary where this module
is placed.
"""
self.comment = comment
@property
def name(self) -> str:
return f"# {self.comment}" if self.comment else "#"
[docs]
def _get_callable(
self,
) -> ModuleCallable:
return lambda params, data, training, state, rng: (
data, # output is the same as input
state, # state is unchanged
)
[docs]
def get_output_shape(self, input_shape: DataShape) -> DataShape:
return input_shape # output shape is the same as input shape