import random
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as np
import numpy as onp
from packaging.version import parse
import parametricmatrixmodels as pmm
from .modules import BaseModule
from .training import make_loss_fn, train
[docs]
class Model(object):
"""
Model class built from a list of modules.
"""
def __repr__(self) -> str:
trainable_floats_num = self.get_num_trainable_floats()
# get number of modules in order to reserve whitespace
num_modules = len(self.modules)
mod_idx_width = len(str(num_modules - 1))
if trainable_floats_num is None:
num_trainable_floats = "(uninitialized)"
else:
num_trainable_floats = (
f"(trainable floats: {trainable_floats_num:,})"
)
rep = (
f"Model(input_shape={self.input_shape}, "
f"output_shape={self.output_shape}, ready={self.ready}) "
f"{num_trainable_floats}\n"
)
input_shape = self.input_shape if self.input_shape else None
for i, module in enumerate(self.modules):
input_shape = (
module.get_output_shape(input_shape) if input_shape else None
)
comment = module.name().startswith("#")
rep += f"\n{i:>{mod_idx_width}}: {module}" + (
f" -> {input_shape}" if input_shape and not comment else ""
)
return rep
[docs]
def __init__(
self, modules: List[BaseModule] = [], rng: Any = None
) -> None:
"""
Initialize the model with the input shape and a list of modules.
Parameters
----------
modules : List[BaseModule], optional
List of modules to initialize the model with. Default is an
empty list.
rng : Any, optional
Initial random key for the model. Default is None. If None, a
new random key will be generated using JAX's random.PRNGKey. If
an integer is provided, it will be used as the seed to create
the key.
"""
self.modules = modules
if rng is None:
self.rng = jax.random.key(random.randint(0, 2**32 - 1))
elif isinstance(rng, int):
self.rng = jax.random.key(rng)
else:
self.rng = rng
self.reset()
[docs]
def get_num_trainable_floats(self) -> Optional[int]:
num_trainable_floats = [
module.get_num_trainable_floats() for module in self.modules
]
if None in num_trainable_floats:
return None
else:
return sum(num_trainable_floats)
[docs]
def reset(self) -> None:
self.input_shape = None
self.output_shape = None
self.ready = False
self.parameter_counts = None
self.state_counts = None
self.callable = None
[docs]
def append_module(self, module: BaseModule) -> None:
"""
Append a module to the model.
Parameters
----------
module : BaseModule
Module to append to the model.
"""
self.modules.append(module)
self.reset()
[docs]
def prepend_module(self, module: BaseModule) -> None:
"""
Prepend a module to the model.
Parameters
----------
module : BaseModule
Module to prepend to the model.
"""
self.modules.insert(0, module)
self.reset()
[docs]
def insert_module(self, module: BaseModule, index: int) -> None:
"""
Insert a module at the given index in the model.
Parameters
----------
module : BaseModule
Module to insert into the model.
index : int
Index at which to insert the module.
"""
self.modules.insert(index, module)
self.reset()
add = append_module
put = prepend_module
insert = insert_module
[docs]
def remove_module(self, index: int) -> None:
"""
Remove a module from the model at the given index.
Parameters
----------
index : int
Index of the module to remove.
"""
if index < 0 or index >= len(self.modules):
raise IndexError("Index out of range.")
del self.modules[index]
self.reset()
[docs]
def pop_module(self) -> BaseModule:
"""
Pop the last module from the model.
Returns
-------
BaseModule
The last module in the model
"""
if not self.modules:
raise IndexError("No modules to pop.")
module = self.modules.pop()
self.reset()
return module
[docs]
def __getitem__(
self, key: Union[int, np.ndarray, slice]
) -> Union[List[BaseModule], BaseModule]:
"""
Get the module at the given index.
Parameters
----------
index : int
Index of the module to retrieve.
Returns
-------
BaseModule
The module at the specified index.
"""
if isinstance(key, np.ndarray):
if key.ndim > 1:
raise ValueError(
"Index array must be 1D. Use a boolean mask or a 1D array."
)
# the key can either be an index array or a boolean mask
if key.dtype == bool:
if len(key) != len(self.modules):
raise ValueError(
"Boolean mask length must match the number of modules."
)
indices = np.where(key)[0]
return [self.modules[i] for i in indices]
elif key.dtype == int:
indices = key.flatten()
return [self.modules[i] for i in indices]
else:
raise ValueError(
"Index array must be of type int or bool. "
f"Got {key.dtype}."
)
elif isinstance(key, slice):
# return a slice of the modules
return self.modules[key]
elif isinstance(key, int):
if key < 0 or key >= len(self.modules):
raise IndexError("Index out of range.")
return self.modules[key]
else:
raise TypeError(
"Index must be an integer, a slice, or a 1D numpy array. "
f"Got {type(key)}."
)
[docs]
def compile(
self,
rngkey: Optional[Union[Any, int]],
input_shape: Tuple[int, ...],
verbose: bool = False,
) -> None:
"""
Compile the model for training by compiling each module.
Parameters
----------
rngkey : Union[Any, int]
Random key for initializing the model parameters. JAX PRNGKey
or integer seed.
input_shape : Tuple[int, ...]
Shape of the input array, excluding the batch size.
For example, (input_features,) for a 1D input or
(input_height, input_width, input_channels) for a 3D input.
verbose : bool, optional
Print debug information during compilation. Default is False.
"""
if rngkey is None:
rngkey = random.randint(0, 2**32 - 1)
if isinstance(rngkey, int):
rngkey = jax.random.key(rngkey)
if verbose:
print(
f"Compiling model with input shape {input_shape} and "
f"{len(self.modules)} modules."
)
self.input_shape = input_shape
for i, module in enumerate(self.modules):
rngkey, modrng = jax.random.split(rngkey)
module.compile(modrng, input_shape)
input_shape = module.get_output_shape(input_shape)
if verbose:
print(f"({i}) {module.name()} output shape: {input_shape}")
self.output_shape = input_shape
# get number of parameter arrays for each module
self.parameter_counts = [
len(module.get_params()) for module in self.modules
]
self.state_counts = [
len(module.get_state()) for module in self.modules
]
self.ready = True
[docs]
def get_output_shape(
self, input_shape: Tuple[int, ...]
) -> Tuple[int, ...]:
"""
Get the output shape of the model given an input shape.
Parameters
----------
input_shape : Tuple[int, ...]
Shape of the input array, excluding the batch size.
For example, (input_features,) for a 1D input or
(input_height, input_width, input_channels) for a 3D input.
Returns
-------
Tuple[int, ...]
Shape of the output array after passing through the model.
"""
for module in self.modules:
input_shape = module.get_output_shape(input_shape)
return input_shape
[docs]
def get_params(self) -> Tuple[np.ndarray, ...]:
"""
Get the parameters of the model as a Tuple of numpy arrays.
Returns
-------
Tuple[np.ndarray, ...]
numpy arrays representing the parameters of the model. The
order of the parameters should match the order in
which they are used in the _get_callable method.
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
# parameter tuple must be flat
return tuple(
param for module in self.modules for param in module.get_params()
)
[docs]
def set_params(self, params: Tuple[np.ndarray, ...]) -> None:
"""
Set the parameters of the model from a Tuple of numpy arrays.
Parameters
----------
params: Tuple[np.ndarray, ...]
numpy arrays representing the parameters of the model. The
order of the parameters should match the order in which
they are used in the _get_callable method.
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
if len(params) != sum(self.parameter_counts):
raise ValueError(
f"Expected {sum(self.parameter_counts)} parameters, "
f"but got {len(params)}."
)
# set parameters for each module
param_index = 0
for module in self.modules:
count = len(module.get_params())
module.set_params(params[param_index : param_index + count])
param_index += count
[docs]
def get_state(self) -> Tuple[np.ndarray, ...]:
"""
Get the state of the model as a Tuple of numpy arrays.
Returns
-------
Tuple[np.ndarray, ...]
numpy arrays representing the state of the model. The order of
the states should match the order in which they are used in
the _get_callable method.
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
# state tuple must be flat
return tuple(
state for module in self.modules for state in module.get_state()
)
[docs]
def set_state(self, state: Tuple[np.ndarray, ...]) -> None:
"""
Set the state of the model from a Tuple of numpy arrays.
Parameters
----------
state: Tuple[np.ndarray, ...]
numpy arrays representing the state of the model. The order of
the states should match the order in which they are used in
the _get_callable method.
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
if len(state) != sum(self.state_counts):
raise ValueError(
f"Expected {sum(self.state_counts)} states, "
f"but got {len(state)}."
)
# set state for each module
state_index = 0
for module in self.modules:
count = len(module.get_state())
module.set_state(state[state_index : state_index + count])
state_index += count
[docs]
def get_rng(self) -> Any:
return self.rng
[docs]
def set_rng(self, rng: Any) -> None:
"""
Set the random key for the model.
Parameters
----------
rng : Any
Random key to set for the model. JAX PRNGKey or an integer seed
"""
if isinstance(rng, int):
self.rng = jax.random.key(rng)
else:
self.rng = rng
[docs]
def _get_callable(
self,
) -> Callable[
[
Tuple[np.ndarray, ...],
np.ndarray,
bool,
Tuple[np.ndarray, ...],
Any,
],
Tuple[np.ndarray, Tuple[np.ndarray, ...]],
]:
"""
This method must return a jax-jittable and jax-gradable callable in the
form of
```
(
params: Tuple[np.ndarray, ...],
input_NF: np.ndarray[num_samples, num_features],
training: bool,
state: Tuple[np.ndarray, ...],
rng: key<fry>
) -> (
output_NF: np.ndarray[num_samples, num_output_features],
new_state: Tuple[np.ndarray, ...]
)
```
That is, all hyperparameters are traced out and the callable depends
explicitly only on a Tuple of parameter numpy arrays, the input array,
the training flag, a state Tuple of numpy arrays, and a JAX rng key.
The training flag will be traced out, so it doesn't need to be jittable
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
# get the callables for each module
module_callables = [module._get_callable() for module in self.modules]
# parameter tuple must be flattened, so we'll need to iterate over the
# parameter counts
# state tuple must also be flattened, so we'll need to iterate over the
# state counts
# jax will unroll this loop
def model_callable(
params: Tuple[np.ndarray],
X: np.ndarray,
training: bool = False,
states: Tuple[np.ndarray, ...] = (),
rng: Any = None, # absolutely not optonal
) -> np.ndarray:
param_index = 0
state_index = 0
# split rng key into a key for each module
rngs = jax.random.split(rng, len(self.modules))
for idx, param_count in enumerate(self.parameter_counts):
state_count = self.state_counts[idx]
module_params = tuple(
params[param_index : param_index + param_count]
)
module_states = tuple(
states[state_index : state_index + state_count]
)
module_rng = rngs[idx]
X, new_module_states = module_callables[idx](
module_params, X, training, module_states, module_rng
)
# update the states
states = (
states[:state_index]
+ new_module_states
+ states[state_index + state_count :]
)
# increment indices
param_index += param_count
state_index += state_count
return X, states
return model_callable
[docs]
def __call__(
self,
X: np.ndarray,
dtype: Optional[Any] = np.float64,
rng: Any = None,
return_state: bool = False,
update_state: bool = False,
) -> np.ndarray:
"""
Call the model with the input array.
Parameters
----------
X : np.ndarray
Input array of shape (batch_size, <input feature axes>).
For example, (batch_size, input_features) for a 1D input or
(batch_size, input_height, input_width, input_channels) for a
3D input.
dtype : Optional[Any], optional
Data type of the output array. Default is jax.numpy.float64.
It is strongly recommended to perform training in single
precision (float32 and complex64) and inference with double
precision inputs (float64, the default here) with single
precision weights.
rng : Any, optional
JAX random key for stochastic modules. Default is None.
If None, the saved rng key will be used if it exists, which
would be the final rng key from the last training run. If an
integer is provided, it will be used as the seed to create a
new JAX random key.
return_state : bool, optional
If True, the model will return the state of the model after
evaluation. Default is False.
update_state : bool, optional
If True, the model will update the state of the model after
evaluation. Default is False.
Returns
-------
np.ndarray
Output array of shape (batch_size, <output feature axes>).
For example, (batch_size, output_features) for a 1D output or
(batch_size, output_height, output_width, output_channels) for
a 3D output.
Tuple[np.ndarray, ...], optional
If return_state is True, the model will also return the state
of the model as a Tuple of numpy arrays. The order of the
states will match the order in which they are used in the
_get_callable method.
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
if self.callable is None:
self.callable = jax.jit(
self._get_callable(), static_argnames=["training"]
)
X_ = X.astype(dtype)
# make sure the dtype was converted, issue a warning if not
if X_.dtype != dtype:
warnings.warn(
"While performing inference with model: "
f"Requested dtype ({dtype}) was not successfully applied. "
"This is most likely due to JAX_ENABLE_X64 not being set. "
"See accompanying JAX warning for more details.",
UserWarning,
)
if rng is None:
rng = self.get_rng()
elif isinstance(rng, int):
rng = jax.random.key(rng)
out, new_state = self.callable(
self.get_params(), X_, False, self.get_state(), rng
)
if update_state:
warnings.warn(
"update_state is True. This is an uncommon use case, make "
"sure you know what you are doing.",
UserWarning,
)
self.set_state(new_state)
if return_state:
return out, new_state
else:
return out
# alias for __call__ method
predict = __call__
[docs]
def set_precision(self, prec: Union[np.dtype, str, int]) -> None:
"""
Set the precision of the model parameters and states.
Parameters
----------
prec : Union[np.dtype, str, int]
Precision to set for the model parameters and states.
Valid options are:
[for 32-bit precision (all options are equivalent)]
- np.float32, np.complex64, "float32", "complex64"
- "single", "f32", "c64", 32
[for 64-bit precision (all options are equivalent)]
- np.float64, np.complex128, "float64", "complex128"
- "double", "f64", "c128", 64
"""
if not self.ready:
raise RuntimeError("Model is not ready. Call compile() first.")
# convert precision to 32 or 64
if prec in [
np.float32,
np.complex64,
"float32",
"complex64",
"single",
"f32",
"c64",
32,
]:
prec = 32
elif prec in [
np.float64,
np.complex128,
"float64",
"complex128",
"double",
"f64",
"c128",
64,
]:
prec = 64
else:
raise ValueError(
"Invalid precision. Valid options are:\n"
"[for 32-bit precision] np.float32, np.complex64, 'float32', "
"'complex64', 'single', 'f32', 'c64', 32;\n"
"[for 64-bit precision] np.float64, np.complex128, 'float64', "
"'complex128', 'double', 'f64', 'c128', 64."
)
# check if dtype is supported
if not jax.config.read("jax_enable_x64") and prec == 64:
raise ValueError(
"JAX_ENABLE_X64 is not set. "
"Please set it to True to use double precision float64 or "
"complex128 data types."
)
for module in self.modules:
module.set_precision(prec)
# alias for set_precision method that returns self
[docs]
def astype(self, dtype: Union[np.dtype, str]) -> "Model":
"""
Convenience wrapper to set_precision using the dtype argument, returns
self.
"""
self.set_precision(dtype)
return self
[docs]
def train(
self,
X: np.ndarray,
Y: Optional[np.ndarray] = None,
Y_unc: Optional[np.ndarray] = None,
X_val: Optional[np.ndarray] = None,
Y_val: Optional[np.ndarray] = None,
Y_val_unc: Optional[np.ndarray] = None,
loss_fn: Union[str, Callable] = "mse",
lr: float = 1e-3,
batch_size: int = 32,
num_epochs: int = 100,
convergence_threshold: float = 1e-12,
early_stopping_patience: int = 10,
early_stopping_tolerance: float = 1e-6,
# advanced options
initialization_seed: Optional[int] = None,
callback: Optional[Callable] = None,
unroll: Optional[int] = None,
verbose: bool = True,
batch_seed: Optional[int] = None,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
clip: float = 1e3,
) -> None:
# check if the model is ready
if not self.ready:
initialization_seed = initialization_seed or random.randint(
0, 2**32 - 1
)
self.compile(jax.random.key(initialization_seed), X.shape[1:])
# check if any of the model parameters are double precision and give a
# warning if so
if any(
(
np.issubdtype(np.asarray(param).dtype, np.float64)
or np.issubdtype(np.asarray(param).dtype, np.complex128)
)
for param in self.get_params()
):
warnings.warn(
"Some parameters are double precision. "
"This may lead to significantly slower training on certain "
"backends. It is strongly recommended to use single precision "
"(float32/complex64) parameters for training. Set the "
"precision of the model with Model.set_precision.",
UserWarning,
)
# check dimensions
input_shape = X.shape[1:]
if input_shape != self.input_shape:
raise ValueError(
f"Input shape {input_shape} does not match model input shape "
f"{self.input_shape}."
)
if Y is not None and Y.shape[1:] != self.output_shape:
raise ValueError(
f"Output shape {Y.shape[1:]} does not match model output "
f"shape {self.output_shape}."
)
if Y is not None and X_val is not None and Y_val is None:
raise ValueError(
"Validation data Y_val must be provided if validation input "
"X_val is provided for supervised training."
)
# get callable, not jitted since the training function will
# handle that
callable_ = self._get_callable()
# make the loss function
if isinstance(loss_fn, str):
loss_fn_ = make_loss_fn(
loss_fn, lambda x, p, t, s, r: callable_(p, x, t, s, r)
)
else:
# if the loss function is already a callable, we wrap it with the
# model callable
# whether or not Y and Y_unc are provided changes the signature
# of the loss function
if Y is not None and Y_unc is not None:
# the loss function should be
# loss_fn(X, Y, Y_unc, Y_pred) -> err
def loss_fn_(X, Y, Y_unc, params, training, states, rng):
Y_pred, new_states = callable_(
params, X, training, states, rng
)
err = loss_fn(X, Y, Y_unc, Y_pred)
return err, new_states
elif Y is not None and Y_unc is None:
# the loss function should be
# loss_fn(X, Y, Y_pred) -> err
def loss_fn_(X, Y, params, training, states, rng):
Y_pred, new_states = callable_(
params, X, training, states, rng
)
err = loss_fn(X, Y, Y_pred)
return err, new_states
elif Y is None and Y_unc is None:
# the loss function should be
# loss_fn(X, pred) -> err
# (unsupervised training)
def loss_fn_(X, params, training, states, rng):
pred, new_states = callable_(
params, X, training, states, rng
)
err = loss_fn(X, pred)
return err, new_states
else:
raise ValueError(
"Invalid loss function signature. "
"If Y and Y_unc are provided, the loss function should be "
"loss_fn(X, Y, Y_unc, Y_pred) -> err. "
"If only Y is provided, it should be "
"loss_fn(X, Y, Y_pred) -> err. "
"If neither are provided, it should be "
"loss_fn(X, pred) -> err."
)
# train the model
(
final_params,
final_model_states,
final_model_rng,
final_epoch,
final_adam_states,
) = train(
init_params=self.get_params(),
init_states=self.get_state(),
init_rng=self.get_rng(),
loss_fn=loss_fn_,
X=X,
Y=Y,
Y_unc=Y_unc,
X_val=X_val,
Y_val=Y_val,
Y_val_unc=Y_val_unc,
lr=lr,
batch_size=batch_size,
num_epochs=num_epochs,
convergence_threshold=convergence_threshold,
early_stopping_patience=early_stopping_patience,
early_stopping_tolerance=early_stopping_tolerance,
callback=callback,
unroll=unroll,
verbose=verbose,
batch_seed=batch_seed,
b1=b1,
b2=b2,
eps=eps,
clip=clip,
real=False,
)
# set the final parameters
self.set_params(final_params)
# set the final state
self.set_state(final_model_states)
# set the final rng
self.set_rng(final_model_rng)
[docs]
def serialize(self) -> Dict[str, Union[Any, Dict[str, Any]]]:
"""
Serialize the model to a dictionary. This is done by serializing the
model's parameters/metadata and then serializing each module.
Returns
-------
Dict[str, Union[Any, Dict[str, Any]]]
"""
module_fulltypenames = [str(type(module)) for module in self.modules]
module_typenames = [
module.__class__.__name__ for module in self.modules
]
module_modules = [module.__module__ for module in self.modules]
module_names = [module.name() for module in self.modules]
serialized_modules = [module.serialize() for module in self.modules]
# serialize rng key
key_data = jax.random.key_data(self.get_rng())
return {
"module_typenames": module_typenames,
"module_modules": module_modules,
"module_fulltypenames": module_fulltypenames,
"module_names": module_names,
"serialized_modules": serialized_modules,
"key_data": key_data,
"package_version": pmm.__version__,
}
[docs]
def deserialize(self, data: Dict[str, Any]) -> None:
"""
Deserialize the model from a dictionary. This is done by deserializing
the model's parameters/metadata and then deserializing each module.
Parameters
----------
data : Dict[str, Any]
Dictionary containing the serialized model data.
"""
self.reset()
# read the version of the package this model was serialized with
current_version = parse(pmm.__version__)
package_version = parse(str(data["package_version"]))
if current_version != package_version:
# in the future, we will issue DeprecationWarnings or Errors if the
# version is unsupported
# or possibly handle version-specific deserialization
pass
module_typenames = data["module_typenames"]
module_modules = data["module_modules"]
# initialize the modules
self.modules = [
getattr(sys.modules[module_module], module_typename)()
for module_typename, module_module in zip(
module_typenames, module_modules
)
]
# deserialize the modules
for module, serialized_module in zip(
self.modules, data["serialized_modules"]
):
module.deserialize(serialized_module)
# deserialize the rng key
key = jax.random.wrap_key_data(data["key_data"])
self.set_rng(key)
[docs]
def save(self, filename: str) -> None:
"""
Save the model to a file.
Parameters
----------
filename : str
Name of the file to save the model to.
"""
# if everything serializes correctly, we can save the model with just
# savez
data = self.serialize()
filename = filename if filename.endswith(".npz") else filename + ".npz"
np.savez(filename, **data)
[docs]
def save_compressed(self, filename: str) -> None:
"""
Save the model to a compressed file.
Parameters
----------
filename : str
Name of the file to save the model to.
"""
# if everything serializes correctly, we can save the model with just
# savez_compressed
data = self.serialize()
filename = filename if filename.endswith(".npz") else filename + ".npz"
# jax.numpy doesn't have savez_compressed, so we use numpy
onp.savez_compressed(filename, **data)
[docs]
def load(self, filename: str) -> None:
"""
Load the model from a file. Supports both compressed and uncompressed
Parameters
----------
filename : str
Name of the file to load the model from.
"""
filename = filename if filename.endswith(".npz") else filename + ".npz"
# jax numpy load supports both compressed and uncompressed npz files
data = np.load(filename, allow_pickle=True)
# deserialize the model
self.deserialize(data)
[docs]
@classmethod
def from_file(cls, filename: str) -> "Model":
"""
Load a model from a file and return an instance of the Model class.
Parameters
----------
filename : str
Name of the file to load the model from.
Returns
-------
Model
An instance of the Model class with the loaded parameters.
"""
model = cls()
model.load(filename)
return model