import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import jaxtyped
from ..tree_util import is_shape_leaf
from ..typing import (
Any,
ArrayData,
ArrayDataShape,
Data,
DataShape,
HyperParams,
ModuleCallable,
Params,
State,
Tuple,
)
from .basemodule import BaseModule
[docs]
class Reshape(BaseModule):
"""
Module that reshapes the input array to a specified shape. Ignores the
batch dimension.
"""
__version__: str = "0.0.0"
[docs]
def __init__(self, shape: DataShape = None) -> None:
"""
Initialize a ``Reshape`` module.
Parameters
----------
shape
The target shape to reshape the input to, by default None.
If None, the input shape will remain unchanged.
Does not include the batch dimension. If the input to the module is
a PyTree, then ``shape`` should be a PyTree of matching structure.
Any ``None`` values in the PyTree will leave the corresponding leaf
arrays unchanged.
Examples
--------
.. code-block:: python
# Prepare to accept only bare array data (no PyTrees) and reshape
# to (2, 3)
reshape_module = Reshape(shape=(2, 3))
# Prepare to accept a PyTree of arrays with structure [*, (*, *)]
# and reshape the first leaf to (2, 3), leave the second leaf
# unchanged, and flatten the final leaf
reshape_module = Reshape(shape=[(2, 3), (None, (-1,))])
"""
# validate shape
# unless it is entirely an iterable of ints, none of the elements can
# be ints
if shape is None:
self.shape = shape
return
try:
len(shape)
except TypeError:
raise AssertionError(
"Shape must be a tuple, list, or PyTree of shapes."
)
if all(isinstance(dim, int) for dim in shape):
# shape is just an iterable of ints
pass
elif any(isinstance(dim, int) for dim in shape):
# shape is a PyTree, but not all the of the leaves are shapes
# (iterables themselves)
# e.g. shape = [(2, 3), 2, (4, 5)], the second element (2) is
# invalid and should be (2,) instead
raise TypeError(
"If shape is a PyTree, all leaves must be shapes "
"(iterables of ints)."
)
# at this point shape is either an iterable of ints, or a PyTree of
# shapes or Nones
self.shape = shape
@property
def name(self) -> str:
return f"Reshape(shape={self.shape})"
[docs]
def is_ready(self) -> bool:
return True
[docs]
def get_num_trainable_floats(self) -> int | None:
return 0
[docs]
def _get_callable(
self,
) -> ModuleCallable:
@jaxtyped(typechecker=beartype)
def reshape_array(
arr: ArrayData,
shape: ArrayDataShape | None,
) -> ArrayData:
batch_dim = arr.shape[0]
if shape is None:
return arr
else:
return np.reshape(arr, (batch_dim, *shape))
@jaxtyped(typechecker=beartype)
def reshape_callable(
params: Params,
data: Data,
training: bool,
state: State,
rng: Any,
) -> Tuple[Data, State]:
if self.shape is None:
return data, state
else:
reshaped_data = jax.tree.map(
reshape_array,
data,
self.shape,
)
return reshaped_data, state
return reshape_callable
[docs]
def validate_and_concretify_shape(
self, input_shape: DataShape
) -> DataShape:
# check that input_shape and self.shape are compatible
if self.shape is None:
return input_shape
assert input_shape is not None, "Input shape must not be None."
assert not (None in input_shape), "Input shape must not contain None."
try:
len(input_shape)
except TypeError:
raise TypeError(
"Input shape must be a tuple, list, or PyTree of shapes."
)
promoted = False
# if input_shape is an iterable of ints, convert to a single-element
# PyTree for consistency
if all(isinstance(dim, int) for dim in input_shape):
input_shape = (input_shape,)
promoted = True
# same for self.shape
if all(isinstance(dim, int) for dim in self.shape):
selfshape = (self.shape,)
else:
selfshape = self.shape
input_struct = jax.tree.structure(input_shape, is_leaf=is_shape_leaf)
shape_struct = jax.tree.structure(selfshape, is_leaf=is_shape_leaf)
assert input_struct == shape_struct, (
f"Input shape structure {input_struct} does not match target shape"
f" structure {shape_struct}"
)
def check_compatibility_and_concretify(
in_shape: ArrayDataShape,
target_shape: ArrayDataShape | None,
) -> ArrayDataShape:
if target_shape is None:
return
in_size = np.prod(np.array(in_shape)).item()
if -1 in target_shape:
# make sure there is only one -1
assert (
target_shape.count(-1) == 1
), "Target shape can only contain one -1 dimension"
# infer the size of the -1 dimension
known_size = 1
for dim in target_shape:
if dim != -1:
known_size *= dim
inferred_dim = in_size // known_size
target_size = known_size * inferred_dim
else:
target_size = np.prod(np.array(target_shape)).item()
assert in_size == target_size, (
f"Input shape {in_shape} is not compatible with target shape"
f" {target_shape}"
)
if -1 in target_shape:
# replace -1 with inferred dimension
return tuple(
inferred_dim if dim == -1 else dim for dim in target_shape
)
else:
return target_shape
concrete_shape = jax.tree.map(
check_compatibility_and_concretify,
input_shape,
selfshape,
is_leaf=is_shape_leaf,
)
if promoted:
return concrete_shape[0]
else:
return concrete_shape
[docs]
def compile(self, rng: Any, input_shape: DataShape) -> None:
self.validate_and_concretify_shape(input_shape)
[docs]
def get_output_shape(self, input_shape: DataShape) -> DataShape:
return self.validate_and_concretify_shape(input_shape)
[docs]
def get_hyperparameters(self) -> HyperParams:
return {
"shape": self.shape,
}
[docs]
def get_params(self) -> Params:
return ()
[docs]
def set_params(self, params: Params) -> None:
pass