Source code for parametricmatrixmodels.modules.flatten

import jax

from ..tree_util import is_shape_leaf
from ..typing import (
    Any,
    DataShape,
)
from .reshape import Reshape


[docs] class Flatten(Reshape): """ Module that flattens the input to 1D. Ignores the batch dimension. Operates on all leafs of a PyTree input. """
[docs] def __init__(self) -> None: # initialize with shape=None, shape will be determined at compile time # this accounts for both array and PyTree inputs super().__init__(shape=None)
@property def name(self) -> str: return "Flatten"
[docs] def compile(self, rng: Any, input_shape: DataShape) -> None: try: len(input_shape) except TypeError: raise TypeError( "Input shape must be a tuple, list, or PyTree of shapes." ) # if input_shape is an iterable of ints, then the input is a single # array if all(isinstance(dim, int) for dim in input_shape): self.shape = (-1,) return # construct the tree of output shapes (all (-1,)) input_shapes, input_struct = jax.tree.flatten( input_shape, is_leaf=is_shape_leaf ) self.shape = jax.tree.unflatten( input_struct, [(-1,) for _ in input_shapes], )
[docs] def get_output_shape(self, input_shape: DataShape) -> DataShape: # if input_shape is an iterable of ints, then the input is a single # array if all(isinstance(dim, int) for dim in input_shape): self.shape = (-1,) return super().get_output_shape(input_shape) # construct the tree of output shapes (all (-1,)) input_shapes, input_struct = jax.tree.flatten( input_shape, is_leaf=is_shape_leaf ) self.shape = jax.tree.unflatten( input_struct, [(-1,) for _ in input_shapes], ) return super().get_output_shape(input_shape)