import jax
import jax.numpy as np
from beartype import beartype
from jaxtyping import Array, Inexact, PyTree, jaxtyped
from ..tree_util import is_shape_leaf
from ..typing import (
Any,
Data,
DataShape,
Dict,
HyperParams,
ModuleCallable,
OrderedSet,
Params,
State,
Tuple,
)
from .basemodule import BaseModule
[docs]
class Einsum(BaseModule):
"""
Module that implements Einsum operations between array leaves in a PyTree
input. Optionally with leading trainable arrays. The einsum operation is
defined by the provided einsum string, which must specify all indices
except for the batch index. The batch index should not be included in the
provided einsum string, as it will be automatically added as the leading
index later.
Examples
--------
Taking in a PyTree with two leaves, each being a 1D array (excluding the
batch dimension), the dot product can be expressed as:
>>> einsum_module = Einsum("i,i->")
Taking in a PyTree with three leaves, each of which being a matrix
(excluding the batch dimension), and computing the matrix multiplication
of a trainable matrix with the elementwise product of these three input
matrices can be expressed as:
>>> einsum_module = Einsum("ij,jk,jk,jk->ik")
Since the einsum string in this example has four input arrays (separated by
commas), when the module is compiled with only three input leaves, it will
infer that the leading array is trainable, and initialize it randomly
during compilation. This array can be directly specified by calling
``set_params`` or by providing it in the ``params`` argument during
initialization.
Alternatively, to perform the same operation but with the leading array
fixed (not trainable), the module must be initialized with this array
specified in the ``params`` argument and ``trainable`` set to ``False``:
>>> W = np.random.normal(size=(input_dim, output_dim))
>>> einsum_module = Einsum(
... "ij,jk,jk,jk->ik",
... params = W,
... trainable = False)
In this case, the fixed arrays can be specified later by calling
``set_hyperparameters`` with a dictionary containing the key ``params``.
Any additional trainable or fixed arrays will always be treated as leading
arrays in the einsum operation.
If no additional fixed or trainable arrays are to be used, the einsum
string can alternatively be provided as a two-element Tuple
consisting of a PyTree of strings with the same structure as the input data
and string representing the output of the einsum, which can be omitted if
the output string is to be inferred. For example, for a PyTree input with
structure ``PyTree([*, (*, *)])``, with each leaf being a 1D array, to
specify the operation of the three-way outer product between the three
leaves in the order ``PyTree([2, (0, 1)])``, the einsum string can be
provided in any of the following equivalent ways:
>>> Einsum('c,a,b->abc') # full string
>>> Einsum('c,a,b') # output inferred 'abc'
>>> Einsum(['k', ('i', 'j')]) # output inferred 'ijk'
>>> Einsum((['k', ('i', 'j')], 'ijk')) # full tuple
>>> Einsum((['a', ('b', 'c')], 'bca')) # full tuple
If additional fixed or trainable arrays are to be used, the einsum string
can be provided as a three-element tuple where the first element is an
einsum str for the additional arrays, the second element is a PyTree of
strings with the same structure as the input
data, and the third element is the output string, which can be omitted if
to be inferred. For example, to perform the same three-way outer product as
above but with a leading trainable array, the einsum string can be provided
in any of the following equivalent ways:
>>> Einsum('ab,c,a,b->c') # full string
>>> Einsum('ab,c,a,b') # output inferred 'c'
>>> Einsum(('ij', ['k', ('i', 'j')])) # output inferred 'k'
>>> Einsum(('ij', ['k', ('i', 'j')], 'k')) # full tuple
For multiple leading arrays the following are equivalent:
>>> Einsum('ab,cd,c,a,b->d') # full string
>>> Einsum('ab,cd,c,a,b') # output inferred 'd'
>>> Einsum(('ab,cd', ['c', ('a', 'b')])) # output inferred 'd'
>>> Einsum(('ab,cd', ['c', ('a', 'b')], 'd')) # full tuple
"""
[docs]
def __init__(
self,
einsum_str: (
Tuple[str, PyTree[str], str]
| Tuple[str, PyTree[str]]
| Tuple[PyTree[str], str]
| PyTree[str]
| str
| None
) = None,
params: (
PyTree[Inexact[Array, "..."]] | Inexact[Array, "..."] | None
) = None,
dim_map: Dict[str, int] | None = None,
trainable: bool = False,
init_magnitude: float = 1e-2,
real: bool = True,
) -> None:
"""
Initialize an ``Einsum`` module.
Parameters
----------
einsum_str
The einsum string defining the operation. The batch index
should not be included in the provided einsum string, as it
will be automatically added as the leading index later. Can be
provided as a single string, a PyTree of input_strings, a Tuple
of (PyTree[input_strings], output_string), or a tuple of
(leading_arrays_einsum_str, PyTree[input_strings],
output_string). The input_strings in a PyTree must have
the same structure as the input data. output_string can be
omitted to have it inferred. If ``None``, it
must be set before compilation via ``set_hyperparameters`` with
the ``einsum_str`` key. Default is ``None``.
params
Optional additional leading arrays for the einsum operation. If
trainable is ``True``, these will be treated as the initial
values for trainable arrays. If ``False``, they will be treated
as fixed arrays. If ``None`` and trainable is ``True``, the
leading arrays will be initialized randomly during compilation.
Default is ``None``. Can be provided later via
``set_hyperparameters`` with the ``params`` key if
``trainable`` is ``False``, or via ``set_params`` if
``trainable`` is ``True``.
dim_map
Dictionary mapping einsum indices (characters) to integer
sizes for the array dimensions. Only entries for indices that
cannot be inferred from the input data shapes or parameter
shapes need to be provided. Default is ``None``.
trainable
Whether the provided ``params`` are trainable or fixed. If
``True``, the arrays in ``params`` will be treated as initial
values for trainable arrays. If ``False``, they will be treated
as fixed arrays. Default is ``False``.
init_magnitude
Magnitude for the random initialization of weights.
Default is ``1e-2``.
real
Ignored when there are no trainable arrays. If ``True``, the
weights and biases will be real-valued. If ``False``, they will
be complex-valued. Default is ``True``.
"""
self.einsum_str = einsum_str
self.params = params
self.dim_map = dim_map
self.trainable = trainable
self.init_magnitude = init_magnitude
self.real = real
self.input_shape: DataShape | None = None
[docs]
def _get_dimension_map(
self,
concrete_einsum_str: str,
input_shape: DataShape,
) -> Dict[str, int]:
r"""
Fill in the dimension map by inferring sizes from the input shapes
and parameter shapes based on the provided concrete einsum string and
``self.params`` if applicable.
Parameters
----------
concrete_einsum_str
The concrete einsum string with all indices specified,
including the output indices and batch index.
input_shape
The shape of the input data, used to infer dimension sizes.
Should not include the batch dimension.
Returns
-------
A complete dimension map with sizes for all indices in the
einsum string.
"""
dim_map = (
dict({k: v for k, v in self.dim_map.items()})
if self.dim_map
else {}
)
# get the batch dimension character from the concrete einsum string and
# remove it
input_str, output_str = concrete_einsum_str.split("->")
batch_char = output_str[0]
input_strs = input_str.replace(batch_char, "").split(",")
input_shapes_list = jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)
# split input_strs into leading and input based on number of input
# arrays
num_input_arrays = len(input_shapes_list)
leading_strs = input_strs[:-num_input_arrays]
input_strs = input_strs[-num_input_arrays:]
# infer from input arrays
for s, shape in zip(input_strs, input_shapes_list):
if len(s) != len(shape):
raise ValueError(
f"Einsum input string '{s}' has length {len(s)}, but "
f"corresponding input array has shape {shape}."
)
for char, size in zip(s, shape):
if char not in dim_map:
dim_map[char] = size
else:
if dim_map[char] != size:
raise ValueError(
f"Dimension size mismatch for index '{char}': "
f"got size {size} from input shape {shape}, "
"but previously recorded size in `dim_map` is "
f"{dim_map[char]}."
)
# infer from parameter arrays if they exist
if self.params is not None:
param_shapes_list = (
[self.params.shape]
if isinstance(self.params, np.ndarray)
else jax.tree.leaves(
self.params, is_leaf=lambda x: isinstance(x, np.ndarray)
)
)
if len(leading_strs) != len(param_shapes_list):
raise ValueError(
f"Number of leading einsum strings ({len(leading_strs)}) "
"does not match number of parameter arrays "
f"({len(param_shapes_list)})."
)
for s, shape in zip(leading_strs, param_shapes_list):
if len(s) != len(shape):
raise ValueError(
f"Einsum leading string '{s}' has length {len(s)}, "
"but corresponding parameter array has shape "
f"{shape}."
)
for char, size in zip(s, shape):
if char not in dim_map:
dim_map[char] = size
else:
if dim_map[char] != size:
raise ValueError(
"Dimension size mismatch for index "
f"'{char}': got size {size} from parameter "
f"shape {shape}, but previously recorded "
f"size in `dim_map` is {dim_map[char]}."
)
# now, all indices in the einsum string should be in dim_map, except
# for the batch index, including leading arrays and output indices
all_indices = OrderedSet(
concrete_einsum_str.replace(",", "")
.replace("->", "")
.replace(batch_char, "")
)
missing_indices = all_indices - OrderedSet(dim_map.keys())
if missing_indices:
raise ValueError(
f"Could not infer sizes for indices {missing_indices}. "
"Please provide their sizes in the `dim_map` argument."
)
return dim_map
[docs]
def _get_concrete_einsum_str(self, input_shape: DataShape) -> str:
r"""
Get the concrete einsum string by parsing the provided einsum string,
adding batch indices and inferring any missing output indices. To
account for batch dimensions. If the einsum string is a PyTree, it must
have the same structure as the input data.
Parameters
----------
input_shape
The shape of the input data, used to infer any missing output
indices as well as validate existing indices. Should not
include the batch dimension.
Examples
--------
>>> m = Einsum("ij,jk->ik")
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
'aij,ajk->aik' # leading index 'a' added for batch dimension
>>> m = Einsum((('ij', 'jk'), 'ik'))
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
'aij,ajk->aik'
>>> m = Einsum("ab,bc->ac")
>>> m._get_concrete_einsum_str(((5, 2), (2, 4)))
'dab,dbc->dac' # leading index 'd' added for batch dimension
>>> m = Einsum((['ij', 'jk'], 'ik'))
>>> m._get_concrete_einsum_str(((2, 3), (3, 4)))
ValueError: The structure of the einsum_str PyTree must match that
of the input data.
# (since the input data is a Tuple of two arrays, not a List)
>>> m = Einsum(('ij,jk', {'x': 'ik', 'y': 'ab'}, 'ab'))
>>> m._get_concrete_einsum_str({'x': (2, 3), 'y': (3, 4)})
'ij,jk,cik,cab->cab' # leading arrays don't have batch index
# all arrays from PyTrees are inserted in the same order as the
# list from jax.tree.leaves(...)
"""
# standardize einsum_str to the single string case
if self.einsum_str is None:
raise ValueError("einsum_str must be set before concretization.")
was_pytree = False
# if the einsum_str structure matches the input_shape structure, it's
# the PyTree case
input_struct = jax.tree.structure(input_shape, is_leaf=is_shape_leaf)
einsum_str_struct = jax.tree.structure(self.einsum_str)
if einsum_str_struct == input_struct:
leading_str = None
input_strs = self.einsum_str
output_str = None
was_pytree = True
elif isinstance(self.einsum_str, tuple):
if len(self.einsum_str) == 3:
leading_str, input_strs, output_str = self.einsum_str
elif len(self.einsum_str) == 2:
first, second = self.einsum_str
# two cases: (leading_str, input_strs) or
# (input_strs, output_str)
# can distinguish based on which is a bare string
# if both are bare strings, then it's ambiguous (input is a
# degenerate PyTree of a single array and the user should use
# the single string case instead)
first_is_str = isinstance(first, str)
second_is_str = isinstance(second, str)
if first_is_str and not second_is_str:
leading_str = first
input_strs = second
output_str = None
elif not first_is_str and second_is_str:
leading_str = None
input_strs = first
output_str = second
else:
# this case should be caught by the structure check above,
# but just in case
raise ValueError(
"If einsum_str is a tuple of length 2, one element "
"must be a bare string and the other must be a "
"non-degenerate PyTree of strings matching the "
"structure of input_shape. If input_shape is a "
"degenerate PyTree representing bare array input, use "
"the single string einsum_str format instead. E.g. "
f"use Einsum('{first},{second}') or "
f"Einsum('{first}->{second}') depending on your "
"intention instead of "
f"Einsum(('{first}', '{second}'))."
)
else:
raise ValueError(
"If einsum_str is a tuple not matching the structure of "
"input_shape, it must have length 2 or 3."
)
elif isinstance(self.einsum_str, str):
leading_str = None
input_strs = self.einsum_str
output_str = None
else: # PyTree case
# verify the structure matches input_shape
einsum_struct = jax.tree.structure(self.einsum_str)
input_struct = jax.tree.structure(
input_shape, is_leaf=is_shape_leaf
)
if einsum_struct != input_struct:
raise ValueError(
"The structure of the einsum_str PyTree must match that "
f"of the input data. Got {einsum_struct} but "
f"expected {input_struct}."
)
leading_str = None
input_strs = self.einsum_str
output_str = None
# if output_str is None, see if it is included in input_strs, which is
# only possible if input_strs is a bare string containing '->'
if (
output_str is None
and isinstance(input_strs, str)
and "->" in input_strs
):
input_strs, output_str = input_strs.split("->")
# if leading_str is None, see if it is included in input_strs, which is
# only possible if input_strs is a bare string
if leading_str is None and isinstance(input_strs, str):
input_str_list = input_strs.split(",")
# if the number of input strings is greater than the number of
# input arrays, then the leading strings are included here
num_input_arrays = len(
jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)
)
num_input_strings = len(input_str_list)
if num_input_strings > num_input_arrays:
leading_str = ",".join(
input_str_list[: num_input_strings - num_input_arrays]
)
input_strs = ",".join(
input_str_list[num_input_strings - num_input_arrays :]
)
# now leading_str, input_strs, and output_str are properly separated
# validate the number of arrays in the input_strs matches input_shape
num_input_arrays = len(
jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)
)
if isinstance(input_strs, str):
num_input_strings = len(input_strs.split(","))
else:
was_pytree = True
# no input strings can have non-alphabetic characters, we check
# that here
if not jax.tree.all(
jax.tree.map(
lambda s: isinstance(s, str) and s.isalpha(),
input_strs,
)
):
raise ValueError(
"All input strings in the einsum_str PyTree must be "
"bare strings containing only alphabetic characters."
)
num_input_strings = len(
jax.tree.leaves(
input_strs, is_leaf=lambda x: isinstance(x, str)
)
)
# if leading_str is None then params must be None or empty
# if it's not None, params may still be None (to be initialized later)
if leading_str is None and not (
self.params is None or self.params == ()
):
raise ValueError(
"If einsum_str does not specify leading arrays, then "
"params must be None or empty."
)
# if output_str is not None and contains either ',' or '->', raise
# error
if output_str is not None:
if ("," in output_str) or ("->" in output_str):
raise ValueError(
"output_str cannot contain ',' or '->'. Got "
f"'{output_str}'."
)
# verify that all strings contain only valid characters (a-z, A-Z, ',')
valid_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,"
def validate_chars(s: str) -> None:
invalid_chars = set(s) - set(valid_chars)
if invalid_chars:
raise ValueError(
f"Einsum string '{s}' contains invalid characters: "
f"{invalid_chars}. Only a-z, A-Z, and ',' are allowed. "
)
if leading_str is not None:
validate_chars(leading_str)
if output_str is not None:
validate_chars(output_str)
if isinstance(input_strs, str):
validate_chars(input_strs)
else:
jax.tree.map(validate_chars, input_strs)
# if input_strs is a bare string, split on commas
if isinstance(input_strs, str):
input_strs = input_strs.split(",")
# now, leading_str and output_str are either strings or None
# and input_strs is a PyTree of strings
# put the input strings in the order of the input_shape
# so long as the original input was a PyTree
if was_pytree:
input_strs = jax.tree.map(
lambda _, s: s,
input_shape,
input_strs,
is_leaf=is_shape_leaf,
)
# reduce over input_strs to single string
def reduce_input_strs(carry: str, s: str) -> str:
if carry == "":
return s
else:
return carry + "," + s
input_str = jax.tree.reduce(
reduce_input_strs,
input_strs,
initializer="",
)
# infer output_str if None
# it should be all the indices that appear only once in
# leading_str + input_str, and be in alphabetical order, same as einsum
# would do
if output_str is None:
all_input = ""
if leading_str is not None:
all_input += leading_str
all_input += input_str
index_counts = {}
for char in all_input:
if char != ",":
if char in index_counts:
index_counts[char] += 1
else:
index_counts[char] = 1
output_indices = [
char
for char in sorted(index_counts.keys())
if index_counts[char] == 1
]
output_str = "".join(output_indices)
# find all used indices to find a batch index that doesn't conflict
# use OrderedSet to have a deterministic order
used_indices = OrderedSet()
if leading_str is not None:
used_indices.update(OrderedSet(leading_str.replace(",", "")))
used_indices.update(OrderedSet(input_str.replace(",", "")))
if output_str is not None:
used_indices.update(OrderedSet(output_str))
available_indices = (
OrderedSet(valid_chars.replace(",", "")) - used_indices
)
if not available_indices:
raise ValueError(
"No available indices to use for batch dimension. "
"Einsum strings are using all possible indices."
)
batch_index = available_indices[0]
# before adding batch index, validate that the input
# shapes match the input_strs
input_str_list = input_str.split(",")
def validate_shape(s: str, shape: Tuple[int, ...]) -> None:
if len(s) != len(shape):
raise ValueError(
f"Einsum input string '{s}' has length {len(s)}, but "
f"corresponding input array has shape {shape}."
)
for s, shape in zip(
input_str_list,
jax.tree.leaves(input_shape, is_leaf=is_shape_leaf),
):
validate_shape(s, shape)
# add batch index to input_str and output_str only
# leading_str does not get a batch index
input_str = ",".join([batch_index + s for s in input_str.split(",")])
output_str = batch_index + output_str
# construct full einsum string
full_einsum_str = input_str + "->" + output_str
if leading_str is not None:
full_einsum_str = leading_str + "," + full_einsum_str
return full_einsum_str
@property
def name(self) -> str:
concrete_einsum_str = (
self._get_concrete_einsum_str(self.input_shape)
if self.input_shape is not None
else self.einsum_str
)
return f"Einsum({concrete_einsum_str})"
[docs]
def is_ready(self) -> bool:
return self.input_shape is not None
[docs]
def _get_callable(self) -> ModuleCallable:
# set up the callable
concrete_einsum_str = self._get_concrete_einsum_str(self.input_shape)
@jaxtyped(typechecker=beartype)
def einsum_callable(
params: Params, data: Data, training: bool, state: State, rng: Any
) -> Tuple[Data, State]:
# prepare the list of arrays to einsum over
# most of this will be traced out by jax
arrays = []
# if trainable, params will be the leading arrays, if there are
# any, otherwise it will be an empty tuple
if self.trainable:
if isinstance(params, np.ndarray):
arrays.append(params)
else:
arrays.extend(
jax.tree.leaves(
params,
is_leaf=lambda x: isinstance(x, np.ndarray),
)
)
elif not self.trainable and self.params is not None:
# if not trainable, params are fixed leading arrays, if any,
# and they are stored in self.params
if isinstance(self.params, np.ndarray):
arrays.append(self.params)
else:
arrays.extend(
jax.tree.leaves(
self.params,
is_leaf=lambda x: isinstance(x, np.ndarray),
)
)
# add the input data arrays
arrays.extend(
jax.tree.leaves(
data,
is_leaf=lambda x: isinstance(x, np.ndarray),
)
)
# convert all arrays to common dtype
dtype = np.result_type(*[a.dtype for a in arrays])
arrays = [a.astype(dtype) for a in arrays]
output = np.einsum(concrete_einsum_str, *arrays)
return output, state
return einsum_callable
[docs]
def compile(self, rng: Any, input_shape: DataShape) -> None:
if self.einsum_str is None:
raise ValueError(
"einsum_str must be set before compiling the module"
)
if self.is_ready() and self.input_shape != input_shape:
raise ValueError(
"Module has already been compiled with a different input "
"shape."
)
self.input_shape = input_shape
concrete_einsum_str = self._get_concrete_einsum_str(input_shape)
dim_map = self._get_dimension_map(
concrete_einsum_str,
input_shape,
)
# figure out how many leading arrays there are
input_str, output_str = concrete_einsum_str.split("->")
input_strs = input_str.split(",")
num_input_arrays = len(
jax.tree.leaves(input_shape, is_leaf=is_shape_leaf)
)
leading_strs = input_strs[:-num_input_arrays]
# initialize params if needed
if len(leading_strs) > 0 and self.params is None:
if not self.trainable:
raise ValueError(
"params must be provided for fixed (non-trainable) "
"leading arrays."
)
param_arrays = []
if self.real:
keys = jax.random.split(rng, len(leading_strs))
else:
rkey, ikey = jax.random.split(rng)
rkeys = jax.random.split(rkey, len(leading_strs))
ikeys = jax.random.split(ikey, len(leading_strs))
for i, s in enumerate(leading_strs):
shape = tuple(dim_map[char] for char in s)
if self.real:
param_array = self.init_magnitude * jax.random.normal(
keys[i], shape, dtype=np.float32
)
else:
real_part = self.init_magnitude * jax.random.normal(
rkeys[i], shape, dtype=np.complex64
)
imag_part = self.init_magnitude * jax.random.normal(
ikeys[i], shape, dtype=np.complex64
)
param_array = real_part + 1j * imag_part
param_arrays.append(param_array)
# if there's only one leading array, store it as a single array
if len(param_arrays) == 1:
self.params = param_arrays[0]
else:
self.params = tuple(param_arrays)
if self.params is not None:
# make sure params shape matches leading_strs
expected_shapes = [
tuple(dim_map[char] for char in s) for s in leading_strs
]
param_shapes = (
[self.params.shape]
if isinstance(self.params, np.ndarray)
else [
p.shape
for p in jax.tree.leaves(
self.params,
is_leaf=lambda x: isinstance(x, np.ndarray),
)
]
)
if len(expected_shapes) != len(param_shapes):
raise ValueError(
"Number of leading arrays in einsum_str "
f"'{concrete_einsum_str}' "
f"({len(leading_strs)}) does not match number of "
f"parameter arrays ({len(param_shapes)})."
)
for i, (expected_shape, param_shape) in enumerate(
zip(expected_shapes, param_shapes)
):
if expected_shape != param_shape:
raise ValueError(
f"Parameter array {i} has shape {param_shape}, but "
f"expected shape {expected_shape} based on einsum_str."
)
[docs]
def get_output_shape(self, input_shape: DataShape) -> DataShape:
concrete_einsum_str = self._get_concrete_einsum_str(input_shape)
dim_map = self._get_dimension_map(
concrete_einsum_str,
input_shape,
)
_, output_str = concrete_einsum_str.split("->")
# skip the batch dimension
output_shape = tuple(dim_map[char] for char in output_str[1:])
return output_shape
[docs]
def get_hyperparameters(self) -> HyperParams:
# include params in hyperparameters only if they are fixed
return {
"einsum_str": self.einsum_str,
"dim_map": self.dim_map,
**({"params": self.params} if not self.trainable else {}),
"trainable": self.trainable,
"init_magnitude": self.init_magnitude,
"real": self.real,
}
[docs]
def set_hyperparameters(self, hyperparams: HyperParams) -> None:
# setting the hyperparameters should require recompilation
self.input_shape = None
# only allow setting params if they are fixed
if "einsum_str" in hyperparams:
self.einsum_str = hyperparams["einsum_str"]
if "dim_map" in hyperparams:
self.dim_map = hyperparams["dim_map"]
if "trainable" in hyperparams:
self.trainable = hyperparams["trainable"]
if "params" in hyperparams and not self.trainable:
self.params = hyperparams["params"]
if "init_magnitude" in hyperparams:
self.init_magnitude = hyperparams["init_magnitude"]
if "real" in hyperparams:
self.real = hyperparams["real"]
[docs]
def get_params(self) -> Params:
# return params only if they are trainable
if not self.trainable:
return ()
return self.params
[docs]
def set_params(self, params: Params) -> None:
# only allow setting params if they are trainable
if self.trainable:
self.params = params