Source code for parametricmatrixmodels.modules.reshape

from __future__ import annotations

from typing import Any, Callable

import jax.numpy as np

from .basemodule import BaseModule


[docs] class Reshape(BaseModule): """ Module that reshapes the input array to a specified shape. Ignores the batch dimension. """
[docs] def __init__(self, shape: tuple[int, ...] = None) -> None: """ Parameters ---------- shape The target shape to reshape the input to, by default None. If None, the input shape will remain unchanged. Does not include the batch dimension. """ self.shape = shape
[docs] def name(self) -> str: return f"Reshape(shape={self.shape})"
[docs] def is_ready(self) -> bool: return True
[docs] def get_num_trainable_floats(self) -> int | None: return 0
[docs] def _get_callable( self, ) -> Callable[ [ tuple[np.ndarray, ...], np.ndarray, bool, tuple[np.ndarray, ...], Any, ], tuple[np.ndarray, tuple[np.ndarray, ...]], ]: return lambda params, input_NF, training, state, rng: ( ( input_NF.reshape(input_NF.shape[0], *self.shape) if self.shape else input_NF ), state, # state is unchanged )
[docs] def compile(self, rng: Any, input_shape: tuple[int, ...]) -> None: pass
[docs] def get_output_shape( self, input_shape: tuple[int, ...] ) -> tuple[int, ...]: # handle the special cases where self.shape is None or (-1,) if self.shape is None: return input_shape elif self.shape == (-1,): return (np.prod(np.array(input_shape)).item(),) else: return self.shape
[docs] def get_hyperparameters(self) -> dict[str, Any]: return { "shape": self.shape, }
[docs] def get_params(self) -> tuple[np.ndarray, ...]: return ()
[docs] def set_params(self, params: tuple[np.ndarray, ...]) -> None: pass