Source code for parametricmatrixmodels.modules.treeflatten

import jax
from beartype import beartype
from jaxtyping import jaxtyped

from ..tree_util import is_shape_leaf
from ..typing import (
    Any,
    Data,
    DataShape,
    HyperParams,
    ModuleCallable,
    Params,
    State,
    Tuple,
)
from .basemodule import BaseModule


[docs] class TreeFlatten(BaseModule): """ Module that flattens an input tree of arrays into a list of arrays """
[docs] def __init__(self) -> None: pass
@property def name(self) -> str: return "TreeFlatten"
[docs] def compile(self, rng: Any, input_shape: DataShape) -> None: pass
[docs] def is_ready(self) -> bool: return True
[docs] def _get_callable(self) -> ModuleCallable: @jaxtyped(typechecker=beartype) def flatten_tree_callable( params: Params, data: Data, training: bool, state: State, rng: Any ) -> Tuple[Data, State]: return jax.tree.leaves(data), state return flatten_tree_callable
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: return jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)
[docs] def get_hyperparameters(self) -> HyperParams: return {}
[docs] def get_params(self) -> Params: return ()
[docs] def set_params(self, params: Params) -> None: pass