Source code for parametricmatrixmodels.tree_util

from __future__ import annotations

import warnings

import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import Array, Integer, Num, PyTree, Shaped, jaxtyped

from .typing import Any, Callable, List, Tuple


[docs] def all_equal( pytree1: PyTree[Any], pytree2: PyTree[Any], ) -> bool: r""" Check if two pytrees are equal in structure and content. """ struct1 = jax.tree.structure(pytree1) struct2 = jax.tree.structure(pytree2) if struct1 != struct2: return False leaves1 = jax.tree.leaves(pytree1) leaves2 = jax.tree.leaves(pytree2) for leaf1, leaf2 in zip(leaves1, leaves2): if isinstance(leaf1, np.ndarray) and isinstance(leaf2, np.ndarray): if not np.array_equal(leaf1, leaf2): return False else: if leaf1 != leaf2: return False return True
[docs] def all_close( pytree1: PyTree[Any], pytree2: PyTree[Any], rtol: float = 1e-05, atol: float = 1e-08, ) -> bool: r""" Check if two pytrees are close in structure and content. """ struct1 = jax.tree.structure(pytree1) struct2 = jax.tree.structure(pytree2) if struct1 != struct2: return False leaves1 = jax.tree.leaves(pytree1) leaves2 = jax.tree.leaves(pytree2) for leaf1, leaf2 in zip(leaves1, leaves2): if isinstance(leaf1, np.ndarray) and isinstance(leaf2, np.ndarray): if not np.allclose(leaf1, leaf2, rtol=rtol, atol=atol): return False else: if leaf1 != leaf2: return False return True
[docs] def get_shapes( pytree: PyTree[Shaped[Array, "..."]], axis: int | Tuple[int, ...] | slice | None = None, ) -> PyTree[Tuple[int, ...]]: r""" Get the shapes of all leaves in a PyTree of arrays, optionally selecting specific axes. """ def _get_shape(leaf: Shaped[Array, "..."]) -> Tuple[int, ...]: shape = leaf.shape if axis is None: return shape elif isinstance(axis, slice): return shape[axis] else: if isinstance(axis, int): tuple_axis = (axis,) else: tuple_axis = axis return tuple(shape[ax] for ax in tuple_axis) return jax.tree.map(_get_shape, pytree)
[docs] def is_single_leaf( pytree: PyTree[Any], ndim: int | None = None, shape: Tuple[int, ...] | None = None, is_leaf: Callable[[Any], bool] | None = None, ) -> Tuple[bool, Any]: r""" Check if a pytree consists of a single leaf node, optionally verifying the leaf's number of dimensions and shape. """ if ndim is not None and shape is not None: if len(shape) != ndim: raise ValueError( f"Provided shape {shape} does not match provided ndim {ndim}." ) leaves = jax.tree.leaves(pytree, is_leaf=is_leaf) if len(leaves) != 1: print("Pytree does not have a single leaf.") return False, None leaf = leaves[0] if ndim is not None: # if the leaf is an array, check its ndim # otherwise, try to check len(leaf) if hasattr(leaf, "ndim"): if leaf.ndim != ndim: return False, None else: if len(leaf) != ndim: return False, None if shape is not None: # if the leaf is an array, check its shape # otherwise, try to check tuple(leaf) # if None in shape, skip that dimension if hasattr(leaf, "shape"): leaf_shape = leaf.shape else: leaf_shape = tuple(leaf) if len(leaf_shape) != len(shape): return False, None for dim1, dim2 in zip(leaf_shape, shape): if dim2 is not None and dim1 != dim2: return False, None return True, leaf
[docs] def make_mutable(pytree: PyTree[Any]) -> PyTree[Any]: r""" Convert all tuples in a pytree to lists for mutability. Parameters ---------- pytree The pytree to convert. Returns ------- A new pytree with all tuples converted to lists. """ def _is_tuple(obj: Any) -> bool: return isinstance(obj, tuple) while not jax.tree.all( jax.tree.map(lambda x: not _is_tuple(x), pytree, is_leaf=_is_tuple) ): pytree = jax.tree.map( lambda x: list(x) if _is_tuple(x) else x, pytree, is_leaf=_is_tuple, ) return pytree
[docs] def getitem_by_strpath( pytree: PyTree[Any], strpath: str, separator: str = ".", allow_early_return: bool = False, return_remainder: bool = False, is_leaf: Callable[Any, bool] | None = None, ) -> Any: r""" Get an item from a pytree using a string path. Effectively the inverse of ``jax.tree_util.keystr`` with ``simple=True``. Optionally allows early return if the path cannot be fully traversed, returning the remaining path as well. This function works recursively. Parameters ---------- pytree The pytree from which to get the item. strpath The string path to the item, with keys/indexes separated by ``separator``. separator The separator used in the string path. Default is '.'. allow_early_return If True, allows the function to return early if the path cannot be fully traversed. In this case, returns a tuple of the current pytree node and the remaining path as a string. return_remainder If True, returns a tuple of the found item and the remaining path as a string, even if the full path was traversed. is_leaf Optional function to determine if a node is a leaf. Returns ------- The item at the specified path in the pytree. Examples -------- >>> pytree = {'a': [1, 2, {'b': 3}], 'c': 4} >>> getitem_by_strpath(pytree, 'a.2.b') 3 >>> getitem_by_strpath(pytree, 'c', return_remainder=True) 4, '' >>> pytree = {'a': [1, 2, 3], 'c': 4} >>> getitem_by_strpath( ... pytree, ... 'a.2.b', ... allow_early_return=True, ... return_remainder=True ... ) (3, 'b') >>> getitem_by_strpath( ... pytree, ... 'a.2.b.0.c', ... allow_early_return=True, ... return_remainder=True ... ) (3, 'b.0.c') """ # Since the strpath is generated by jax.tree_util.keystr with simple=True, # we have to infer the types of the keys (int or str) when traversing the # pytree. if not strpath: if return_remainder: return pytree, "" else: return pytree elif is_leaf is not None and is_leaf(pytree): if allow_early_return: if return_remainder: return pytree, strpath else: return pytree else: raise TypeError( f"Reached leaf node of type {type(pytree)} before " f"completing strpath '{strpath}'" ) else: key, *rest = strpath.split(separator, 1) if isinstance(pytree, dict): next_pytree = pytree[key] elif isinstance(pytree, (list, tuple)): if not key.isdigit(): raise KeyError( f"Expected integer index for list/tuple, got '{key}'" ) index = int(key) next_pytree = pytree[index] elif allow_early_return: if return_remainder: return pytree, strpath else: return pytree else: raise TypeError(f"Unsupported pytree node type: {type(pytree)}") return getitem_by_strpath( next_pytree, rest[0] if rest else "", separator, allow_early_return, return_remainder, is_leaf, )
[docs] def setitem_by_strpath( pytree: PyTree[Any], strpath: str, value: Any, separator: str = ".", is_leaf: Callable[Any, bool] | None = None, ) -> None: r""" Set an item in a pytree using a string path. Effectively the inverse of ``jax.tree_util.keystr`` with ``simple=True``. This function works recursively. Parameters ---------- pytree The pytree in which to set the item. strpath The string path to the item, with keys/indexes separated by ``separator``. value The value to set at the specified path. separator The separator used in the string path. Default is '.'. is_leaf Optional function to determine if a node is a leaf. Returns ------- None. The pytree is modified in place. """ if not strpath: raise ValueError("strpath cannot be empty.") elif is_leaf is not None and is_leaf(pytree): raise TypeError( f"Reached leaf node of type {type(pytree)} before " f"completing strpath '{strpath}'" ) else: key, *rest = strpath.split(separator, 1) if isinstance(pytree, dict): if rest: setitem_by_strpath( pytree[key], rest[0], value, separator, is_leaf, ) else: pytree[key] = value elif isinstance(pytree, (list, tuple)): if not key.isdigit(): raise KeyError( f"Expected integer index for list/tuple, got '{key}'" ) index = int(key) if rest: setitem_by_strpath( pytree[index], rest[0], value, separator, is_leaf, ) else: if isinstance(pytree, tuple): raise TypeError( "Cannot set item in a tuple, as tuples are " "immutable. Consider converting to a list first." ) pytree[index] = value else: raise TypeError(f"Unsupported pytree node type: {type(pytree)}")
[docs] def extend_structure_from_strpaths( base_pytree: PyTree[Any] | None, strpaths: List[str] | Tuple[str, ...], separator: str = ".", fill_values: List[Any] | Tuple[Any, ...] | Any | None = None, ) -> PyTree[Any]: r""" Extends the structure of a base pytree by adding new branches specified by string paths. The new branches are initialized with given fill values or None. New branches are created as needed, with dictionaries for string keys and lists for integer keys. Tuples are not created automatically; any tuples that need to be extended in the base PyTree will have their type preserved however. Parameters ---------- base_pytree The base pytree to be extended. If None, an empty PyTree with structure inferred from strpaths is created. strpaths A list or tuple of string paths specifying the new branches to be added separator The separator used in the string paths. Default is '.'. fill_values A list, tuple, or single value specifying the values to fill in the new branches. If a single value is provided, it is used for all new branches. Default is None. Returns ------- The extended pytree with new branches added. Examples -------- >>> base_pytree = {} >>> strpaths = ['a.b.c', 'd.0.e'] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... ) >>> extended_pytree {'a': {'b': {'c': None}}, 'd': [ {'e': None} ]} >>> base_pytree = {'x': 1} >>> strpaths = ['y.z', 'y.w'] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... fill_values=42, ... ) >>> extended_pytree {'x': 1, 'y': {'z': 42, 'w': 42}} >>> base_pytree = [] >>> strpaths = ['0.a', '1.1.2'] >>> fill_values = [10, 20] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... fill_values=fill_values, ... ) >>> extended_pytree [ {'a': 10}, [ None, [None, None, 20] ] ] """ # Normalize fill_values to a list if isinstance(fill_values, tuple): fill_values = list(fill_values) elif isinstance(fill_values, list): pass else: fill_values = [fill_values] * len(strpaths) # validate fill_values length if len(fill_values) != len(strpaths): raise ValueError( f"Length of fill_values ({len(fill_values)}) does not " f"match length of strpaths ({len(strpaths)})." ) # Initialize base_pytree if None if base_pytree is None: # infer type by looking at the first key of all strpaths first_keys = [path.split(separator, 1)[0] for path in strpaths if path] # if all first keys are integer format, use list if all(key.isdigit() for key in first_keys): base_pytree = [] else: base_pytree = {} # convert the base_pytree to a mutable structure if needed base_tuple = False if isinstance(base_pytree, tuple): base_pytree = list(base_pytree) base_tuple = True # keep track all all tuple nodes that were converted to lists # so we can convert them back at the end # we need lists for mutability during construction tuple_node_paths = [] for strpath, fill_value in zip(strpaths, fill_values): current_node = base_pytree current_path = "" path_parts = strpath.split(separator) for i, key in enumerate(path_parts): current_path = separator.join(path_parts[:i]) # if current_node is a tuple, convert to list for mutability # and insert it into the base_pytree if isinstance(current_node, tuple): tuple_node_paths.append(current_path) current_node = list(current_node) # set the converted list back into the base_pytree setitem_by_strpath( base_pytree, current_path, current_node, separator, ) is_last = i == len(path_parts) - 1 # determine if key is int or str if isinstance(current_node, dict): is_int_key = False elif isinstance(current_node, list): if not key.isdigit(): raise KeyError( f"Expected integer index for list/tuple, got '{key}'" ) is_int_key = True key = int(key) else: raise TypeError( f"Unsupported pytree node type: {type(current_node)}" ) # create next node if it doesn't exist if is_int_key: # ensure list is long enough while len(current_node) <= key: current_node.append(None) if current_node[key] is None: if is_last: current_node[key] = fill_value else: # look ahead to determine next node type next_key = path_parts[i + 1] if next_key.isdigit(): current_node[key] = [] else: current_node[key] = {} elif is_last and ( current_node[key] != fill_value and fill_value is not None ): # if the node already exists and is not the fill_value # and the fill_value is not None, we raise an error raise ValueError( f"Node at path '{current_path}' already exists " f"with value {current_node[key]}, cannot overwrite " f"with fill_value {fill_value}." ) current_node = current_node[key] else: if key not in current_node: if is_last: current_node[key] = fill_value else: # look ahead to determine next node type next_key = path_parts[i + 1] if next_key.isdigit(): current_node[key] = [] else: current_node[key] = {} elif is_last and ( current_node[key] != fill_value and fill_value is not None ): # if the node already exists and is not the fill_value # and the fill_value is not None, we raise an error raise ValueError( f"Node at path '{current_path}' already exists " f"with value {current_node[key]}, cannot overwrite " f"with fill_value {fill_value}." ) current_node = current_node[key] # convert any list nodes that correspond to tuple nodes in the # original base_pytree back to tuples for path in tuple_node_paths: node = getitem_by_strpath( base_pytree, path, separator, allow_early_return=False ) if isinstance(node, list): tuple_node = tuple(node) setitem_by_strpath( base_pytree, path, tuple_node, separator, ) else: raise RuntimeError( f"Expected list at path '{path}' to convert back to tuple, " f"got {type(node)}" ) # convert base_pytree back to tuple if needed if base_tuple: base_pytree = tuple(base_pytree) return base_pytree
[docs] @jaxtyped(typechecker=beartype) def mean( pytree: PyTree[Num[Array, " *d"], " T"], ) -> Num[Array, ""]: r""" Computes the mean of all elements in all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- The mean of all elements across all leaves in the PyTree. """ leaves = jax.tree.leaves(pytree) num_elements = sum(leaf.size for leaf in leaves) total_sum = sum(np.sum(leaf) for leaf in leaves) return total_sum / num_elements
[docs] @jaxtyped(typechecker=beartype) def add( pytree1: PyTree[Num[Array, "..."], " T"], pytree2: PyTree[Num[Array, "..."], " T"], ) -> PyTree[Num[Array, "..."], " T"]: r""" Computes the element-wise addition of two PyTrees of arrays with the same structure. Parameters ---------- pytree1 The first PyTree where each leaf is an array. pytree2 The second PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the inputs, where each leaf is the element-wise addition of the corresponding leaves from the input PyTrees. """ return jax.tree.map(lambda x, y: x + y, pytree1, pytree2)
[docs] @jaxtyped(typechecker=beartype) def sub( pytree1: PyTree[Num[Array, "..."], " T"], pytree2: PyTree[Num[Array, "..."], " T"], ) -> PyTree[Num[Array, "..."], " T"]: r""" Computes the element-wise subtraction of two PyTrees of arrays with the same structure. Parameters ---------- pytree1 The first PyTree where each leaf is an array. pytree2 The second PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the inputs, where each leaf is the element-wise subtraction of the corresponding leaves from the input PyTrees. """ return jax.tree.map(lambda x, y: x - y, pytree1, pytree2)
[docs] @jaxtyped(typechecker=beartype) def mul( pytree1: PyTree[Num[Array, "..."], " T"], pytree2: PyTree[Num[Array, "..."], " T"], ) -> PyTree[Num[Array, "..."], " T"]: r""" Computes the element-wise multiplication of two PyTrees of arrays with the same structure. Parameters ---------- pytree1 The first PyTree where each leaf is an array. pytree2 The second PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the inputs, where each leaf is the element-wise multiplication of the corresponding leaves from the input PyTrees. """ return jax.tree.map(lambda x, y: x * y, pytree1, pytree2)
[docs] @jaxtyped(typechecker=beartype) def div( pytree1: PyTree[Num[Array, "..."], " T"], pytree2: PyTree[Num[Array, "..."], " T"], ) -> PyTree[Num[Array, "..."], " T"]: r""" Computes the element-wise division of two PyTrees of arrays with the same structure. Parameters ---------- pytree1 The first PyTree where each leaf is an array. pytree2 The second PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the inputs, where each leaf is the element-wise division of the corresponding leaves from the input PyTrees. """ return jax.tree.map(lambda x, y: x / y, pytree1, pytree2)
[docs] @jaxtyped(typechecker=beartype) def neg( pytree: PyTree[Num[Array, " *d"], " T"], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Computes the element-wise negation of all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the input, but with each leaf negated. """ return jax.tree.map(lambda x: -x, pytree)
[docs] @jaxtyped(typechecker=beartype) def abs( pytree: PyTree[Num[Array, " *d"], " T"], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Computes the element-wise absolute value of all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the input, but with each leaf replaced by its absolute value. """ return jax.tree.map(lambda x: np.abs(x), pytree)
[docs] @jaxtyped(typechecker=beartype) def abs_sqr( pytree: PyTree[Num[Array, " *d"], " T"], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Computes the element-wise squared absolute value of all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the input, but with each leaf replaced by its squared absolute value. """ return jax.tree.map(lambda x: np.abs(x) ** 2, pytree)
[docs] @jaxtyped(typechecker=beartype) def sqrt( pytree: PyTree[Num[Array, " *d"], " T"], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Computes the element-wise square root of all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the input, but with each leaf replaced by its square root. """ return jax.tree.map(lambda x: np.sqrt(x), pytree)
[docs] @jaxtyped(typechecker=beartype) def abs_sqrt( pytree: PyTree[Num[Array, " *d"], " T"], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Computes the element-wise square root of the absolute value of all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. Returns ------- A PyTree with the same structure as the input, but with each leaf replaced by the square root of its absolute value. """ return jax.tree.map(lambda x: np.sqrt(np.abs(x)), pytree)
[docs] @jaxtyped(typechecker=beartype) def pow( pytree: PyTree[Num[Array, " *d"], " T"], exponent: float | int, ) -> PyTree[Num[Array, " *d"], " T"]: r""" Raises all leaves of a PyTree of arrays to a specified power. Parameters ---------- pytree A PyTree where each leaf is an array. exponent The power to which each leaf should be raised. Returns ------- A PyTree with the same structure as the input, but with each leaf raised to the specified power. """ return jax.tree.map(lambda x: np.power(x, exponent), pytree)
[docs] @jaxtyped(typechecker=beartype) def scalar_add( pytree: PyTree[Num[Array, " *d"], " T"], scalar: complex | float | int | Num[Array, ""], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Adds a scalar value to all leaves of a PyTree of arrays. Parameters ---------- pytree A PyTree where each leaf is an array. scalar The scalar value to add to each leaf. Returns ------- A PyTree with the same structure as the input, but with the scalar added to each leaf. """ return jax.tree.map(lambda x: x + scalar, pytree)
[docs] @jaxtyped(typechecker=beartype) def scalar_mul( pytree: PyTree[Num[Array, " *d"], " T"], scalar: complex | float | int | Num[Array, ""], ) -> PyTree[Num[Array, " *d"], " T"]: r""" Multiplies all leaves of a PyTree of arrays by a scalar value. Parameters ---------- pytree A PyTree where each leaf is an array. scalar The scalar value to multiply each leaf by. Returns ------- A PyTree with the same structure as the input, but with each leaf multiplied by the scalar. """ return jax.tree.map(lambda x: x * scalar, pytree)
[docs] @jaxtyped(typechecker=beartype) def astype( pytree: PyTree[Shaped[Array, " ?*d"], " T"], dtype: jax.typing.DTypeLike, ) -> PyTree[Shaped[Array, " ?*d"], " T"]: r""" Casts all leaves of a PyTree of arrays to a specified data type. Parameters ---------- pytree A PyTree where each leaf is an array. dtype The target data type to which each leaf should be cast. Can be a JAX numpy dtype or a string representing the dtype. Returns ------- A PyTree with the same structure as the input, but with each leaf cast to the specified data type. """ return jax.tree.map(lambda x: x.astype(dtype), pytree)
[docs] def is_shape_leaf(obj: Any, allow_none: bool = True) -> bool: """ Check if the given object is a shape leaf, i.e., a tuple of integers representing the shape of an array, or None. Parameters ---------- obj The object to check. Returns ------- True if the object is a shape leaf, False otherwise. """ if allow_none and obj is None: return True if not isinstance(obj, tuple): return False return all(isinstance(dim, int) for dim in obj)
[docs] @jaxtyped(typechecker=beartype) def has_uniform_leaf_shapes( pytree: PyTree[Shaped[Array, "..."]], axis: int | Tuple[int, ...] | slice | None = None, ) -> bool: r""" Checks if all leaves of a PyTree of arrays have the same shape along the specified axes. Returns False if any leaves are not arrays. Parameters ---------- pytree The PyTree to check. axis The axes along which to check for uniformity. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are checked. Returns ------- True if all leaves have the same shape along the specified axes, False otherwise. """ leaves = jax.tree.leaves(pytree) if not leaves: return True # check if any are not arrays if any(not isinstance(leaf, np.ndarray) for leaf in leaves): return False reference_shape = leaves[0].shape for leaf in leaves[1:]: shape = leaf.shape if axis is None: if shape != reference_shape: return False elif isinstance(axis, slice): if shape[axis] != reference_shape[axis]: return False else: if isinstance(axis, int): axis = (axis,) for ax in axis: if shape[ax] != reference_shape[ax]: return False return True
[docs] @jaxtyped(typechecker=beartype) def uniform_leaf_shapes_equal( pytree1: PyTree[Shaped[Array, "..."]], pytree2: PyTree[Shaped[Array, "..."]], axis: int | Tuple[int, ...] | slice | None = None, ) -> bool: r""" Checks if the shapes of the leaves of two PyTrees with array leaves match along the specified axes. First checks that both PyTrees have uniform leaf shapes along the specified axes. Parameters ---------- pytree1 The first PyTree to compare. pytree2 The second PyTree to compare. axis The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared. Returns ------- True if the shapes of the leaves match along the specified axes, False otherwise. """ if not has_uniform_leaf_shapes( pytree1, axis ) or not has_uniform_leaf_shapes(pytree2, axis): return False leaves1 = jax.tree.leaves(pytree1) leaves2 = jax.tree.leaves(pytree2) shape1 = leaves1[0].shape shape2 = leaves2[0].shape if axis is None: return shape1 == shape2 elif isinstance(axis, slice): return shape1[axis] == shape2[axis] else: if isinstance(axis, int): axis = (axis,) for ax in axis: if shape1[ax] != shape2[ax]: return False return True
[docs] @jaxtyped(typechecker=beartype) def shapes_equal( pytree1: PyTree[Shaped[Array, "..."]], pytree2: PyTree[Shaped[Array, "..."]], axis: int | Tuple[int, ...] | slice | None = None, ) -> bool: r""" Checks if both the structure and the shapes of the leaves of two PyTrees with array leaves match along the specified axes. Parameters ---------- pytree1 The first PyTree to compare. pytree2 The second PyTree to compare. axis The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared. Returns ------- True if both the structure and the shapes of the leaves match along the specified axes, False otherwise. """ struct1 = jax.tree.structure(pytree1) struct2 = jax.tree.structure(pytree2) if struct1 != struct2: return False leaves1 = jax.tree.leaves(pytree1) leaves2 = jax.tree.leaves(pytree2) for leaf1, leaf2 in zip(leaves1, leaves2): shape1 = leaf1.shape shape2 = leaf2.shape if axis is None: if shape1 != shape2: return False elif isinstance(axis, slice): if shape1[axis] != shape2[axis]: return False else: if isinstance(axis, int): axis = (axis,) for ax in axis: if shape1[ax] != shape2[ax]: return False return True
[docs] @jaxtyped(typechecker=beartype) def batch_leaves( pytree: PyTree[Shaped[Array, "..."], " T"], batch_size: int | Integer[Array, ""], batch_idx: int | Integer[Array, ""], length: int | Integer[Array, ""] | None = None, axis: int | Integer[Array, ""] = 0, ) -> PyTree[Shaped[Array, "..."], " T"]: r""" Extracts a batch of values from all leaves of a PyTree of arrays. Each leaf is sliced along the specified axis by ``[batch_idx * batch_size : batch_idx * batch_size + length]`` if ``length`` is provided, otherwise by ``[batch_idx * batch_size : (batch_idx + 1) * batch_size]``. Parameters ---------- pytree A PyTree where each leaf is an array with a matching size along the specified axis. batch_size The size of the batches that the leaves are divided into. batch_idx The index of the batch to extract. length Optional length of the slice to extract. If not provided, defaults to `batch_size`. Useful for getting the last batch which may be smaller than `batch_size`. axis The axis along which to slice the leaves. Default is 0. Returns ------- A PyTree with the same structure as the input, but with each leaf sliced to contain only the specified batch. """ start = batch_idx * batch_size length = length if length is not None else batch_size return jax.tree.map( lambda x: jax.lax.dynamic_slice_in_dim(x, start, length, axis=axis), pytree, )
[docs] @jaxtyped(typechecker=beartype) def random_permute_leaves( pytree: PyTree[Shaped[Array, "..."], " T"], key: Any, axis: int = 0, independent_arrays: bool = False, independent_leaves: bool = False, ) -> PyTree[Shaped[Array, "..."], " T"]: r""" Randomly permutes the arrays in the leaves of a PyTree of arrays along a specified axis. Parameters ---------- pytree A PyTree where each leaf is an array with a matching size along the specified axis. key A JAX PRNG key used for generating random permutations. axis The axis along which to permute the leaves. Default is 0. independent_arrays If True, each individual vector along the given axis is shuffled independently. Default is False. See the ``independent`` argument of ``jax.random.permutation`` for more details. independent_leaves If True, each leaf in the PyTree is permuted independently using different random keys. Default is False. Returns ------- A PyTree with the same structure as the input, but with each leaf randomly permuted along the specified axis. """ if independent_leaves: keys = jax.random.split(key, len(jax.tree.leaves(pytree))) keys = jax.tree.unflatten( jax.tree.structure(pytree), keys, ) return jax.tree.map( lambda x, k: jax.random.permutation( k, x, axis=axis, independent=independent_arrays, ), pytree, keys, ) else: return jax.tree.map( lambda x: jax.random.permutation( key, x, axis=axis, independent=independent_arrays, ), pytree, )
[docs] @jaxtyped(typechecker=beartype) def safecast( X: PyTree[Num[Array, "..."], " T"], dtype: jax.typing.DTypeLike ) -> PyTree[Num[Array, "..."], " T"]: r""" Safely cast input data to a specified dtype, ensuring that complex types are not inadvertently cast to float types. And issues a warning if the requested dtype was not successfully applied, usually due to JAX settings. Parameters ---------- X Input data to be cast. dtype Desired data type for the output. """ # make sure that we don't cast complex to float def cast_with_complex_check( x: np.ndarray, dtype: jax.typing.DTypeLike ) -> np.ndarray: if np.issubdtype(x.dtype, np.complexfloating) and not np.issubdtype( dtype, np.complexfloating ): raise ValueError( f"Cannot cast complex input dtype {x.dtype} to " f"float output dtype {dtype}." ) return x.astype(dtype) X_ = jax.tree.map(lambda x: cast_with_complex_check(x, dtype), X) # make sure the dtype was converted, issue a warning if not def check_cast(x: np.ndarray, dtype: jax.typing.DTypeLike) -> None: if x.dtype != dtype: warnings.warn( f"Requested dtype ({dtype}) was not successfully applied. " "This is most likely due to JAX_ENABLE_X64 not being set. " "See accompanying JAX warning for more details.", UserWarning, ) jax.tree.map(lambda x: check_cast(x, dtype), X_) return X_
[docs] def strfmt_pytree( tree: PyTree, indent: int = 0, indentation: int = 1, max_leaf_chars: int | None = None, base_indent_str: str = "", is_leaf: Callable[[PyTree], bool] | None = None, ) -> str: """ Format a JAX PyTree into a nicely indented string representation. Parameters ---------- tree An arbitrary JAX PyTree (dict, list, tuple, or leaf value) indent Current indentation level (used for recursion) indentation Number of spaces to indent for each level max_leaf_chars Maximum characters for leaf value representation before truncation base_indent_str Base indentation string to prepend to each line is_leaf Optional function to determine if a node is a leaf Returns: A formatted string representation of the PyTree """ indent_str = " " * indent * indentation next_indent_str = " " * (indent + 1) * indentation def truncate_leaf(s: str) -> str: """Truncate leaf representation if it exceeds max_leaf_chars.""" if max_leaf_chars is None: return s # truncate individual lines for multi-line strings if "\n" in s: lines = s.split("\n") truncated_lines = [truncate_leaf(line) for line in lines] return "\n".join(truncated_lines) if len(s) > max_leaf_chars: return s[: max_leaf_chars - 3] + "..." return s # handle custom leaf detection if is_leaf is not None and is_leaf(tree): ret_str = truncate_leaf(repr(tree)) # handle dictionaries elif isinstance(tree, dict): if not tree: if indent == 0: return base_indent_str + "{}" else: return "{}" items = [] for key, value in tree.items(): formatted_value = strfmt_pytree( value, indent + 1, indentation, max_leaf_chars, base_indent_str, is_leaf, ) # if formatted_value contains newlines, indent them properly if "\n" in formatted_value: formatted_value = "\n".join( f"{base_indent_str}{next_indent_str}{line}" for line in formatted_value.split("\n") ) items.append( f"{base_indent_str}{next_indent_str}{key}:" f"\n{formatted_value}" ) else: items.append( f"{base_indent_str}{next_indent_str}{key}:" f" {formatted_value}" ) ret_str = ( "{{\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str}}}" ) # handle lists elif isinstance(tree, list): if not tree: if indent == 0: return base_indent_str + "[]" else: return "[]" items = [] for item in tree: formatted_item = strfmt_pytree( item, indent + 1, indentation, max_leaf_chars, base_indent_str, is_leaf, ) # if formatted_item contains newlines, indent them properly if "\n" in formatted_item: formatted_item = "\n".join( f"{base_indent_str}{next_indent_str}{line}" for line in formatted_item.split("\n") ) items.append(f"{formatted_item}") else: items.append( f"{base_indent_str}{next_indent_str}{formatted_item}" ) ret_str = ( "[\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str}]" ) # handle tuples elif isinstance(tree, tuple): if not tree: if indent == 0: return base_indent_str + "()" else: return "()" items = [] for item in tree: formatted_item = strfmt_pytree( item, indent + 1, indentation, max_leaf_chars, base_indent_str, is_leaf, ) # if formatted_item contains newlines, indent them properly if "\n" in formatted_item: formatted_item = "\n".join( f"{base_indent_str}{next_indent_str}{line}" for line in formatted_item.split("\n") ) items.append(f"{formatted_item}") else: items.append( f"{base_indent_str}{next_indent_str}{formatted_item}" ) ret_str = ( "(\n" + ",\n".join(items) + f",\n{base_indent_str}{indent_str})" ) # handle leaves else: ret_str = truncate_leaf(repr(tree)) if indent == 0: return base_indent_str + ret_str else: return ret_str