tree_util Module#
- abs_sqr(pytree)[source]#
Computes the element-wise squared absolute value of all leaves of a PyTree of arrays.
- abs_sqrt(pytree)[source]#
Computes the element-wise square root of the absolute value of all leaves of a PyTree of arrays.
- add(pytree1, pytree2)[source]#
Computes the element-wise addition of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.
- Returns:
A PyTree with the same structure as the inputs, where each leaf is the
element-wise addition of the corresponding leaves from the input PyTrees.
- Parameters:
- Return type:
PyTree[jaxtyping.Num[Array, ’…’], ’T’]
- all_close(pytree1, pytree2, rtol=1e-05, atol=1e-08)[source]#
Check if two pytrees are close in structure and content.
- astype(pytree, dtype)[source]#
Casts all leaves of a PyTree of arrays to a specified data type.
- Parameters:
- Returns:
A PyTree with the same structure as the input, but with each leaf
cast to the specified data type.
- Return type:
PyTree[jaxtyping.Shaped[Array, ’?*d’], ’T’]
- batch_leaves(pytree, batch_size, batch_idx, length=None, axis=0)[source]#
Extracts a batch of values from all leaves of a PyTree of arrays.
Each leaf is sliced along the specified axis by
[batch_idx * batch_size : batch_idx * batch_size + length]iflengthis provided, otherwise by[batch_idx * batch_size : (batch_idx + 1) * batch_size].- Parameters:
pytree (PyTree[jaxtyping.Shaped[Array, '...'], 'T']) – A PyTree where each leaf is an array with a matching size along the specified axis.
batch_size (int | Integer[Array, '']) – The size of the batches that the leaves are divided into.
batch_idx (int | Integer[Array, '']) – The index of the batch to extract.
length (int | Integer[Array, ''] | None) – Optional length of the slice to extract. If not provided, defaults to batch_size. Useful for getting the last batch which may be smaller than batch_size.
axis (int | Integer[Array, '']) – The axis along which to slice the leaves. Default is 0.
- Returns:
A PyTree with the same structure as the input, but with each leaf
sliced to contain only the specified batch.
- Return type:
PyTree[jaxtyping.Shaped[Array, ’…’], ’T’]
- div(pytree1, pytree2)[source]#
Computes the element-wise division of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.
- Returns:
A PyTree with the same structure as the inputs, where each leaf is the
element-wise division of the corresponding leaves from the input PyTrees.
- Parameters:
- Return type:
PyTree[jaxtyping.Num[Array, ’…’], ’T’]
- extend_structure_from_strpaths(base_pytree, strpaths, separator='.', fill_values=None)[source]#
Extends the structure of a base pytree by adding new branches specified by string paths. The new branches are initialized with given fill values or None.
New branches are created as needed, with dictionaries for string keys and lists for integer keys. Tuples are not created automatically; any tuples that need to be extended in the base PyTree will have their type preserved however.
- Parameters:
base_pytree (PyTree[Any] | None) – The base pytree to be extended. If None, an empty PyTree with structure inferred from strpaths is created.
strpaths (list[str] | tuple[str, ...]) – A list or tuple of string paths specifying the new branches to be added
separator (str) – The separator used in the string paths. Default is ‘.’.
fill_values (list[Any] | tuple[Any, ...] | Any | None) – A list, tuple, or single value specifying the values to fill in the new branches. If a single value is provided, it is used for all new branches. Default is None.
- Returns:
The extended pytree with new branches added.
- Return type:
PyTree[Any]
Examples
>>> base_pytree = {} >>> strpaths = ['a.b.c', 'd.0.e'] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... ) >>> extended_pytree {'a': {'b': {'c': None}}, 'd': [ {'e': None} ]}
>>> base_pytree = {'x': 1} >>> strpaths = ['y.z', 'y.w'] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... fill_values=42, ... ) >>> extended_pytree {'x': 1, 'y': {'z': 42, 'w': 42}}
>>> base_pytree = [] >>> strpaths = ['0.a', '1.1.2'] >>> fill_values = [10, 20] >>> extended_pytree = extend_structure_from_strpaths( ... base_pytree, ... strpaths, ... fill_values=fill_values, ... ) >>> extended_pytree [ {'a': 10}, [ None, [None, None, 20] ] ]
- get_shapes(pytree, axis=None)[source]#
Get the shapes of all leaves in a PyTree of arrays, optionally selecting specific axes.
- getitem_by_strpath(pytree, strpath, separator='.', allow_early_return=False, return_remainder=False, is_leaf=None)[source]#
Get an item from a pytree using a string path. Effectively the inverse of
jax.tree_util.keystrwithsimple=True. Optionally allows early return if the path cannot be fully traversed, returning the remaining path as well.This function works recursively.
- Parameters:
pytree (PyTree[Any]) – The pytree from which to get the item.
strpath (str) – The string path to the item, with keys/indexes separated by
separator.separator (str) – The separator used in the string path. Default is ‘.’.
allow_early_return (bool) – If True, allows the function to return early if the path cannot be fully traversed. In this case, returns a tuple of the current pytree node and the remaining path as a string.
return_remainder (bool) – If True, returns a tuple of the found item and the remaining path as a string, even if the full path was traversed.
is_leaf (Callable[Any, bool] | None) – Optional function to determine if a node is a leaf.
- Returns:
The item at the specified path in the pytree.
- Return type:
Any
Examples
>>> pytree = {'a': [1, 2, {'b': 3}], 'c': 4} >>> getitem_by_strpath(pytree, 'a.2.b') 3 >>> getitem_by_strpath(pytree, 'c', return_remainder=True) 4, ''
>>> pytree = {'a': [1, 2, 3], 'c': 4} >>> getitem_by_strpath( ... pytree, ... 'a.2.b', ... allow_early_return=True, ... return_remainder=True ... ) (3, 'b') >>> getitem_by_strpath( ... pytree, ... 'a.2.b.0.c', ... allow_early_return=True, ... return_remainder=True ... ) (3, 'b.0.c')
- has_uniform_leaf_shapes(pytree, axis=None)[source]#
Checks if all leaves of a PyTree of arrays have the same shape along the specified axes. Returns False if any leaves are not arrays.
- Parameters:
- Returns:
True if all leaves have the same shape along the specified axes, False
otherwise.
- Return type:
- is_shape_leaf(obj, allow_none=True)[source]#
Check if the given object is a shape leaf, i.e., a tuple of integers representing the shape of an array, or None.
- is_single_leaf(pytree, ndim=None, shape=None, is_leaf=None)[source]#
Check if a pytree consists of a single leaf node, optionally verifying the leaf’s number of dimensions and shape.
- mean(pytree)[source]#
Computes the mean of all elements in all leaves of a PyTree of arrays.
- Parameters:
pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.
- Returns:
The mean of all elements across all leaves in the PyTree.
- Return type:
Num[Array, ’’]
- mul(pytree1, pytree2)[source]#
Computes the element-wise multiplication of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.
- Returns:
A PyTree with the same structure as the inputs, where each leaf is the
element-wise multiplication of the corresponding leaves from the input
PyTrees.
- Parameters:
- Return type:
PyTree[jaxtyping.Num[Array, ’…’], ’T’]
- pow(pytree, exponent)[source]#
Raises all leaves of a PyTree of arrays to a specified power.
- Parameters:
- Returns:
A PyTree with the same structure as the input, but with each leaf
raised to the specified power.
- Return type:
PyTree[jaxtyping.Num[Array, ’*d’], ’T’]
- random_permute_leaves(pytree, key, axis=0, independent_arrays=False, independent_leaves=False)[source]#
Randomly permutes the arrays in the leaves of a PyTree of arrays along a specified axis.
- Parameters:
pytree (PyTree[jaxtyping.Shaped[Array, '...'], 'T']) – A PyTree where each leaf is an array with a matching size along the specified axis.
key (Any) – A JAX PRNG key used for generating random permutations.
axis (int) – The axis along which to permute the leaves. Default is 0.
independent_arrays (bool) – If True, each individual vector along the given axis is shuffled independently. Default is False. See the
independentargument ofjax.random.permutationfor more details.independent_leaves (bool) – If True, each leaf in the PyTree is permuted independently using different random keys. Default is False.
- Returns:
A PyTree with the same structure as the input, but with each leaf
randomly permuted along the specified axis.
- Return type:
PyTree[jaxtyping.Shaped[Array, ’…’], ’T’]
- safecast(X, dtype)[source]#
Safely cast input data to a specified dtype, ensuring that complex types are not inadvertently cast to float types. And issues a warning if the requested dtype was not successfully applied, usually due to JAX settings.
- scalar_add(pytree, scalar)[source]#
Adds a scalar value to all leaves of a PyTree of arrays.
- Parameters:
- Returns:
A PyTree with the same structure as the input, but with the scalar
added to each leaf.
- Return type:
PyTree[jaxtyping.Num[Array, ’*d’], ’T’]
- scalar_mul(pytree, scalar)[source]#
Multiplies all leaves of a PyTree of arrays by a scalar value.
- Parameters:
- Returns:
A PyTree with the same structure as the input, but with each leaf
multiplied by the scalar.
- Return type:
PyTree[jaxtyping.Num[Array, ’*d’], ’T’]
- setitem_by_strpath(pytree, strpath, value, separator='.', is_leaf=None)[source]#
Set an item in a pytree using a string path. Effectively the inverse of
jax.tree_util.keystrwithsimple=True. This function works recursively. :param pytree: The pytree in which to set the item. :param strpath: The string path to the item, with keys/indexes separated byseparator.- Parameters:
- Returns:
None. The pytree is modified in place.
- Return type:
None
- shapes_equal(pytree1, pytree2, axis=None)[source]#
Checks if both the structure and the shapes of the leaves of two PyTrees with array leaves match along the specified axes.
- Parameters:
pytree1 (PyTree[jaxtyping.Shaped[Array, '...']]) – The first PyTree to compare.
pytree2 (PyTree[jaxtyping.Shaped[Array, '...']]) – The second PyTree to compare.
axis (int | tuple[int, ...] | slice | None) – The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared.
- Returns:
True if both the structure and the shapes of the leaves match along the
specified axes, False otherwise.
- Return type:
- strfmt_pytree(tree, indent=0, indentation=1, max_leaf_chars=None, base_indent_str='', is_leaf=None)[source]#
Format a JAX PyTree into a nicely indented string representation.
- Parameters:
tree (PyTree) – An arbitrary JAX PyTree (dict, list, tuple, or leaf value)
indent (int) – Current indentation level (used for recursion)
indentation (int) – Number of spaces to indent for each level
max_leaf_chars (int | None) – Maximum characters for leaf value representation before truncation
base_indent_str (str) – Base indentation string to prepend to each line
is_leaf (Callable[[PyTree], bool] | None) – Optional function to determine if a node is a leaf
- Returns:
A formatted string representation of the PyTree
- Return type:
- sub(pytree1, pytree2)[source]#
Computes the element-wise subtraction of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.
- Returns:
A PyTree with the same structure as the inputs, where each leaf is the
element-wise subtraction of the corresponding leaves from the input
PyTrees.
- Parameters:
- Return type:
PyTree[jaxtyping.Num[Array, ’…’], ’T’]
- uniform_leaf_shapes_equal(pytree1, pytree2, axis=None)[source]#
Checks if the shapes of the leaves of two PyTrees with array leaves match along the specified axes. First checks that both PyTrees have uniform leaf shapes along the specified axes.
- Parameters:
pytree1 (PyTree[jaxtyping.Shaped[Array, '...']]) – The first PyTree to compare.
pytree2 (PyTree[jaxtyping.Shaped[Array, '...']]) – The second PyTree to compare.
axis (int | tuple[int, ...] | slice | None) – The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared.
- Returns:
True if the shapes of the leaves match along the specified axes, False
otherwise.
- Return type: