from __future__ import annotations
import warnings
import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import Array, Integer, Num, PyTree, Shaped, jaxtyped
from .typing import Any, Callable, List, Tuple
[docs]
def all_equal(
pytree1: PyTree[Any],
pytree2: PyTree[Any],
) -> bool:
r"""
Check if two pytrees are equal in structure and content.
"""
struct1 = jax.tree.structure(pytree1)
struct2 = jax.tree.structure(pytree2)
if struct1 != struct2:
return False
leaves1 = jax.tree.leaves(pytree1)
leaves2 = jax.tree.leaves(pytree2)
for leaf1, leaf2 in zip(leaves1, leaves2):
if isinstance(leaf1, np.ndarray) and isinstance(leaf2, np.ndarray):
if not np.array_equal(leaf1, leaf2):
return False
else:
if leaf1 != leaf2:
return False
return True
[docs]
def all_close(
pytree1: PyTree[Any],
pytree2: PyTree[Any],
rtol: float = 1e-05,
atol: float = 1e-08,
) -> bool:
r"""
Check if two pytrees are close in structure and content.
"""
struct1 = jax.tree.structure(pytree1)
struct2 = jax.tree.structure(pytree2)
if struct1 != struct2:
return False
leaves1 = jax.tree.leaves(pytree1)
leaves2 = jax.tree.leaves(pytree2)
for leaf1, leaf2 in zip(leaves1, leaves2):
if isinstance(leaf1, np.ndarray) and isinstance(leaf2, np.ndarray):
if not np.allclose(leaf1, leaf2, rtol=rtol, atol=atol):
return False
else:
if leaf1 != leaf2:
return False
return True
[docs]
def get_shapes(
pytree: PyTree[Shaped[Array, "..."]],
axis: int | Tuple[int, ...] | slice | None = None,
) -> PyTree[Tuple[int, ...]]:
r"""
Get the shapes of all leaves in a PyTree of arrays, optionally selecting
specific axes.
"""
def _get_shape(leaf: Shaped[Array, "..."]) -> Tuple[int, ...]:
shape = leaf.shape
if axis is None:
return shape
elif isinstance(axis, slice):
return shape[axis]
else:
if isinstance(axis, int):
tuple_axis = (axis,)
else:
tuple_axis = axis
return tuple(shape[ax] for ax in tuple_axis)
return jax.tree.map(_get_shape, pytree)
[docs]
def is_single_leaf(
pytree: PyTree[Any],
ndim: int | None = None,
shape: Tuple[int, ...] | None = None,
is_leaf: Callable[[Any], bool] | None = None,
) -> Tuple[bool, Any]:
r"""
Check if a pytree consists of a single leaf node, optionally verifying
the leaf's number of dimensions and shape.
"""
if ndim is not None and shape is not None:
if len(shape) != ndim:
raise ValueError(
f"Provided shape {shape} does not match provided ndim {ndim}."
)
leaves = jax.tree.leaves(pytree, is_leaf=is_leaf)
if len(leaves) != 1:
print("Pytree does not have a single leaf.")
return False, None
leaf = leaves[0]
if ndim is not None:
# if the leaf is an array, check its ndim
# otherwise, try to check len(leaf)
if hasattr(leaf, "ndim"):
if leaf.ndim != ndim:
return False, None
else:
if len(leaf) != ndim:
return False, None
if shape is not None:
# if the leaf is an array, check its shape
# otherwise, try to check tuple(leaf)
# if None in shape, skip that dimension
if hasattr(leaf, "shape"):
leaf_shape = leaf.shape
else:
leaf_shape = tuple(leaf)
if len(leaf_shape) != len(shape):
return False, None
for dim1, dim2 in zip(leaf_shape, shape):
if dim2 is not None and dim1 != dim2:
return False, None
return True, leaf
[docs]
def make_mutable(pytree: PyTree[Any]) -> PyTree[Any]:
r"""
Convert all tuples in a pytree to lists for mutability.
Parameters
----------
pytree
The pytree to convert.
Returns
-------
A new pytree with all tuples converted to lists.
"""
def _is_tuple(obj: Any) -> bool:
return isinstance(obj, tuple)
while not jax.tree.all(
jax.tree.map(lambda x: not _is_tuple(x), pytree, is_leaf=_is_tuple)
):
pytree = jax.tree.map(
lambda x: list(x) if _is_tuple(x) else x,
pytree,
is_leaf=_is_tuple,
)
return pytree
[docs]
def getitem_by_strpath(
pytree: PyTree[Any],
strpath: str,
separator: str = ".",
allow_early_return: bool = False,
return_remainder: bool = False,
is_leaf: Callable[Any, bool] | None = None,
) -> Any:
r"""
Get an item from a pytree using a string path. Effectively the inverse of
``jax.tree_util.keystr`` with ``simple=True``. Optionally allows early
return if the path cannot be fully traversed, returning the remaining path
as well.
This function works recursively.
Parameters
----------
pytree
The pytree from which to get the item.
strpath
The string path to the item, with keys/indexes separated by
``separator``.
separator
The separator used in the string path. Default is '.'.
allow_early_return
If True, allows the function to return early if the path cannot be
fully traversed. In this case, returns a tuple of the current pytree
node and the remaining path as a string.
return_remainder
If True, returns a tuple of the found item and the remaining path as a
string, even if the full path was traversed.
is_leaf
Optional function to determine if a node is a leaf.
Returns
-------
The item at the specified path in the pytree.
Examples
--------
>>> pytree = {'a': [1, 2, {'b': 3}], 'c': 4}
>>> getitem_by_strpath(pytree, 'a.2.b')
3
>>> getitem_by_strpath(pytree, 'c', return_remainder=True)
4, ''
>>> pytree = {'a': [1, 2, 3], 'c': 4}
>>> getitem_by_strpath(
... pytree,
... 'a.2.b',
... allow_early_return=True,
... return_remainder=True
... )
(3, 'b')
>>> getitem_by_strpath(
... pytree,
... 'a.2.b.0.c',
... allow_early_return=True,
... return_remainder=True
... )
(3, 'b.0.c')
"""
# Since the strpath is generated by jax.tree_util.keystr with simple=True,
# we have to infer the types of the keys (int or str) when traversing the
# pytree.
if not strpath:
if return_remainder:
return pytree, ""
else:
return pytree
elif is_leaf is not None and is_leaf(pytree):
if allow_early_return:
if return_remainder:
return pytree, strpath
else:
return pytree
else:
raise TypeError(
f"Reached leaf node of type {type(pytree)} before "
f"completing strpath '{strpath}'"
)
else:
key, *rest = strpath.split(separator, 1)
if isinstance(pytree, dict):
next_pytree = pytree[key]
elif isinstance(pytree, (list, tuple)):
if not key.isdigit():
raise KeyError(
f"Expected integer index for list/tuple, got '{key}'"
)
index = int(key)
next_pytree = pytree[index]
elif allow_early_return:
if return_remainder:
return pytree, strpath
else:
return pytree
else:
raise TypeError(f"Unsupported pytree node type: {type(pytree)}")
return getitem_by_strpath(
next_pytree,
rest[0] if rest else "",
separator,
allow_early_return,
return_remainder,
is_leaf,
)
[docs]
def setitem_by_strpath(
pytree: PyTree[Any],
strpath: str,
value: Any,
separator: str = ".",
is_leaf: Callable[Any, bool] | None = None,
) -> None:
r"""
Set an item in a pytree using a string path. Effectively the inverse of
``jax.tree_util.keystr`` with ``simple=True``.
This function works recursively.
Parameters
----------
pytree
The pytree in which to set the item.
strpath
The string path to the item, with keys/indexes separated by
``separator``.
value
The value to set at the specified path.
separator
The separator used in the string path. Default is '.'.
is_leaf
Optional function to determine if a node is a leaf.
Returns
-------
None. The pytree is modified in place.
"""
if not strpath:
raise ValueError("strpath cannot be empty.")
elif is_leaf is not None and is_leaf(pytree):
raise TypeError(
f"Reached leaf node of type {type(pytree)} before "
f"completing strpath '{strpath}'"
)
else:
key, *rest = strpath.split(separator, 1)
if isinstance(pytree, dict):
if rest:
setitem_by_strpath(
pytree[key],
rest[0],
value,
separator,
is_leaf,
)
else:
pytree[key] = value
elif isinstance(pytree, (list, tuple)):
if not key.isdigit():
raise KeyError(
f"Expected integer index for list/tuple, got '{key}'"
)
index = int(key)
if rest:
setitem_by_strpath(
pytree[index],
rest[0],
value,
separator,
is_leaf,
)
else:
if isinstance(pytree, tuple):
raise TypeError(
"Cannot set item in a tuple, as tuples are "
"immutable. Consider converting to a list first."
)
pytree[index] = value
else:
raise TypeError(f"Unsupported pytree node type: {type(pytree)}")
[docs]
def extend_structure_from_strpaths(
base_pytree: PyTree[Any] | None,
strpaths: List[str] | Tuple[str, ...],
separator: str = ".",
fill_values: List[Any] | Tuple[Any, ...] | Any | None = None,
) -> PyTree[Any]:
r"""
Extends the structure of a base pytree by adding new branches specified by
string paths. The new branches are initialized with given fill values or
None.
New branches are created as needed, with dictionaries for string keys and
lists for integer keys. Tuples are not created automatically; any tuples
that need to be extended in the base PyTree will have their type preserved
however.
Parameters
----------
base_pytree
The base pytree to be extended. If None, an empty PyTree with structure
inferred from strpaths is created.
strpaths
A list or tuple of string paths specifying the new branches to be added
separator
The separator used in the string paths. Default is '.'.
fill_values
A list, tuple, or single value specifying the values to fill in the new
branches. If a single value is provided, it is used for all new
branches. Default is None.
Returns
-------
The extended pytree with new branches added.
Examples
--------
>>> base_pytree = {}
>>> strpaths = ['a.b.c', 'd.0.e']
>>> extended_pytree = extend_structure_from_strpaths(
... base_pytree,
... strpaths,
... )
>>> extended_pytree
{'a': {'b': {'c': None}}, 'd': [ {'e': None} ]}
>>> base_pytree = {'x': 1}
>>> strpaths = ['y.z', 'y.w']
>>> extended_pytree = extend_structure_from_strpaths(
... base_pytree,
... strpaths,
... fill_values=42,
... )
>>> extended_pytree
{'x': 1, 'y': {'z': 42, 'w': 42}}
>>> base_pytree = []
>>> strpaths = ['0.a', '1.1.2']
>>> fill_values = [10, 20]
>>> extended_pytree = extend_structure_from_strpaths(
... base_pytree,
... strpaths,
... fill_values=fill_values,
... )
>>> extended_pytree
[ {'a': 10}, [ None, [None, None, 20] ] ]
"""
# Normalize fill_values to a list
if isinstance(fill_values, tuple):
fill_values = list(fill_values)
elif isinstance(fill_values, list):
pass
else:
fill_values = [fill_values] * len(strpaths)
# validate fill_values length
if len(fill_values) != len(strpaths):
raise ValueError(
f"Length of fill_values ({len(fill_values)}) does not "
f"match length of strpaths ({len(strpaths)})."
)
# Initialize base_pytree if None
if base_pytree is None:
# infer type by looking at the first key of all strpaths
first_keys = [path.split(separator, 1)[0] for path in strpaths if path]
# if all first keys are integer format, use list
if all(key.isdigit() for key in first_keys):
base_pytree = []
else:
base_pytree = {}
# convert the base_pytree to a mutable structure if needed
base_tuple = False
if isinstance(base_pytree, tuple):
base_pytree = list(base_pytree)
base_tuple = True
# keep track all all tuple nodes that were converted to lists
# so we can convert them back at the end
# we need lists for mutability during construction
tuple_node_paths = []
for strpath, fill_value in zip(strpaths, fill_values):
current_node = base_pytree
current_path = ""
path_parts = strpath.split(separator)
for i, key in enumerate(path_parts):
current_path = separator.join(path_parts[:i])
# if current_node is a tuple, convert to list for mutability
# and insert it into the base_pytree
if isinstance(current_node, tuple):
tuple_node_paths.append(current_path)
current_node = list(current_node)
# set the converted list back into the base_pytree
setitem_by_strpath(
base_pytree,
current_path,
current_node,
separator,
)
is_last = i == len(path_parts) - 1
# determine if key is int or str
if isinstance(current_node, dict):
is_int_key = False
elif isinstance(current_node, list):
if not key.isdigit():
raise KeyError(
f"Expected integer index for list/tuple, got '{key}'"
)
is_int_key = True
key = int(key)
else:
raise TypeError(
f"Unsupported pytree node type: {type(current_node)}"
)
# create next node if it doesn't exist
if is_int_key:
# ensure list is long enough
while len(current_node) <= key:
current_node.append(None)
if current_node[key] is None:
if is_last:
current_node[key] = fill_value
else:
# look ahead to determine next node type
next_key = path_parts[i + 1]
if next_key.isdigit():
current_node[key] = []
else:
current_node[key] = {}
elif is_last and (
current_node[key] != fill_value and fill_value is not None
):
# if the node already exists and is not the fill_value
# and the fill_value is not None, we raise an error
raise ValueError(
f"Node at path '{current_path}' already exists "
f"with value {current_node[key]}, cannot overwrite "
f"with fill_value {fill_value}."
)
current_node = current_node[key]
else:
if key not in current_node:
if is_last:
current_node[key] = fill_value
else:
# look ahead to determine next node type
next_key = path_parts[i + 1]
if next_key.isdigit():
current_node[key] = []
else:
current_node[key] = {}
elif is_last and (
current_node[key] != fill_value and fill_value is not None
):
# if the node already exists and is not the fill_value
# and the fill_value is not None, we raise an error
raise ValueError(
f"Node at path '{current_path}' already exists "
f"with value {current_node[key]}, cannot overwrite "
f"with fill_value {fill_value}."
)
current_node = current_node[key]
# convert any list nodes that correspond to tuple nodes in the
# original base_pytree back to tuples
for path in tuple_node_paths:
node = getitem_by_strpath(
base_pytree, path, separator, allow_early_return=False
)
if isinstance(node, list):
tuple_node = tuple(node)
setitem_by_strpath(
base_pytree,
path,
tuple_node,
separator,
)
else:
raise RuntimeError(
f"Expected list at path '{path}' to convert back to tuple, "
f"got {type(node)}"
)
# convert base_pytree back to tuple if needed
if base_tuple:
base_pytree = tuple(base_pytree)
return base_pytree
[docs]
@jaxtyped(typechecker=beartype)
def mean(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> Num[Array, ""]:
r"""
Computes the mean of all elements in all leaves of a PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
The mean of all elements across all leaves in the PyTree.
"""
leaves = jax.tree.leaves(pytree)
num_elements = sum(leaf.size for leaf in leaves)
total_sum = sum(np.sum(leaf) for leaf in leaves)
return total_sum / num_elements
[docs]
@jaxtyped(typechecker=beartype)
def add(
pytree1: PyTree[Num[Array, "..."], " T"],
pytree2: PyTree[Num[Array, "..."], " T"],
) -> PyTree[Num[Array, "..."], " T"]:
r"""
Computes the element-wise addition of two PyTrees of arrays with the same
structure.
Parameters
----------
pytree1
The first PyTree where each leaf is an array.
pytree2
The second PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the inputs, where each leaf is the
element-wise addition of the corresponding leaves from the input PyTrees.
"""
return jax.tree.map(lambda x, y: x + y, pytree1, pytree2)
[docs]
@jaxtyped(typechecker=beartype)
def sub(
pytree1: PyTree[Num[Array, "..."], " T"],
pytree2: PyTree[Num[Array, "..."], " T"],
) -> PyTree[Num[Array, "..."], " T"]:
r"""
Computes the element-wise subtraction of two PyTrees of arrays with the
same structure.
Parameters
----------
pytree1
The first PyTree where each leaf is an array.
pytree2
The second PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the inputs, where each leaf is the
element-wise subtraction of the corresponding leaves from the input
PyTrees.
"""
return jax.tree.map(lambda x, y: x - y, pytree1, pytree2)
[docs]
@jaxtyped(typechecker=beartype)
def mul(
pytree1: PyTree[Num[Array, "..."], " T"],
pytree2: PyTree[Num[Array, "..."], " T"],
) -> PyTree[Num[Array, "..."], " T"]:
r"""
Computes the element-wise multiplication of two PyTrees of arrays with the
same structure.
Parameters
----------
pytree1
The first PyTree where each leaf is an array.
pytree2
The second PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the inputs, where each leaf is the
element-wise multiplication of the corresponding leaves from the input
PyTrees.
"""
return jax.tree.map(lambda x, y: x * y, pytree1, pytree2)
[docs]
@jaxtyped(typechecker=beartype)
def div(
pytree1: PyTree[Num[Array, "..."], " T"],
pytree2: PyTree[Num[Array, "..."], " T"],
) -> PyTree[Num[Array, "..."], " T"]:
r"""
Computes the element-wise division of two PyTrees of arrays with the same
structure.
Parameters
----------
pytree1
The first PyTree where each leaf is an array.
pytree2
The second PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the inputs, where each leaf is the
element-wise division of the corresponding leaves from the input PyTrees.
"""
return jax.tree.map(lambda x, y: x / y, pytree1, pytree2)
[docs]
@jaxtyped(typechecker=beartype)
def neg(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Computes the element-wise negation of all leaves of a PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
negated.
"""
return jax.tree.map(lambda x: -x, pytree)
[docs]
@jaxtyped(typechecker=beartype)
def abs(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Computes the element-wise absolute value of all leaves of a PyTree of
arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
replaced by its absolute value.
"""
return jax.tree.map(lambda x: np.abs(x), pytree)
[docs]
@jaxtyped(typechecker=beartype)
def abs_sqr(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Computes the element-wise squared absolute value of all leaves of a
PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
replaced by its squared absolute value.
"""
return jax.tree.map(lambda x: np.abs(x) ** 2, pytree)
[docs]
@jaxtyped(typechecker=beartype)
def sqrt(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Computes the element-wise square root of all leaves of a PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
replaced by its square root.
"""
return jax.tree.map(lambda x: np.sqrt(x), pytree)
[docs]
@jaxtyped(typechecker=beartype)
def abs_sqrt(
pytree: PyTree[Num[Array, " *d"], " T"],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Computes the element-wise square root of the absolute value of all leaves
of a PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
replaced by the square root of its absolute value.
"""
return jax.tree.map(lambda x: np.sqrt(np.abs(x)), pytree)
[docs]
@jaxtyped(typechecker=beartype)
def pow(
pytree: PyTree[Num[Array, " *d"], " T"],
exponent: float | int,
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Raises all leaves of a PyTree of arrays to a specified power.
Parameters
----------
pytree
A PyTree where each leaf is an array.
exponent
The power to which each leaf should be raised.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
raised to the specified power.
"""
return jax.tree.map(lambda x: np.power(x, exponent), pytree)
[docs]
@jaxtyped(typechecker=beartype)
def scalar_add(
pytree: PyTree[Num[Array, " *d"], " T"],
scalar: complex | float | int | Num[Array, ""],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Adds a scalar value to all leaves of a PyTree of arrays.
Parameters
----------
pytree
A PyTree where each leaf is an array.
scalar
The scalar value to add to each leaf.
Returns
-------
A PyTree with the same structure as the input, but with the scalar
added to each leaf.
"""
return jax.tree.map(lambda x: x + scalar, pytree)
[docs]
@jaxtyped(typechecker=beartype)
def scalar_mul(
pytree: PyTree[Num[Array, " *d"], " T"],
scalar: complex | float | int | Num[Array, ""],
) -> PyTree[Num[Array, " *d"], " T"]:
r"""
Multiplies all leaves of a PyTree of arrays by a scalar value.
Parameters
----------
pytree
A PyTree where each leaf is an array.
scalar
The scalar value to multiply each leaf by.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
multiplied by the scalar.
"""
return jax.tree.map(lambda x: x * scalar, pytree)
[docs]
@jaxtyped(typechecker=beartype)
def astype(
pytree: PyTree[Shaped[Array, " ?*d"], " T"],
dtype: jax.typing.DTypeLike,
) -> PyTree[Shaped[Array, " ?*d"], " T"]:
r"""
Casts all leaves of a PyTree of arrays to a specified data type.
Parameters
----------
pytree
A PyTree where each leaf is an array.
dtype
The target data type to which each leaf should be cast. Can be a
JAX numpy dtype or a string representing the dtype.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
cast to the specified data type.
"""
return jax.tree.map(lambda x: x.astype(dtype), pytree)
[docs]
def is_shape_leaf(obj: Any, allow_none: bool = True) -> bool:
"""
Check if the given object is a shape leaf, i.e., a tuple of integers
representing the shape of an array, or None.
Parameters
----------
obj
The object to check.
Returns
-------
True if the object is a shape leaf, False otherwise.
"""
if allow_none and obj is None:
return True
if not isinstance(obj, tuple):
return False
return all(isinstance(dim, int) for dim in obj)
[docs]
@jaxtyped(typechecker=beartype)
def shapes_equal(
pytree1: PyTree[Shaped[Array, "..."]],
pytree2: PyTree[Shaped[Array, "..."]],
axis: int | Tuple[int, ...] | slice | None = None,
) -> bool:
r"""
Checks if both the structure and the shapes of the leaves of two PyTrees
with array leaves match along the specified axes.
Parameters
----------
pytree1
The first PyTree to compare.
pytree2
The second PyTree to compare.
axis
The axes along which to compare the shapes of the leaves. Can be an
integer, a tuple of integers, a slice, or None. If None, all axes
are compared.
Returns
-------
True if both the structure and the shapes of the leaves match along the
specified axes, False otherwise.
"""
struct1 = jax.tree.structure(pytree1)
struct2 = jax.tree.structure(pytree2)
if struct1 != struct2:
return False
leaves1 = jax.tree.leaves(pytree1)
leaves2 = jax.tree.leaves(pytree2)
for leaf1, leaf2 in zip(leaves1, leaves2):
shape1 = leaf1.shape
shape2 = leaf2.shape
if axis is None:
if shape1 != shape2:
return False
elif isinstance(axis, slice):
if shape1[axis] != shape2[axis]:
return False
else:
if isinstance(axis, int):
axis = (axis,)
for ax in axis:
if shape1[ax] != shape2[ax]:
return False
return True
[docs]
@jaxtyped(typechecker=beartype)
def batch_leaves(
pytree: PyTree[Shaped[Array, "..."], " T"],
batch_size: int | Integer[Array, ""],
batch_idx: int | Integer[Array, ""],
length: int | Integer[Array, ""] | None = None,
axis: int | Integer[Array, ""] = 0,
) -> PyTree[Shaped[Array, "..."], " T"]:
r"""
Extracts a batch of values from all leaves of a PyTree of arrays.
Each leaf is sliced along the specified axis by
``[batch_idx * batch_size : batch_idx * batch_size + length]`` if
``length`` is provided, otherwise by
``[batch_idx * batch_size : (batch_idx + 1) * batch_size]``.
Parameters
----------
pytree
A PyTree where each leaf is an array with a matching size along the
specified axis.
batch_size
The size of the batches that the leaves are divided into.
batch_idx
The index of the batch to extract.
length
Optional length of the slice to extract. If not provided,
defaults to `batch_size`. Useful for getting the last batch which
may be smaller than `batch_size`.
axis
The axis along which to slice the leaves. Default is 0.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
sliced to contain only the specified batch.
"""
start = batch_idx * batch_size
length = length if length is not None else batch_size
return jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(x, start, length, axis=axis),
pytree,
)
[docs]
@jaxtyped(typechecker=beartype)
def random_permute_leaves(
pytree: PyTree[Shaped[Array, "..."], " T"],
key: Any,
axis: int = 0,
independent_arrays: bool = False,
independent_leaves: bool = False,
) -> PyTree[Shaped[Array, "..."], " T"]:
r"""
Randomly permutes the arrays in the leaves of a PyTree of arrays along a
specified axis.
Parameters
----------
pytree
A PyTree where each leaf is an array with a matching size along the
specified axis.
key
A JAX PRNG key used for generating random permutations.
axis
The axis along which to permute the leaves. Default is 0.
independent_arrays
If True, each individual vector along the given axis is shuffled
independently. Default is False. See the ``independent`` argument of
``jax.random.permutation`` for more details.
independent_leaves
If True, each leaf in the PyTree is permuted independently using
different random keys. Default is False.
Returns
-------
A PyTree with the same structure as the input, but with each leaf
randomly permuted along the specified axis.
"""
if independent_leaves:
keys = jax.random.split(key, len(jax.tree.leaves(pytree)))
keys = jax.tree.unflatten(
jax.tree.structure(pytree),
keys,
)
return jax.tree.map(
lambda x, k: jax.random.permutation(
k,
x,
axis=axis,
independent=independent_arrays,
),
pytree,
keys,
)
else:
return jax.tree.map(
lambda x: jax.random.permutation(
key,
x,
axis=axis,
independent=independent_arrays,
),
pytree,
)
[docs]
@jaxtyped(typechecker=beartype)
def safecast(
X: PyTree[Num[Array, "..."], " T"], dtype: jax.typing.DTypeLike
) -> PyTree[Num[Array, "..."], " T"]:
r"""
Safely cast input data to a specified dtype, ensuring that complex types
are not inadvertently cast to float types. And issues a warning if the
requested dtype was not successfully applied, usually due to JAX settings.
Parameters
----------
X
Input data to be cast.
dtype
Desired data type for the output.
"""
# make sure that we don't cast complex to float
def cast_with_complex_check(
x: np.ndarray, dtype: jax.typing.DTypeLike
) -> np.ndarray:
if np.issubdtype(x.dtype, np.complexfloating) and not np.issubdtype(
dtype, np.complexfloating
):
raise ValueError(
f"Cannot cast complex input dtype {x.dtype} to "
f"float output dtype {dtype}."
)
return x.astype(dtype)
X_ = jax.tree.map(lambda x: cast_with_complex_check(x, dtype), X)
# make sure the dtype was converted, issue a warning if not
def check_cast(x: np.ndarray, dtype: jax.typing.DTypeLike) -> None:
if x.dtype != dtype:
warnings.warn(
f"Requested dtype ({dtype}) was not successfully applied. "
"This is most likely due to JAX_ENABLE_X64 not being set. "
"See accompanying JAX warning for more details.",
UserWarning,
)
jax.tree.map(lambda x: check_cast(x, dtype), X_)
return X_
[docs]
def strfmt_pytree(
tree: PyTree,
indent: int = 0,
indentation: int = 1,
max_leaf_chars: int | None = None,
base_indent_str: str = "",
is_leaf: Callable[[PyTree], bool] | None = None,
) -> str:
"""
Format a JAX PyTree into a nicely indented string representation.
Parameters
----------
tree
An arbitrary JAX PyTree (dict, list, tuple, or leaf value)
indent
Current indentation level (used for recursion)
indentation
Number of spaces to indent for each level
max_leaf_chars
Maximum characters for leaf value representation before truncation
base_indent_str
Base indentation string to prepend to each line
is_leaf
Optional function to determine if a node is a leaf
Returns:
A formatted string representation of the PyTree
"""
indent_str = " " * indent * indentation
next_indent_str = " " * (indent + 1) * indentation
def truncate_leaf(s: str) -> str:
"""Truncate leaf representation if it exceeds max_leaf_chars."""
if max_leaf_chars is None:
return s
# truncate individual lines for multi-line strings
if "\n" in s:
lines = s.split("\n")
truncated_lines = [truncate_leaf(line) for line in lines]
return "\n".join(truncated_lines)
if len(s) > max_leaf_chars:
return s[: max_leaf_chars - 3] + "..."
return s
# handle custom leaf detection
if is_leaf is not None and is_leaf(tree):
ret_str = truncate_leaf(repr(tree))
# handle dictionaries
elif isinstance(tree, dict):
if not tree:
if indent == 0:
return base_indent_str + "{}"
else:
return "{}"
items = []
for key, value in tree.items():
formatted_value = strfmt_pytree(
value,
indent + 1,
indentation,
max_leaf_chars,
base_indent_str,
is_leaf,
)
# if formatted_value contains newlines, indent them properly
if "\n" in formatted_value:
formatted_value = "\n".join(
f"{base_indent_str}{next_indent_str}{line}"
for line in formatted_value.split("\n")
)
items.append(
f"{base_indent_str}{next_indent_str}{key}:"
f"\n{formatted_value}"
)
else:
items.append(
f"{base_indent_str}{next_indent_str}{key}:"
f" {formatted_value}"
)
ret_str = (
"{{\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str}}}"
)
# handle lists
elif isinstance(tree, list):
if not tree:
if indent == 0:
return base_indent_str + "[]"
else:
return "[]"
items = []
for item in tree:
formatted_item = strfmt_pytree(
item,
indent + 1,
indentation,
max_leaf_chars,
base_indent_str,
is_leaf,
)
# if formatted_item contains newlines, indent them properly
if "\n" in formatted_item:
formatted_item = "\n".join(
f"{base_indent_str}{next_indent_str}{line}"
for line in formatted_item.split("\n")
)
items.append(f"{formatted_item}")
else:
items.append(
f"{base_indent_str}{next_indent_str}{formatted_item}"
)
ret_str = (
"[\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str}]"
)
# handle tuples
elif isinstance(tree, tuple):
if not tree:
if indent == 0:
return base_indent_str + "()"
else:
return "()"
items = []
for item in tree:
formatted_item = strfmt_pytree(
item,
indent + 1,
indentation,
max_leaf_chars,
base_indent_str,
is_leaf,
)
# if formatted_item contains newlines, indent them properly
if "\n" in formatted_item:
formatted_item = "\n".join(
f"{base_indent_str}{next_indent_str}{line}"
for line in formatted_item.split("\n")
)
items.append(f"{formatted_item}")
else:
items.append(
f"{base_indent_str}{next_indent_str}{formatted_item}"
)
ret_str = (
"(\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str})"
)
# handle leaves
else:
ret_str = truncate_leaf(repr(tree))
if indent == 0:
return base_indent_str + ret_str
else:
return ret_str