Manipulation Modules#

ConcatenateLeaves

Module that concatenates all the array leaves of a PyTree of arrays into a single array.

Flatten

Module that flattens the input to 1D.

Reshape

Module that reshapes the input array to a specified shape.

TreeFlatten

Module that flattens an input tree of arrays into a list of arrays

TreeKey

Module that takes an input tree and takes subtrees or leaves based on specified keypaths.