Source code for parametricmatrixmodels.modules.subsetmodule

from __future__ import annotations

import sys
from typing import Any, Callable

import jax.numpy as np

from ._helpers import subsets_to_string
from .basemodule import BaseModule


[docs] class SubsetModule(BaseModule): """ Meta-module that applies a given Module to a subset of the input data and optionally passes the remaining input data through unchanged. Diagrammatic example: [x0] -> ...................................... -> [x0] [x1] -> ...................................... -> [x1] [x2] -> [ SubsetModule([2:], APPEND, module) ] -> [module(x2, x3)] [x3] -> [ ] The output shape need not match the input shape, as the module may change the shape of the data, as in the example above, where the module takes two input features (x2, x3) and produces one output feature (module(x2, x3)). The output of the module that is applied can either be prepended or appended to the unchanged input data, as specified by the `prepend` parameter in the constructor. Additionally, the unchanged input data can be passed through unchanged alongside the output of the module, or it can be ignored, depending on the `passthrough` parameter in the constructor. A diagrammatic example of a passthrough SubsetModule is as follows: [x0] -> .............................................. -> [x0] [x1] -> .............................................. -> [x1] [x2] -> [ SubsetModule([2:], APPEND, PASSTHROUGH, m) ] -> [x2] [x3] -> [ ] -> [x3] [ ] -> [m(x2, x3)] """
[docs] def __init__( self, subset: tuple[slice, ...] | tuple[np.ndarray, ...] = None, module: BaseModule = None, prepend: bool = True, axis: int = 0, passthrough: bool = False, ): """ Parameters ---------- subset A tuple of slices or index arrays indicating which parts of the input data to apply the module to. The number of slices must match the shape of the input data, not including the batch dimension. For example, for input data of shape (num_samples, num_features), a subset of all except the first two features would be specified as (slice(2, None),). module The module to apply to the specified subset of the input data. prepend If True, the output of the module will be prepended to the unchanged input data. If False, the output will be appended. Defaults to True. axis The axis along which to concatenate the output of the module and the unchanged input data. Defaults to 0 (i.e., along the first feature dimension). This does not include the batch dimension. passthrough If True, the unchanged input data will be passed through alongside the output of the module. If False, the unchanged input data will be dropped. Defaults to False. """ if subset is not None and not isinstance(subset, tuple): raise TypeError( "Subset must be a tuple of slices or index arrays, " f"not {type(subset)}" ) self.subset = subset self.module = module self.prepend = prepend self.axis = axis self.passthrough = passthrough self.input_shape = None # will be set in compile self.output_shape = None # will be set in compile
[docs] def name(self) -> str: subname = self.module.name() subset_str = subsets_to_string(self.subset) pend_str = "PREPEND" if self.prepend else "APPEND" pass_str = "PASSTHROUGH" if self.passthrough else "CONSUME" return ( f"SubsetModule({subset_str}, {pass_str}, {pend_str}, " f"axis={self.axis}, {subname})" )
[docs] def is_ready(self) -> bool: return ( self.module.is_ready() and self.input_shape is not None and self.output_shape is not None )
[docs] def get_num_trainable_floats(self) -> int | None: return self.module.get_num_trainable_floats()
[docs] def _get_inverse_subset_mask( self, input_shape: tuple[int, ...] ) -> np.ndarray: # make a mask for the inverse of the subset # if passthrough, then the mask is just the entire input mask = np.ones(input_shape, dtype=bool) if not self.passthrough: mask = mask.at[self.subset].set(False) return mask
[docs] def _get_callable(self) -> Callable: if not self.is_ready(): raise ValueError( "SubsetModule is not ready. " "Call compile() with the input shape and rng." ) # first add a full slice to the subset to include the batch dimension full_subset = (slice(None),) + self.subset mask = self._get_inverse_subset_mask(self.input_shape) subcallable = self.module._get_callable() def _callable( params: tuple[np.ndarray, ...], input_NF: np.ndarray, training: bool = False, state: tuple[np.ndarray, ...] = (), rng: Any = None, ) -> tuple[np.ndarray, tuple[np.ndarray, ...]]: # apply the module to the subset of the input data subset_input = input_NF[full_subset] module_output, new_state = subcallable( params, subset_input, training, state, rng ) # prepend or append the module output if self.prepend: output_NF = np.concatenate( (module_output, input_NF[:, mask]), axis=self.axis + 1, ) else: output_NF = np.concatenate( (input_NF[:, mask], module_output), axis=self.axis + 1, ) return output_NF, new_state return _callable
[docs] def compile(self, rng: Any, input_shape: tuple[int, ...]) -> None: self.input_shape = input_shape # get subinput shape to pass to the module's compile method subset_input_zeros = np.zeros(input_shape, dtype=np.float32)[ self.subset ] self.module.compile(rng, subset_input_zeros.shape) self.output_shape = self.get_output_shape(input_shape)
[docs] def get_output_shape( self, input_shape: tuple[int, ...] ) -> tuple[int, ...]: # easiest way to get the output shape is to perform a dummy forward # pass with zeros and the module's `get_output_shape` method # don't need to include the batch dimension in any of the shapes here mask = self._get_inverse_subset_mask(input_shape) input_zeros = np.zeros(input_shape, dtype=np.float32) subset_input = input_zeros[self.subset] module_output_shape = self.module.get_output_shape(subset_input.shape) module_output_zeros = np.zeros(module_output_shape, dtype=np.float32) if self.prepend: output_zeros = np.concatenate( (module_output_zeros, input_zeros[mask]), axis=self.axis ) else: output_zeros = np.concatenate( (input_zeros[mask], module_output_zeros), axis=self.axis ) return output_zeros.shape
[docs] def get_hyperparameters(self) -> dict[str, Any]: return { "subset": self.subset, "prepend": self.prepend, "axis": self.axis, "passthrough": self.passthrough, "module": self.module, }
[docs] def set_hyperparameters(self, hyperparams: dict[str, Any]) -> None: if self.input_shape is not None or self.output_shape is not None: raise ValueError( "Cannot set hyperparameters after compile. " "Please set hyperparameters before calling compile()." ) super(SubsetModule, self).set_hyperparameters(hyperparams)
[docs] def get_params(self) -> tuple[np.ndarray, ...]: return self.module.get_params()
[docs] def set_params(self, params: tuple[np.ndarray, ...]) -> None: self.module.set_params(params)
[docs] def get_state(self) -> tuple[np.ndarray, ...]: return self.module.get_state()
[docs] def set_state(self, state: tuple[np.ndarray, ...]) -> None: self.module.set_state(state)
[docs] def serialize(self) -> dict[str, Any]: # call the base class serialize method to get most of the data # serialized serial = super(SubsetModule, self).serialize() # then replace the few fields that weren't actually serialized: # module and subset module_typename = self.module.__class__.__name__ module_module = self.module.__module__ module_serial = self.module.serialize() # serialize the subsets by converting any slices to their start, stop, # and step values serial_subsets = ( [ ( (slc.start, slc.stop, slc.step) if isinstance(slc, slice) else slc ) for slc in self.subset ] if self.subset else None ) serial["hyperparameters"]["subset"] = serial_subsets serial["hyperparameters"]["module"] = { "type": module_typename, "module": module_module, "data": module_serial, } return serial
[docs] def deserialize(self, data: dict[str, Any]) -> None: # replace non-standard serialized fields with the partially # deserialized values before calling the base class's deserialize subset_slices = data["hyperparameters"]["subset"] # deserialize the subsets by converting any tuples back to slices if subset_slices is not None: subset = tuple( slice(*slc) if isinstance(slc, tuple) else slc for slc in subset_slices ) else: subset = None module_module = data["hyperparameters"]["module"]["module"] module_typename = data["hyperparameters"]["module"]["type"] module_data = data["hyperparameters"]["module"]["data"] module = getattr(sys.modules[module_module], module_typename)() module.deserialize(module_data) data["hyperparameters"]["subset"] = subset data["hyperparameters"]["module"] = module super(SubsetModule, self).deserialize(data)