from __future__ import annotations
import random
import jax
from beartype import beartype
from jaxtyping import jaxtyped
from .graph_util import (
get_outer_connections_by_tree,
place_connections_in_tree,
resolve_connections,
)
from .model import Model
from .model_util import (
ModelCallable,
ModelModules,
ModelParams,
ModelState,
)
from .modules import BaseModule
from .tree_util import (
extend_structure_from_strpaths,
getitem_by_strpath,
is_shape_leaf,
make_mutable,
setitem_by_strpath,
)
from .typing import (
Any,
Data,
DataShape,
Dict,
HyperParams,
List,
ModuleCallable,
PyTree,
State,
Tuple,
)
[docs]
class NonSequentialModel(Model):
r"""
A nonsequential model that chains modules (or other models) together with
directed connections.
For confidence intervals or uncertainty quantification, wrap a trained
model with ``ConformalModel``.
See Also
--------
jax.tree
PyTree utilities and concepts in JAX.
Model
Abstract base class for all models.
SequentialModel
A model that applies modules in sequence.
ConformalModel
Wrap a trained model to produce confidence intervals.
"""
[docs]
def __init__(
self,
modules: ModelModules | BaseModule | None = None,
connections: (
Dict[str, str | List[str] | Tuple[str, ...]] | None
) = None,
/,
*,
rng: Any | int | None = None,
separator: str = ".",
) -> None:
r"""
Initialize a nonsequential model with a PyTree of modules and a
random key.
Parameters
----------
modules
module(s) to initialize the model with. Default is None, which
will become an empty dictionary. Can be a single module, which
will be wrapped in a list, or a PyTree of modules (e.g., nested
lists, tuples, or dictionaries).
connections
Directed connections between module input and outputs in the
model. Keys are period-separated paths of module outputs, and
values are lists or tuples of period-separated paths of module
inputs that receive the output. The reserved keys "input" and
"output" refer to the model input and output, respectively. The
separator can be changed from the default period using the
``separator`` argument. Default is None, which will become an
empty dictionary.
rng
Initial random key for the model. Default is None. If None, a
new random key will be generated using JAX's ``random.key``. If
an integer is provided, it will be used as the seed to create
the key.
separator
Separator string to use for denoting paths in the connections
dictionary. Default is ".".
Examples
--------
To denote a sequential model where all modules expect a bare array and
produce a bare array:
>>> modules = [Module1(), Module2(), Module3()]
>>> connections = {
... "input": "0",
... "0": "1",
... "1": "2",
... "2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
or equivalently to name the modules:
>>> modules = {"M0": Module1(), "M1": Module2(), "M2": Module3()}
>>> connections = {
... "input": "M0",
... "M0": "M1",
... "M1": "M2",
... "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
or equivalently to use nested structures:
>>> modules = {
... "block1": [Module1(), Module2()],
... "block2": {"M3": Module3()}
... }
>>> connections = {
... "input": "block1.0",
... "block1.0": "block1.1",
... "block1.1": "block2.M3",
... "block2.M3": "output"
... }
>>> model = NonSequentialModel(modules, connections)
All three of the above will produce a model that applies the same
three modules sequentially.
If a module outputs a PyTree of arrays, or expects a PyTree of arrays
as input, the connections can specify the leaf nodes using the same
period-separated path syntax. For example, if Module1 outputs a dict
with keys "a" and "b", and Module2 expects a tuple of two arrays as
input, the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
... "input": "M1",
... "M1.a": "M2.1",
... "M1.b": "M2.0",
... "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
This will send the "a" output of Module1 to the second input of
Module2, and the "b" output of Module1 to the first input of Module2.
.. note::
If a module expects a Tuple or List as input, it is best to write
the module to accept both Tuple and List types, since the specific
input type between List and Tuple cannot be inferred at compile
time.
If the entire model input or output is a PyTree of arrays, the
connections use the same period-separated path syntax with the reserved
keys "input" and "output". For example, if the model input is a dict
with keys "x1" and "x2", and the model output is a Tuple of two arrays,
the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
... "input.x1": "M1",
... "M1": "M2",
... "M2": "output.1",
... "input.x2": "output.0"
... }
>>> model = NonSequentialModel(modules, connections)
This will perform a sequential model on the "x1" input through Module1
and Module2, sending the output to the second output of the model, and
will send the "x2" input directly to the first output of the model
unchanged.
Modules that output PyTrees need not be fully traversed if entire
subtrees are to be passed between modules. For example, if Module1
outputs a dict with keys "a" and "b", and Module2 expects a dict
with keys "a" and "b" as input, the connections can be specified as:
>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
... "input": "M1",
... "M1": "M2",
... "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
or equivalently:
>>> modules = {"M1": Module1(), "M2": Module2()}
>>> connections = {
... "input": "M1",
... "M1.a": "M2.a",
... "M1.b": "M2.b",
... "M2": "output"
... }
>>> model = NonSequentialModel(modules, connections)
Both ways will pass the entire output dict of Module1 to Module2.
Module ouputs can be sent to multiple module inputs by specifying a
list or tuple of input paths in the connections dictionary. For
example, to send the output of Module1 to both Module2 and Module3:
>>> modules = {"M1": Module1(), "M2": Module2(), "M3": Module3()}
>>> connections = {
... "input": "M1",
... "M1": ["M2", "M3"],
... "M2": "output.0",
... "M3": "output.1"
... }
>>> model = NonSequentialModel(modules, connections)
This will create a model that sends the output of Module1 to both
Module2 and Module3 in parallel, and collects their outputs as a Tuple
as the model output.
The order of the connections in the dictionary does not matter, as long
as the connections form a valid directed acyclic graph from the
model input to the model output. It is not necessary to use all parts
of the model input, or all modules. However, this will raise a warning.
It is not necessary and will not raise a warning if some parts of the
outputs of some modules are not used, but all inputs of all present
modules must be connected.
See Also
--------
ModelModules
Type alias for a PyTree of modules in a model.
jax.random.key
JAX function to create a random key.
jax.tree_util.keystr
JAX function to create string paths for PyTree KeyPaths in the
format expected by the connections dictionary.
"""
modules = make_mutable(modules)
super().__init__(modules if modules is not None else {}, rng=rng)
self.connections = connections if connections is not None else {}
# if "input" or "output" is in the module keys (if modules is a dict),
# then raise an error
if isinstance(self.modules, dict):
if "input" in self.modules:
raise ValueError(
"Module key 'input' is reserved for model input."
)
if "output" in self.modules:
raise ValueError(
"Module key 'output' is reserved for model output."
)
self.execution_order: List[str] = None
self.separator = separator
[docs]
def get_execution_order(self) -> List[str]:
r"""
Resolve the connections dictionary to find the execution order of
module execution.
Raises
------
ValueError
If the connections do not form a valid directed acyclic graph
from the model input to the model output.
"""
module_connections = get_outer_connections_by_tree(
self.connections,
self.modules,
separator=self.separator,
in_key="input",
out_key="output",
)
# now we have connections in the form
# { 'input': {'<mod_path>', ...}, ... }
# now is a good time to verify that 'input' is present as a key
if "input" not in module_connections:
raise ValueError(
"Connections must include 'input' as a key "
"denoting the model input."
)
# now we resolve the execution order
topo_order, visited = resolve_connections(
module_connections,
start_key="input",
end_key="output",
)
return topo_order
[docs]
def reset(self) -> None:
r"""
Reset the compiled state of the model. This will require recompilation
before the model can be used again.
"""
self.execution_order = None
super().reset()
[docs]
def is_ready(self) -> bool:
r"""
Check if the model is compiled and ready for use. Overrides the base
implementation since not all modules need to be ready, as some may
not appear in the execution order.
Returns
-------
True if the model is compiled and ready, False otherwise.
"""
# if the execution order is not set, the model is not ready
if self.execution_order is None:
return False
# if any module in the execution order is not ready, the model is not
# ready
for module_path in self.execution_order:
if module_path == "input" or module_path == "output":
continue
module = getitem_by_strpath(
self.modules,
module_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
if not module.is_ready():
return False
# if the input or output shapes are not set, the model is not ready
return self.input_shape is not None and self.output_shape is not None
[docs]
def compile(
self,
rng: Any | int | None,
input_shape: DataShape,
/,
*,
verbose: bool = False,
) -> None:
r"""
Compile the model for training by finding the execution order of the
directed graph defined by the connections, and compiling each module
in that order.
Parameters
----------
rng
Random key for initializing the model parameters. JAX PRNGKey
or integer seed.
input_shape
Shape of the input array, excluding the batch size.
For example, (input_features,) for a 1D input or
(input_height, input_width, input_channels) for a 3D input.
verbose
Print debug information during compilation. Default is False.
Raises
------
ValueError
If the connections do not form a valid directed acyclic graph
from the model input to the model output.
"""
if rng is None:
rng = jax.random.key(random.randint(0, 2**32 - 1))
elif isinstance(rng, int):
rng = jax.random.key(rng)
if self.is_ready():
# just validate that the input shape matches the compiled one
assert jax.tree.structure(self.input_shape) == jax.tree.structure(
input_shape
), (
f"{self.name} is already compiled with input shape "
f"{self.input_shape}, cannot recompile with different "
f"input shape {input_shape}."
)
assert jax.tree.all(
jax.tree.map(
lambda a, b: a == b,
self.input_shape,
input_shape,
is_leaf=is_shape_leaf,
)
), (
f"{self.name} is already compiled with input shape "
f"{self.input_shape}, cannot recompile with different "
f"input shape {input_shape}."
)
return
# resolve the execution order
self.execution_order = self.get_execution_order()
if verbose:
print(f"{self.name} execution order:")
for i, module_path in enumerate(self.execution_order):
print(f" {i}: {module_path}")
# set the input and output shapes
self.input_shape = input_shape
input_progression, output_progression, self.output_shape = (
self._get_shape_progression(input_shape)
)
if verbose:
print(f"{self.name} input shape: {self.input_shape}")
# print progression of shapes through the execution order
print(f"{self.name} shape progression:")
for module_path in self.execution_order:
if module_path == "input" or module_path == "output":
continue
module_input_shape = getitem_by_strpath(
input_progression,
module_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
)
module_output_shape = getitem_by_strpath(
output_progression,
module_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
)
print(
f" {module_path}: {module_input_shape} -> "
f"{module_output_shape}"
)
print(f"{self.name} output shape: {self.output_shape}")
# for all modules that appear in the execution order, compile them
for module_path in self.execution_order:
if module_path == "input" or module_path == "output":
continue
rng, module_rng = jax.random.split(rng)
module = getitem_by_strpath(
self.modules,
module_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
module_input_shape = getitem_by_strpath(
input_progression,
module_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
)
if verbose:
print(
f"Compiling module '{module_path}' with input shape "
f"{module_input_shape}."
)
try:
module.compile(
module_rng,
module_input_shape,
)
except Exception as e:
raise RuntimeError(
f"Error compiling module '{module_path}' ({module.name}) "
f"with input shape {module_input_shape}: {e}"
) from e
[docs]
def _get_shape_progression(
self,
input_shape: DataShape,
/,
) -> Tuple[PyTree[DataShape | None], PyTree[DataShape | None], DataShape]:
r"""
Get the progression of output shapes through the model given an input
shape. The first entry is the model input shape, and the last entry is
the model output shape.
Parameters
----------
input_shape
Shape of the input, excluding the batch dimension.
For example, (input_features,) for 1D bare-array input, or
(input_height, input_width, input_channels) for 3D bare-array
input, [(input_features1,), (input_features2,)] for a List
(PyTree) of 1D arrays, etc.
Returns
-------
input_shapes
PyTree of input shapes at each module in the execution order,
with the same structure as the modules in the execution order.
output_shapes
PyTree of output shapes at each module in the execution order,
with the same structure as the modules in the execution order.
output_shape
Shape of the output after passing through the model.
"""
# get the execution order
execution_order = self.get_execution_order()
# the modules execute sequentially in the execution order and each
# return an arbitrary PyTree of arrays.
# since the inputs to each module can be arbitrary compositions of
# previous module outputs, we need to track the shapes of all
# intermediate outputs and assemble the inputs to each module
# first we make two empty PyTrees in the same shape as the modules, but
# only the ones in the execution order
input_shapes: PyTree[DataShape | None] = (
extend_structure_from_strpaths(
None,
execution_order,
separator=self.separator,
)
)
output_shapes: PyTree[DataShape | None] = (
extend_structure_from_strpaths(
None,
execution_order,
separator=self.separator,
)
)
# fill in the return shape from the "input" node
setitem_by_strpath(
output_shapes,
"input",
input_shape, # input node returns the model input shape
separator=self.separator,
is_leaf=is_shape_leaf,
)
module_input_deps, out_input_deps = (
self._get_module_input_dependencies()
)
# now we build the input and output shapes in execution order
for mod_path, req_in_paths in zip(execution_order, module_input_deps):
if mod_path == "input" or mod_path == "output":
continue
# mod_path is the path to the current module in self.modules
# req_in_paths is either a str or a list of str paths to the
# required inputs for the current module
# since we are in execution order, all required inputs will be
# available in output_shapes
in_shapes = jax.tree.map(
lambda p: getitem_by_strpath(
output_shapes,
p,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
),
req_in_paths,
)
setitem_by_strpath(
input_shapes,
mod_path,
in_shapes,
separator=self.separator,
is_leaf=is_shape_leaf,
)
module = getitem_by_strpath(
self.modules,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
try:
out_shape = module.get_output_shape(in_shapes)
except Exception as e:
raise RuntimeError(
f"Error getting output shape for module '{mod_path}' "
f"({module.name}) with input shape {in_shapes}: {e}"
"\nCurrent input shapes progression:"
f"\n{input_shapes}"
"\nCurrent output shapes progression:"
f"\n{output_shapes}"
) from e
setitem_by_strpath(
output_shapes,
mod_path,
out_shape,
separator=self.separator,
is_leaf=is_shape_leaf,
)
# finally, get the output shape from the "output" node
out_shapes = jax.tree.map(
lambda p: getitem_by_strpath(
output_shapes,
p,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
),
out_input_deps,
)
return input_shapes, output_shapes, out_shapes
[docs]
def get_output_shape(self, input_shape: DataShape, /) -> DataShape:
r"""
Get the output shape of the model given an input shape. Must be
implemented by all subclasses.
Parameters
----------
input_shape
Shape of the input, excluding the batch dimension.
For example, (input_features,) for 1D bare-array input, or
(input_height, input_width, input_channels) for 3D bare-array
input, [(input_features1,), (input_features2,)] for a List
(PyTree) of 1D arrays, etc.
Returns
-------
output_shape
Shape of the output after passing through the model.
"""
if self.is_ready():
return self.output_shape
_, _, output_shape = self._get_shape_progression(input_shape)
return output_shape
[docs]
def _get_callable(
self,
) -> ModelCallable:
r"""
Returns a ``jax.jit``-able and ``jax.grad``-able callable that
represents the model's forward pass.
This must be implemented by all subclasses.
This method must be implemented by all subclasses and must return a
``jax-jit``-able and ``jax-grad``-able callable in the form of
.. code-block:: python
model_callable(
params: parametricmatrixmodels.model_util.ModelParams,
data: parametricmatrixmodels.typing.Data,
training: bool,
state: parametricmatrixmodels.model_util.ModelState,
rng: Any,
) -> (
output: parametricmatrixmodels.typing.Data,
new_state: parametricmatrixmodels.model_util.ModelState,
)
That is, all hyperparameters are traced out and the callable depends
explicitly only on
* the model's parameters, as a PyTree with leaf nodes as JAX arrays,
* the input data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, ...),
* the training flag, as a boolean,
* the model's state, as a PyTree with leaf nodes as JAX arrays
and returns
* the output data, as a PyTree with leaf nodes as JAX arrays, each of
which has shape (num_samples, ...),
* the new model state, as a PyTree with leaf nodes as JAX arrays. The
PyTree structure must match that of the input state and
additionally all leaf nodes must have the same shape as the input
state leaf nodes.
The training flag will be traced out, so it doesn't need to be jittable
Returns
-------
A callable that takes the model's parameters, input data,
training flag, state, and rng key and returns the output data and
new state.
See Also
--------
__call__ : Calls the model with the current parameters and
given input, state, and rng.
ModelCallable : Typing for the callable returned by this method.
Params : Typing for the model parameters.
Data : Typing for the input and output data.
State : Typing for the model state.
"""
if not self.is_ready():
raise RuntimeError(
f"{self.name} is not ready. Call compile() first."
)
self.modules = make_mutable(self.modules)
# get the callables for each module in the execution order and put them
# in a PyTree with the same structure
module_callables: List[ModuleCallable | None] = [None] + [
getitem_by_strpath(
self.modules,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)._get_callable()
for mod_path in self.execution_order
if mod_path != "input" and mod_path != "output"
]
# get the input dependencies for each module
module_input_deps, out_input_deps = (
self._get_module_input_dependencies()
)
@jaxtyped(typechecker=beartype)
def nonseq_callable(
params: ModelParams,
data: Data,
training: bool,
state: ModelState,
rng: Any,
) -> Tuple[Data, ModelState]:
# initialize the intermediate outputs PyTree
intermediate_outputs: PyTree[Data] = (
extend_structure_from_strpaths(
None,
self.execution_order,
separator=self.separator,
)
)
# set the model input
setitem_by_strpath(
intermediate_outputs,
"input",
data,
separator=self.separator,
)
new_state = state
# now execute each module in the execution order, assembling the
# inputs from the intermediate outputs
for mod_path, module_callable, req_in_paths in zip(
self.execution_order,
module_callables,
module_input_deps,
):
if mod_path == "input" or mod_path == "output":
continue
# assemble the inputs for the module
in_data = jax.tree.map(
lambda p: getitem_by_strpath(
intermediate_outputs,
p,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
),
req_in_paths,
)
module_params = getitem_by_strpath(
params,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
module_state = getitem_by_strpath(
new_state,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
# call the module
out_data, new_module_state = module_callable(
module_params,
in_data,
training,
module_state,
rng,
)
# store the output
setitem_by_strpath(
intermediate_outputs,
mod_path,
out_data,
separator=self.separator,
)
# update the state
setitem_by_strpath(
new_state,
mod_path,
new_module_state,
separator=self.separator,
)
# assemble the model output from the intermediate outputs
out_data = jax.tree.map(
lambda p: getitem_by_strpath(
intermediate_outputs,
p,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
is_leaf=is_shape_leaf,
),
out_input_deps,
)
return out_data, new_state
return nonseq_callable
[docs]
def get_state(self) -> ModelState:
r"""
Get the state of all modules in the model as a PyTree.
Override of the base method in order to ignore modules that are not in
the execution order.
Returns
-------
A PyTree of the states of all modules in the model with the same
structure as the modules PyTree. Modules that are not in the
execution order will have state None.
"""
if not self.is_ready():
raise RuntimeError(
f"{self.name} is not ready. Call compile() first."
)
# if the model is ready, then all the modules in the execution order
# are ready, and otherwise we can ignore them
def get_state_or_none(module: BaseModule) -> State:
if module.is_ready():
return module.get_state()
else:
return None
return jax.tree.map(
get_state_or_none,
self.modules,
)
[docs]
def get_hyperparameters(self) -> HyperParams:
return {
"connections": self.connections,
"separator": self.separator,
**super().get_hyperparameters(),
}
[docs]
def set_hyperparameters(self, hyperparams: HyperParams, /) -> None:
self.connections = hyperparams.get("connections", self.connections)
self.separator = hyperparams.get("separator", self.separator)
super().set_hyperparameters(hyperparams)
def __add__(self, other: BaseModule) -> NonSequentialModel:
r"""
Overload the + operator to append a module or model to the current
model.
Parameters
----------
other
Module or model to append.
Returns
-------
New NonSequentialModel with the other module or model appended.
"""
# TODO: handle connections: anything that goes to "output" in self
# should now go to the new module/model and anything that comes from
# that module/model should now go to output
raise NotImplementedError(
"'+' not yet implemented for NonSequentialModel."
)
new_model = NonSequentialModel(
self.modules,
self.connections,
rng=self.get_rng(),
separator=self.separator,
)
new_model.append_module(other)
return new_model