Source code for parametricmatrixmodels.graph_util

from __future__ import annotations

import uuid
from collections import defaultdict, deque

from .tree_util import extend_structure_from_strpaths, getitem_by_strpath
from .typing import Any, Dict, List, OrderedSet, PyTree, Tuple


[docs] def concretize_paths(paths: List[str], separator: str = ".") -> List[str]: r""" Concretize a list of keypaths by appending indices or uuids to duplicate paths. Respects existing indices or keys. Parameters ---------- paths List of keypaths to normalize. separator The string used to separate levels in the keypaths. Default is ".". Returns ------- List of concretized keypaths with duplicates resolved. Examples -------- >>> paths = [ ... "a.b.0.1", ... "a.b.0", ... "a.b", ... "a.1", ... "a.1", ... "a.2", ... "a", ... "", ... ] >>> concrete_paths = concretize_paths(paths) >>> concrete_paths [ "a.b.0.1", "a.b.0.0", # added .0 since a.b.0.* exists "a.b.1", # added .1 since a.b.* exists and .0 is taken "a.1.0", # added .0 since a.1.* exists "a.1.1", # added .1 since a.1.* exists and .0 is taken "a.2", # no change since a.2.* doesn't exist "a.<uuid>", # added .<uuid> since a.* exists & non-int keys ('b') "<uuid>", # added <uuid> since * exists & non-int keys ('a') ] """ paths = [p.split(separator) if p else [] for p in paths] # work from the shortest paths to the longest cur_len = min(len(p) for p in paths) max_len = max(len(p) for p in paths) while cur_len <= max_len: # group all paths by their prefix up to cur_len # save path index with it to reconstruct later (deterministic order) prefix_to_paths: Dict[str, List[Tuple[List[str], int]]] = {} for i, p in enumerate(paths): prefix = separator.join(p[:cur_len]) if prefix not in prefix_to_paths: prefix_to_paths[prefix] = [] prefix_to_paths[prefix].append((p[cur_len:], i)) # if a group contains multiple paths and any are [] # we need to concretize them # otherwise, do nothing for prefix, group_paths in prefix_to_paths.items(): if len(group_paths) <= 1 or [] not in [p for p, _ in group_paths]: continue # infer existing indices/keys at this level existing_keys = set([p[0] for p, _ in group_paths if p]) is_str = any(not k.isdigit() for k in existing_keys) if is_str: # use uuid to avoid collisions for p, i in group_paths: if p == []: new_key = str(uuid.uuid4().hex) paths[i] = paths[i] + [new_key] else: # get an iterable of available indices available_indices = iter( sorted( list( set(range(len(group_paths))) - set(int(k) for k in existing_keys) ) ) ) for p, i in group_paths: if p == []: new_index = str(next(available_indices)) paths[i] = paths[i] + [new_index] cur_len += 1 paths = [separator.join(p) for p in paths] return paths
[docs] def concretize_connections( connections: Dict[str, List[str]], tree: PyTree[Any], separator: str = ".", in_key: str = "input", out_key: str = "output", ) -> Dict[str, List[str]]: r""" Concretize the connections dictionary by explicitly adding paths for leaves that have substructures but are not fully specified in the connections. The keys in the connections are never modified, since omission of substructures in the keys implies that the entire substructure is passed. Only the values are modified to explicitly specify all substructures. Order is preserved. Returns ------- Concretized connections dictionary with all leave substructures explicitly specified. Examples -------- All examples below assume a tree like >>> tree = { ... "M1": "*", ... "M2": "*", ... } Multiple implicit connections to the output >>> connections = { ... "input": ["M1", "M2"], ... "M1": "output", ... "M2": "output" ... } >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.0"], ... "M2": ["output.1"] ... } Multiple implicit connections from the input and between modules >>> connections = { ... "input.0": "M1", ... "input.1": "M1", ... "input.2": "M2", ... "M1": "M2", ... "M2": "output" ... } >>> concretized_connections = { ... "input.0": ["M1.0"], ... "input.1": ["M1.1"], ... "input.2": ["M2.0"], ... "M1": ["M2.1"], ... "M2": ["output"] ... } Beyond depth-1 implicit connections >>> connections = { ... "input": ["M1", "M2"], ... "M1": "output.0", ... "M2": "output.0" ... } >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.0.0"], ... "M2": ["output.0.1"] ... } Partially implicit connection to the output >>> connections = { ... "input": ["M1", "M2"], ... "M1": "output.1", ... "M2": "output" ... } >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.1"], ... "M2": ["output.0"] ... } Partially implicit connections across depths >>> connections = { ... "input": ["M1", "M2"], ... "M1": "output.1.a", ... "M2": "output" ... } >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.1.a"], ... "M2": ["output.0"] ... } Partially implicit connection with a dictionary type, in which case a random UUID is used to avoid collisions >>> connections = { ... "input": ["M1", "M2"], ... "M1": "output.a", ... "M2": "output" ... } >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.a"], ... "M2": ["output.<uuid>"] ... } """ # first we need a dictionary of outer leaves to substructures, only for # values in the connections # we include the original key in order to reconstruct the connections later outer_to_inner_w_key: Dict[str, List[Tuple[str, str]]] = {} partitioned_connections = partition_connections_by_tree( connections, tree, separator=separator, in_key=in_key, out_key=out_key, ) double_sep = f"{separator}{separator}" for key, values in partitioned_connections.items(): for v in values: outer, inner = v.split(double_sep) if outer not in outer_to_inner_w_key: outer_to_inner_w_key[outer] = [] outer_to_inner_w_key[outer].append((inner, key)) # make a copy without the keys outer_to_inner = { k: [t[0] for t in v] for k, v in outer_to_inner_w_key.items() } # handle mixed depths outer_to_inner = { outer: concretize_paths(inner_paths, separator=separator) for outer, inner_paths in outer_to_inner.items() } # handle duplicates by adding indices concretized_outer_to_inner: Dict[str, List[str]] = {} for outer, inner_paths in outer_to_inner.items(): # first, find duplicates path_counts: Dict[str, int] = {} for p in inner_paths: if p not in path_counts: path_counts[p] = 0 path_counts[p] += 1 # now concretize duplicates by adding indices concretized_paths: List[str] = [] path_indices: Dict[str, int] = {} for p in inner_paths: if path_counts[p] > 1: if p not in path_indices: path_indices[p] = 0 concretized_p = ( str(path_indices[p]) if not p else f"{p}{separator}{path_indices[p]}" ) path_indices[p] += 1 concretized_paths.append(concretized_p) else: concretized_paths.append(p) concretized_outer_to_inner[outer] = concretized_paths # readd keys concretized_outer_to_inner_w_key: Dict[str, List[Tuple[str, str]]] = {} for outer, inner_paths in concretized_outer_to_inner.items(): original_tuples = outer_to_inner_w_key[outer] concretized_outer_to_inner_w_key[outer] = list( zip(inner_paths, [t[1] for t in original_tuples]) ) # finally, reconstruct the connections dictionary concretized_connections: Dict[str, List[str]] = {} for outer, inner_tuples in concretized_outer_to_inner_w_key.items(): for inner, key in inner_tuples: new_value = ( outer + double_sep + inner if inner else outer + double_sep ) if key not in concretized_connections: concretized_connections[key] = [] concretized_connections[key].append(new_value) # remove double separators # if they're at the begin or end, just remove them, # otherwise replace with single separator return { k.replace(double_sep, separator).strip(separator): [ v.replace(double_sep, separator).strip(separator) for v in vs ] for k, vs in concretized_connections.items() }
[docs] def get_outer_connections_by_tree( connections: Dict[str, List[str]], tree: PyTree[Any], separator: str = ".", in_key: str = "input", out_key: str = "output", ) -> Dict[str, List[str]]: r""" Get the connections dictionary showing only the tree structure, not including any substructure. Returns ------- Connections dictionary with only tree structure. Examples -------- Given a tree like >>> tree = { ... "block1": { ... "M1": "*" ... }, ... "block2": { ... "M2": "*" ... }, ... "block3": { ... "M3": "*" ... } ... } and a connections dictionary like >>> connections = { ... "block1.M1.a": ["block2.M2.0", "block3.M3.input"], ... "block1.M1.b": "output", ... "input": "block1.M1" ... } The outer tree-only connections will be >>> outer_connections = { ... "block1.M1": ["block2.M2", "block3.M3"], ... "block1.M1": ["output"], ... "input": ["block1.M1"] ... } """ conn_separated = partition_connections_by_tree( connections, tree, separator=separator, in_key=in_key, out_key=out_key, ) double_sep = f"{separator}{separator}" # now we have connections in the form # { 'input.<path>..': ['<mod_path>..<io_path>', ...], ... } # we need to be careful, since now there can be multiple identical keys # we handle this in the values by using a set, and always add to the # value instead of overwriting # we use OrderedSet to preserve order conn_stripped: Dict[str, OrderedSet[str]] = {} for key, value in conn_separated.items(): # special case for 'input' and 'output' if key.startswith("input") or key.startswith("output"): # remove all IO paths here too stripped_key = key.split(separator)[0] else: stripped_key, _ = key.split(double_sep) if stripped_key not in conn_stripped: conn_stripped[stripped_key] = OrderedSet() for v in value: stripped_v, _ = v.split(double_sep) conn_stripped[stripped_key].add(stripped_v) # convert OrderedSets back to lists return {key: list(value) for key, value in conn_stripped.items()}
[docs] def partition_connections_by_tree( connections: Dict[str, List[str]], tree: PyTree[Any], separator: str = ".", in_key: str = "input", out_key: str = "output", ) -> Dict[str, List[str]]: r""" Process the connections dictionary to separate the tree structure of 'tree' from the remaining structure. Parameters ---------- connections Dictionary defining connections between leaves in the tree and their sub-structures. Keys and values are strings representing the keypaths to the leaves or sub-structures in the tree. tree A PyTree representing the outer structure to separate. separator The string used to separate levels in the keypaths. Default is ".". in_key The reserved key representing the model input. Default is "input". out_key The reserved key representing the model output. Default is "output". Returns ------- Processed connections dictionary with separated structures by a double separator (e.g., ".." if the separator is "."). Raises ------ ValueError If the connections contain invalid keys or values, or if the separator appears consecutively in any key or value. Examples -------- Given a tree like >>> tree = { ... "block1": { ... "M1": "*" ... }, ... "block2": { ... "M2": "*" ... }, ... "block3": { ... "M3": "*" ... } ... } and connections dictionary like >>> connections = { ... "block1.M1.a": ["block2.M2.0", "block3.M3.input"], ... "block1.M1.b": "output", ... "input": "block1.M1" ... } The partitioned connections will be >>> partitioned_connections = { ... "block1..M1.a": ["block2..M2.0", "block3..M3.input"], ... "block1..M1.b": ["output.."], ... "input..": ["block1..M1"] ... } """ # the zeroth step is to convert the connections dictionary mixed value # types to uniform list types conn: Dict[str, List[str]] = {} for key, value in connections.items(): if isinstance(value, (list, tuple)): conn[key] = list(value) else: conn[key] = [value] # first we need to verify that the separator is not already doubled # anywhere double_sep = f"{separator}{separator}" for key, value in conn.items(): if double_sep in key: raise ValueError( f"Separator '{separator}' cannot appear " "consecutively in connection keys. Found in key '{key}'." ) for v in value: if double_sep in v: raise ValueError( f"Separator '{separator}' cannot appear " "consecutively in connection values. Found in value " f"'{v}'." ) # now we use getitem_by_strpath with allow_early_return and # return_remainder to separate the tree structure from the remaining # structure conn_separated: Dict[str, List[str]] = {} for key, value in conn.items(): if key.startswith(in_key + separator) or key == in_key: new_key = ( in_key + double_sep + key.removeprefix(in_key).lstrip(separator) ) elif key.startswith(out_key + separator) or key == out_key: raise ValueError( f"'{out_key}' cannot be used as a key in the connections " "dictionary since it is reserved for model output." ) else: _, key_remainder = getitem_by_strpath( tree, key, separator=separator, allow_early_return=True, return_remainder=True, ) new_key = ( key.removesuffix(key_remainder).rstrip(separator) + double_sep + key_remainder ) new_values = [] for v in value: if v.startswith(out_key + separator) or v == out_key: new_v = ( out_key + double_sep + v.removeprefix(out_key).lstrip(separator) ) new_values.append(new_v) continue elif v.startswith(in_key + separator) or v == in_key: raise ValueError( f"'{in_key}' cannot be used as a value in the " "connections dictionary since it is reserved for " "model input." ) else: _, v_remainder = getitem_by_strpath( tree, v, separator=separator, allow_early_return=True, return_remainder=True, ) new_v = ( v.removesuffix(v_remainder).removesuffix(separator) + double_sep + v_remainder ) new_values.append(new_v) conn_separated[new_key] = new_values return conn_separated
[docs] def place_connections_in_tree( connections: Dict[str, List[str]], tree: PyTree[Any], separator: str = ".", in_key: str = "input", out_key: str = "output", ) -> Tuple[PyTree[Any], PyTree[Any]]: r""" Place the concretized connections into a PyTree structure matching 'tree'. Parameters ---------- connections Connections dictionary, will be concretized internally if needed. tree A PyTree representing the outer structure to separate. separator The string used to separate levels in the keypaths. Default is ".". Returns ------- A PyTree with the same structure as 'tree', with values from 'concrete_connections' placed at the appropriate leaves. And the remainder structure with the end_key paths. Examples -------- Given a tree like >>> tree = { ... "M1": "*", ... "M2": "*", ... } and concretized connections like >>> concretized_connections = { ... "input": ["M1", "M2"], ... "M1": ["output.0"], ... "M2": ["output.1"] ... } this will (internally) reverse the connections to get >>> reversed_connections = { ... "M1": ["input"], ... "M2": ["input"], ... "output.0": ["M1"], ... "output.1": ["M2"], ... } then place the paths into the original tree structure, and the remainder with the end_key will be returned separately, with its corresponding structure >>> placed_tree = { ... "M1": ["input"], ... "M2": ["input"], ... } >>> output_remainder = [ ... "M1", ... "M2", ... ] Further examples: >>> tree = {"A": "*", "B": ("*", "*")} >>> concretized_connections = { ... "input.0": ["A.0", "B.0", "B.1"], ... "input.1": ["A.1"], ... "B.0": ["A.2"], ... "B.1": ["output.x"], ... "A": ["output.y"], ... } >>> placed_tree, output_remainder = place_connections_in_tree( ... concretized_connections, ... tree, ... ) >>> placed_tree { "A": ["input.0", "input.1", "B.0"], "B": ("input.0", "input.0"), } >>> output_remainder { "x": "B.1", "y": "A", } """ # first, concretize the connections concretized_connections = concretize_connections( connections, tree, separator=separator, in_key=in_key, out_key=out_key, ) # reverse the connections reversed_conn: Dict[str, List[str]] = {} for key, values in concretized_connections.items(): for v in values: reversed_conn.setdefault(v, []).append(key) # if any reversed connection has multiple values, raise an error as this # isn't possible after concretization for key, values in reversed_conn.items(): if len(values) > 1: raise RuntimeError( "Invalid concretized connections: multiple sources for " f"'{key}': {values}" ) reversed_conn = {k: v[0] for k, v in reversed_conn.items()} # infer the remainder structure # search through the reversed_part_conn for all keys starting with out_key output_keys = [k for k in reversed_conn if k.startswith(out_key)] # if theres only a single output key, and that key is exactly out_key, # then the remainder is just that value is_str = len(output_keys) == 1 and output_keys[0] == out_key # otherwise, the structure will be inferred when building the tree if is_str: output_remainder = reversed_conn[out_key] else: output_remainder = extend_structure_from_strpaths( None, [ k.removeprefix(f"{out_key}{separator}").strip(separator) for k in output_keys ], separator=separator, fill_values=[reversed_conn[k] for k in output_keys], ) placed_tree = extend_structure_from_strpaths( None, [ k.strip(separator) for k in reversed_conn if not k.startswith(out_key + separator) ], separator=separator, fill_values=[ reversed_conn[k] for k in reversed_conn if not k.startswith(out_key + separator) ], ) return placed_tree, output_remainder
[docs] def resolve_connections( graph: Dict[str, List[str]], start_key: str = "input", end_key: str = "output", ) -> Tuple[List[str], OrderedSet[str]]: # breadth-first search # this is a special case of a topological sort, since we can ignore # all nodes that are not on a path from start_key to end_key, # but we need to make sure that all nodes on all such paths are # included topological_order: List[str] = [] in_degree: Dict[str, int] = defaultdict(int) reverse_graph: Dict[str, List[str]] = defaultdict(list) reverse_in_degree: Dict[str, int] = defaultdict(int) # make the reverse graph for node, neighbors in graph.items(): for neighbor in neighbors: reverse_graph[neighbor].append(node) reverse_in_degree[node] += 1 for node, neighbors in graph.items(): for neighbor in neighbors: in_degree[neighbor] += 1 # filter out nodes that are not reachable from start_key visited: OrderedSet[str] = OrderedSet() reachable: OrderedSet[str] = OrderedSet() stack = [start_key] while stack: node = stack.pop() if node in visited: continue visited.add(node) reachable.add(node) for neighbor in graph.get(node, []): if neighbor not in visited: stack.append(neighbor) # filter out nodes that cannot reach end_key visited.clear() reverse_reachable: OrderedSet[str] = OrderedSet() stack = [end_key] while stack: node = stack.pop() if node in visited: continue visited.add(node) reverse_reachable.add(node) for neighbor in reverse_graph.get(node, []): if neighbor not in visited: stack.append(neighbor) # intersect reachable and reverse_reachable all_nodes = reachable & reverse_reachable queue = deque([start_key]) while queue: node = queue.popleft() topological_order.append(node) for neighbor in graph.get(node, []): if neighbor not in all_nodes: continue in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) try: validate_resolution(topological_order, all_nodes, start_key, end_key) except RuntimeError as e: raise RuntimeError( "Failed to resolve connections into a valid topological order. " "Got order:\n " + "\n ".join(topological_order) ) from e return topological_order, visited
[docs] def validate_resolution( order: List[str], all_nodes: OrderedSet[str], start_key="input", end_key="output", ) -> None: # ensure that order[0] is start_key, order[-1] is end_key, # and all nodes appear exactly once if len(order) != len(all_nodes): raise RuntimeError( "Topological order contains a cycle or is missing nodes." ) if order[0] != start_key: raise RuntimeError( f"Topological order does not start with '{start_key}'" ) if order[-1] != end_key: raise RuntimeError("Topological order does not end with '{end_key}'")