tree_util Module#

abs(pytree)[source]#

Computes the element-wise absolute value of all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • replaced by its absolute value.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

abs_sqr(pytree)[source]#

Computes the element-wise squared absolute value of all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • replaced by its squared absolute value.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

abs_sqrt(pytree)[source]#

Computes the element-wise square root of the absolute value of all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • replaced by the square root of its absolute value.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

add(pytree1, pytree2)[source]#

Computes the element-wise addition of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the inputs, where each leaf is the

  • element-wise addition of the corresponding leaves from the input PyTrees.

Parameters:
  • pytree1 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

  • pytree2 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

Return type:

PyTree[jaxtyping.Num[Array, ’…’], ’T’]

all_close(pytree1, pytree2, rtol=1e-05, atol=1e-08)[source]#

Check if two pytrees are close in structure and content.

Parameters:
Return type:

bool

all_equal(pytree1, pytree2)[source]#

Check if two pytrees are equal in structure and content.

Parameters:
Return type:

bool

astype(pytree, dtype)[source]#

Casts all leaves of a PyTree of arrays to a specified data type.

Parameters:
  • pytree (PyTree[jaxtyping.Shaped[Array, '?*d'], 'T']) – A PyTree where each leaf is an array.

  • dtype (str | type[Any] | dtype | SupportsDType) – The target data type to which each leaf should be cast. Can be a JAX numpy dtype or a string representing the dtype.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • cast to the specified data type.

Return type:

PyTree[jaxtyping.Shaped[Array, ’?*d’], ’T’]

batch_leaves(pytree, batch_size, batch_idx, length=None, axis=0)[source]#

Extracts a batch of values from all leaves of a PyTree of arrays.

Each leaf is sliced along the specified axis by [batch_idx * batch_size : batch_idx * batch_size + length] if length is provided, otherwise by [batch_idx * batch_size : (batch_idx + 1) * batch_size].

Parameters:
  • pytree (PyTree[jaxtyping.Shaped[Array, '...'], 'T']) – A PyTree where each leaf is an array with a matching size along the specified axis.

  • batch_size (int | Integer[Array, '']) – The size of the batches that the leaves are divided into.

  • batch_idx (int | Integer[Array, '']) – The index of the batch to extract.

  • length (int | Integer[Array, ''] | None) – Optional length of the slice to extract. If not provided, defaults to batch_size. Useful for getting the last batch which may be smaller than batch_size.

  • axis (int | Integer[Array, '']) – The axis along which to slice the leaves. Default is 0.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • sliced to contain only the specified batch.

Return type:

PyTree[jaxtyping.Shaped[Array, ’…’], ’T’]

div(pytree1, pytree2)[source]#

Computes the element-wise division of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the inputs, where each leaf is the

  • element-wise division of the corresponding leaves from the input PyTrees.

Parameters:
  • pytree1 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

  • pytree2 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

Return type:

PyTree[jaxtyping.Num[Array, ’…’], ’T’]

extend_structure_from_strpaths(base_pytree, strpaths, separator='.', fill_values=None)[source]#

Extends the structure of a base pytree by adding new branches specified by string paths. The new branches are initialized with given fill values or None.

New branches are created as needed, with dictionaries for string keys and lists for integer keys. Tuples are not created automatically; any tuples that need to be extended in the base PyTree will have their type preserved however.

Parameters:
  • base_pytree (PyTree[Any] | None) – The base pytree to be extended. If None, an empty PyTree with structure inferred from strpaths is created.

  • strpaths (list[str] | tuple[str, ...]) – A list or tuple of string paths specifying the new branches to be added

  • separator (str) – The separator used in the string paths. Default is ‘.’.

  • fill_values (list[Any] | tuple[Any, ...] | Any | None) – A list, tuple, or single value specifying the values to fill in the new branches. If a single value is provided, it is used for all new branches. Default is None.

Returns:

The extended pytree with new branches added.

Return type:

PyTree[Any]

Examples

>>> base_pytree = {}
>>> strpaths = ['a.b.c', 'd.0.e']
>>> extended_pytree = extend_structure_from_strpaths(
...     base_pytree,
...     strpaths,
... )
>>> extended_pytree
{'a': {'b': {'c': None}}, 'd': [ {'e': None} ]}
>>> base_pytree = {'x': 1}
>>> strpaths = ['y.z', 'y.w']
>>> extended_pytree = extend_structure_from_strpaths(
...     base_pytree,
...     strpaths,
...     fill_values=42,
... )
>>> extended_pytree
{'x': 1, 'y': {'z': 42, 'w': 42}}
>>> base_pytree = []
>>> strpaths = ['0.a', '1.1.2']
>>> fill_values = [10, 20]
>>> extended_pytree = extend_structure_from_strpaths(
...     base_pytree,
...     strpaths,
...     fill_values=fill_values,
... )
>>> extended_pytree
[ {'a': 10}, [ None, [None, None, 20] ] ]
get_shapes(pytree, axis=None)[source]#

Get the shapes of all leaves in a PyTree of arrays, optionally selecting specific axes.

Parameters:
Return type:

PyTree[tuple[int, …]]

getitem_by_strpath(pytree, strpath, separator='.', allow_early_return=False, return_remainder=False, is_leaf=None)[source]#

Get an item from a pytree using a string path. Effectively the inverse of jax.tree_util.keystr with simple=True. Optionally allows early return if the path cannot be fully traversed, returning the remaining path as well.

This function works recursively.

Parameters:
  • pytree (PyTree[Any]) – The pytree from which to get the item.

  • strpath (str) – The string path to the item, with keys/indexes separated by separator.

  • separator (str) – The separator used in the string path. Default is ‘.’.

  • allow_early_return (bool) – If True, allows the function to return early if the path cannot be fully traversed. In this case, returns a tuple of the current pytree node and the remaining path as a string.

  • return_remainder (bool) – If True, returns a tuple of the found item and the remaining path as a string, even if the full path was traversed.

  • is_leaf (Callable[Any, bool] | None) – Optional function to determine if a node is a leaf.

Returns:

The item at the specified path in the pytree.

Return type:

Any

Examples

>>> pytree = {'a': [1, 2, {'b': 3}], 'c': 4}
>>> getitem_by_strpath(pytree, 'a.2.b')
3
>>> getitem_by_strpath(pytree, 'c', return_remainder=True)
4, ''
>>> pytree = {'a': [1, 2, 3], 'c': 4}
>>> getitem_by_strpath(
...     pytree,
...     'a.2.b',
...     allow_early_return=True,
...     return_remainder=True
... )
(3, 'b')
>>> getitem_by_strpath(
...     pytree,
...     'a.2.b.0.c',
...     allow_early_return=True,
...     return_remainder=True
... )
(3, 'b.0.c')
has_uniform_leaf_shapes(pytree, axis=None)[source]#

Checks if all leaves of a PyTree of arrays have the same shape along the specified axes. Returns False if any leaves are not arrays.

Parameters:
  • pytree (PyTree[jaxtyping.Shaped[Array, '...']]) – The PyTree to check.

  • axis (int | tuple[int, ...] | slice | None) – The axes along which to check for uniformity. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are checked.

Returns:

  • True if all leaves have the same shape along the specified axes, False

  • otherwise.

Return type:

bool

is_shape_leaf(obj, allow_none=True)[source]#

Check if the given object is a shape leaf, i.e., a tuple of integers representing the shape of an array, or None.

Parameters:
  • obj (Any) – The object to check.

  • allow_none (bool)

Returns:

True if the object is a shape leaf, False otherwise.

Return type:

bool

is_single_leaf(pytree, ndim=None, shape=None, is_leaf=None)[source]#

Check if a pytree consists of a single leaf node, optionally verifying the leaf’s number of dimensions and shape.

Parameters:
Return type:

tuple[bool, Any]

make_mutable(pytree)[source]#

Convert all tuples in a pytree to lists for mutability.

Parameters:

pytree (PyTree[Any]) – The pytree to convert.

Returns:

A new pytree with all tuples converted to lists.

Return type:

PyTree[Any]

mean(pytree)[source]#

Computes the mean of all elements in all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

The mean of all elements across all leaves in the PyTree.

Return type:

Num[Array, ’’]

mul(pytree1, pytree2)[source]#

Computes the element-wise multiplication of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the inputs, where each leaf is the

  • element-wise multiplication of the corresponding leaves from the input

  • PyTrees.

Parameters:
  • pytree1 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

  • pytree2 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

Return type:

PyTree[jaxtyping.Num[Array, ’…’], ’T’]

neg(pytree)[source]#

Computes the element-wise negation of all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • negated.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

pow(pytree, exponent)[source]#

Raises all leaves of a PyTree of arrays to a specified power.

Parameters:
  • pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

  • exponent (float | int) – The power to which each leaf should be raised.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • raised to the specified power.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

random_permute_leaves(pytree, key, axis=0, independent_arrays=False, independent_leaves=False)[source]#

Randomly permutes the arrays in the leaves of a PyTree of arrays along a specified axis.

Parameters:
  • pytree (PyTree[jaxtyping.Shaped[Array, '...'], 'T']) – A PyTree where each leaf is an array with a matching size along the specified axis.

  • key (Any) – A JAX PRNG key used for generating random permutations.

  • axis (int) – The axis along which to permute the leaves. Default is 0.

  • independent_arrays (bool) – If True, each individual vector along the given axis is shuffled independently. Default is False. See the independent argument of jax.random.permutation for more details.

  • independent_leaves (bool) – If True, each leaf in the PyTree is permuted independently using different random keys. Default is False.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • randomly permuted along the specified axis.

Return type:

PyTree[jaxtyping.Shaped[Array, ’…’], ’T’]

safecast(X, dtype)[source]#

Safely cast input data to a specified dtype, ensuring that complex types are not inadvertently cast to float types. And issues a warning if the requested dtype was not successfully applied, usually due to JAX settings.

Parameters:
  • X (PyTree[jaxtyping.Num[Array, '...'], 'T']) – Input data to be cast.

  • dtype (str | type[Any] | dtype | SupportsDType) – Desired data type for the output.

Return type:

PyTree[jaxtyping.Num[Array, ’…’], ’T’]

scalar_add(pytree, scalar)[source]#

Adds a scalar value to all leaves of a PyTree of arrays.

Parameters:
  • pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

  • scalar (complex | float | int | Num[Array, '']) – The scalar value to add to each leaf.

Returns:

  • A PyTree with the same structure as the input, but with the scalar

  • added to each leaf.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

scalar_mul(pytree, scalar)[source]#

Multiplies all leaves of a PyTree of arrays by a scalar value.

Parameters:
  • pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

  • scalar (complex | float | int | Num[Array, '']) – The scalar value to multiply each leaf by.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • multiplied by the scalar.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

setitem_by_strpath(pytree, strpath, value, separator='.', is_leaf=None)[source]#

Set an item in a pytree using a string path. Effectively the inverse of jax.tree_util.keystr with simple=True. This function works recursively. :param pytree: The pytree in which to set the item. :param strpath: The string path to the item, with keys/indexes separated by

separator.

Parameters:
  • value (Any) – The value to set at the specified path.

  • separator (str) – The separator used in the string path. Default is ‘.’.

  • is_leaf (Callable[Any, bool] | None) – Optional function to determine if a node is a leaf.

  • pytree (PyTree[Any])

  • strpath (str)

Returns:

None. The pytree is modified in place.

Return type:

None

shapes_equal(pytree1, pytree2, axis=None)[source]#

Checks if both the structure and the shapes of the leaves of two PyTrees with array leaves match along the specified axes.

Parameters:
  • pytree1 (PyTree[jaxtyping.Shaped[Array, '...']]) – The first PyTree to compare.

  • pytree2 (PyTree[jaxtyping.Shaped[Array, '...']]) – The second PyTree to compare.

  • axis (int | tuple[int, ...] | slice | None) – The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared.

Returns:

  • True if both the structure and the shapes of the leaves match along the

  • specified axes, False otherwise.

Return type:

bool

sqrt(pytree)[source]#

Computes the element-wise square root of all leaves of a PyTree of arrays.

Parameters:

pytree (PyTree[jaxtyping.Num[Array, '*d'], 'T']) – A PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the input, but with each leaf

  • replaced by its square root.

Return type:

PyTree[jaxtyping.Num[Array, ’*d’], ’T’]

strfmt_pytree(tree, indent=0, indentation=1, max_leaf_chars=None, base_indent_str='', is_leaf=None)[source]#

Format a JAX PyTree into a nicely indented string representation.

Parameters:
  • tree (PyTree) – An arbitrary JAX PyTree (dict, list, tuple, or leaf value)

  • indent (int) – Current indentation level (used for recursion)

  • indentation (int) – Number of spaces to indent for each level

  • max_leaf_chars (int | None) – Maximum characters for leaf value representation before truncation

  • base_indent_str (str) – Base indentation string to prepend to each line

  • is_leaf (Callable[[PyTree], bool] | None) – Optional function to determine if a node is a leaf

Returns:

A formatted string representation of the PyTree

Return type:

str

sub(pytree1, pytree2)[source]#

Computes the element-wise subtraction of two PyTrees of arrays with the same structure. :param pytree1: The first PyTree where each leaf is an array. :param pytree2: The second PyTree where each leaf is an array.

Returns:

  • A PyTree with the same structure as the inputs, where each leaf is the

  • element-wise subtraction of the corresponding leaves from the input

  • PyTrees.

Parameters:
  • pytree1 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

  • pytree2 (PyTree[jaxtyping.Num[Array, '...'], 'T'])

Return type:

PyTree[jaxtyping.Num[Array, ’…’], ’T’]

uniform_leaf_shapes_equal(pytree1, pytree2, axis=None)[source]#

Checks if the shapes of the leaves of two PyTrees with array leaves match along the specified axes. First checks that both PyTrees have uniform leaf shapes along the specified axes.

Parameters:
  • pytree1 (PyTree[jaxtyping.Shaped[Array, '...']]) – The first PyTree to compare.

  • pytree2 (PyTree[jaxtyping.Shaped[Array, '...']]) – The second PyTree to compare.

  • axis (int | tuple[int, ...] | slice | None) – The axes along which to compare the shapes of the leaves. Can be an integer, a tuple of integers, a slice, or None. If None, all axes are compared.

Returns:

  • True if the shapes of the leaves match along the specified axes, False

  • otherwise.

Return type:

bool