{ "cells": [ { "cell_type": "markdown", "id": "c5130161", "metadata": {}, "source": [ "# Regression Example\n", "\n", "This notebook provides two examples for performing basic regression on nonlinear data. The first example builds a traditional neural network and the second uses an affine observable PMM." ] }, { "cell_type": "markdown", "id": "8defe0c7", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "c9583d5a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scikit-learn in /home/pcook/Research/FRIB/pmmenv/lib/python3.12/site-packages (1.8.0)\n", "Requirement already satisfied: numpy>=1.24.1 in /home/pcook/Research/FRIB/pmmenv/lib/python3.12/site-packages (from scikit-learn) (2.4.0)\n", "Requirement already satisfied: scipy>=1.10.0 in /home/pcook/Research/FRIB/pmmenv/lib/python3.12/site-packages (from scikit-learn) (1.16.3)\n", "Requirement already satisfied: joblib>=1.3.0 in /home/pcook/Research/FRIB/pmmenv/lib/python3.12/site-packages (from scikit-learn) (1.5.3)\n", "Requirement already satisfied: threadpoolctl>=3.2.0 in /home/pcook/Research/FRIB/pmmenv/lib/python3.12/site-packages (from scikit-learn) (3.6.0)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install scikit-learn" ] }, { "cell_type": "code", "execution_count": 2, "id": "e78966bd", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as np\n", "import jax.random as jr\n", "import parametricmatrixmodels as pmm\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "# enable x64 for JAX\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "markdown", "id": "4b23be7d", "metadata": {}, "source": [ "## Data Creation" ] }, { "cell_type": "code", "execution_count": 3, "id": "d7a41f8e", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(100, 1)\n", "(100, 1)\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# random seed for reproducibility\n", "SEED = 0\n", "key = jr.key(SEED)\n", "\n", "n_pnts = 100\n", "X = np.linspace(-2 * np.pi, 2 * np.pi, n_pnts)\n", "Y = np.sin(X) + 0.1 * X**2\n", "\n", "# add noise\n", "key, noise_key = jr.split(key)\n", "\n", "Y = Y + 0.2 * jr.normal(noise_key, Y.shape)\n", "\n", "# data must be in the shape [n_pnts, n_features] for the pmm library\n", "X = X[:, None]\n", "Y = Y[:, None]\n", "print(X.shape)\n", "print(Y.shape)\n", "\n", "plt.plot(X, Y, 'o')" ] }, { "cell_type": "markdown", "id": "54ec5169", "metadata": {}, "source": [ "## Data Partitioning" ] }, { "cell_type": "code", "execution_count": 4, "id": "73771245", "metadata": {}, "outputs": [], "source": [ "# train on 70% of the data, validate on 15%, test on the rest\n", "train_perc = 0.7\n", "val_perc = 0.15\n", "\n", "n_train = int(n_pnts * train_perc)\n", "n_val = int(n_pnts * val_perc)\n", "n_test = n_pnts - n_train - n_val\n", "\n", "key, shuffle_key = jr.split(key)\n", "\n", "# shuffle the data and then split\n", "X_sh = jr.permutation(shuffle_key, X)\n", "Y_sh = jr.permutation(shuffle_key, Y)\n", "\n", "X_train = X_sh[:n_train]\n", "Y_train = Y_sh[:n_train]\n", "X_val = X_sh[n_train:n_train + n_val]\n", "Y_val = Y_sh[n_train:n_train + n_val]\n", "X_test = X_sh[n_train + n_val:]\n", "Y_test = Y_sh[n_train + n_val:]" ] }, { "cell_type": "markdown", "id": "ec05f0c8", "metadata": {}, "source": [ "## Traditional Neural Network" ] }, { "cell_type": "markdown", "id": "ae389dc7", "metadata": {}, "source": [ "### Data Preparation (Scaling)\n", "\n", "Regression neural networks typically function best when the data are scaled to have unit variance and zero mean (`StandardScaler`)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cbfb0405", "metadata": {}, "outputs": [], "source": [ "xscaler = StandardScaler()\n", "yscaler = StandardScaler()\n", "\n", "X_train_sc = xscaler.fit_transform(X_train)\n", "X_val_sc = xscaler.transform(X_val)\n", "Y_train_sc = yscaler.fit_transform(Y_train)\n", "Y_val_sc = yscaler.transform(Y_val)\n", "\n", "# we need to convert the arrays back from numpy arrays to jax.numpy arrays, since sklearn uses pure numpy\n", "X_train_sc = np.array(X_train_sc)\n", "X_val_sc = np.array(X_val_sc)\n", "Y_train_sc = np.array(Y_train_sc)\n", "Y_val_sc = np.array(Y_val_sc)" ] }, { "cell_type": "markdown", "id": "9468335e", "metadata": {}, "source": [ "### Model Creation\n", "\n", "The model is built from an ordered sequence of `Module`s. For a traditional neural network, these are just `LinearNN` modules. Since we want linear sequential execution, we use a `SequentialModel`." ] }, { "cell_type": "code", "execution_count": 6, "id": "74afec31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequentialModel(\n", " [\n", " LinearNN(\n", " []\n", " ),\n", " LinearNN(\n", " []\n", " ),\n", " LinearNN(\n", " []\n", " ),\n", " ]\n", ")\n" ] } ], "source": [ "# two hidden layers with 64 neurons each and the ReLU activation function\n", "# one output neuron\n", "modules = [\n", " pmm.modules.LinearNN(\n", " out_features=64, bias=True, activation=pmm.modules.ReLU()\n", " ),\n", " pmm.modules.LinearNN(\n", " out_features=64, bias=True, activation=pmm.modules.ReLU()\n", " ),\n", " pmm.modules.LinearNN(out_features=1, bias=True),\n", " ]\n", "\n", "model = pmm.SequentialModel(modules)\n", "\n", "# print model summary before compilation\n", "print(model)" ] }, { "cell_type": "markdown", "id": "e80b2a8a", "metadata": {}, "source": [ "We can see that the model summary is very sparse at the moment. This is because the model doesn't know what data to expect yet and doesn't have any initial values for trainable parameters." ] }, { "cell_type": "markdown", "id": "d795ea73", "metadata": {}, "source": [ "### Model Compilation\n", "\n", "It's more accurate to call this preparation or initialization, since all compilation happens JIT. This step doesn't _need_ to be done manually, as the model will automatically compile itself for the provided training data when the `Model.train` method is called.\n", "\n", "Models are compiled by providing them with a random key or seed as well as the shape of the input data (excluding the batch dimension). This allows the model to prepare all its modules by, for instance, setting initial values for trainable parameters. Forward passes (inferences/predictions) cannot be done with Models (or Modules) until after compilation with the corresponding input shape." ] }, { "cell_type": "code", "execution_count": 7, "id": "2c490289", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequentialModel(\n", " [\n", " LinearNN(\n", " [\n", " Flatten,\n", " MatMul(trainable) (trainable floats: 64),\n", " Bias(real=True) (trainable floats: 64),\n", " ReLU,\n", " ]\n", " ),\n", " LinearNN(\n", " [\n", " Flatten,\n", " MatMul(trainable) (trainable floats: 4,096),\n", " Bias(real=True) (trainable floats: 64),\n", " ReLU,\n", " ]\n", " ),\n", " LinearNN(\n", " [\n", " Flatten,\n", " MatMul(trainable) (trainable floats: 64),\n", " Bias(real=True) (trainable floats: 1),\n", " ]\n", " ),\n", " ]\n", ")\n", "Total trainable floats: 4353\n" ] } ], "source": [ "key, compile_key = jr.split(key)\n", "\n", "# the random key here can be replaced with an integer seed or None, in which case a random seed will be chosen\n", "# the model only needs to know the input shape without the batch dimension\n", "model.compile(key, X_train_sc.shape[1:])\n", "\n", "# print the model summary after compilation\n", "print(model)\n", "print(f\"Total trainable floats: {model.get_num_trainable_floats()}\")" ] }, { "cell_type": "markdown", "id": "0cd8ab73", "metadata": {}, "source": [ "Now we see some actual details about the model. We can see it is built from 3 `LinearNN` Modules (which themselves are actually also Models) which use three to four submodules each: `Flatten`, a trainable `MatMul`, a trainable and real-valued `Bias`, and for the hidden layers `ReLU`. This represents the classic $f(Wx+b)$ neural network layer operation.\n", "\n", "We can perform a forward pass (inference/prediction) with the randomly initialized model just to make sure everything is working:" ] }, { "cell_type": "code", "execution_count": 8, "id": "d3640a40", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([[-0.00960011]], dtype=float64)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(np.array([[1.0]]))" ] }, { "cell_type": "markdown", "id": "4e136541", "metadata": {}, "source": [ "### 64-bit versus 32-bit\n", "\n", "By default, JAX disables 64-bit floating point support (double precision). This is re-enabled by calling `jax.config.update(\"jax_enable_x64\", True)` after JAX is imported but _before_ it is initialized (usually before the first array operation). On certain hardware (like consumer GPUs) 64-bit operations can be nearly an order of magnitude slower than 32-bit. Additionally, there is little benefit to storing the trainable parameters of a model in 64-bit precision or for training models with 64-bit precision. It is for this reason that the PMM library defaults to (and will raise warnings otherwise) 32-bit models, 32-bit training, and 64-bit inference.\n", "\n", "Here, we explicitly cast the model to 32-bit as well as the training/validation data." ] }, { "cell_type": "code", "execution_count": 9, "id": "59c43dd8", "metadata": {}, "outputs": [], "source": [ "model = model.astype(np.float32)\n", "\n", "X_train_sc_32 = X_train_sc.astype(np.float32)\n", "Y_train_sc_32 = Y_train_sc.astype(np.float32)\n", "X_val_sc_32 = X_val_sc.astype(np.float32)\n", "Y_val_sc_32 = Y_val_sc.astype(np.float32)" ] }, { "cell_type": "markdown", "id": "5bccce97", "metadata": {}, "source": [ "### Training\n", "\n", "To train the model, we simply call the `model.train()` method and supply it with the training and validation data and optionally things like the loss function and options for the optimization process such as the learning rate, total number of epochs, batch size, etc." ] }, { "cell_type": "code", "execution_count": 10, "id": "4618b7da", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/250 | 7.8607e-01/7.8607e-01 [####################] (0s)\n", "2/250 | 5.0318e-01/5.0318e-01 [####################] (0s)\n", "3/250 | 4.5940e-01/4.5940e-01 [####################] (0s)\n", "4/250 | 4.0544e-01/4.0544e-01 [####################] (0s)\n", "5/250 | 3.6005e-01/3.6005e-01 [####################] (0s)\n", "6/250 | 2.8516e-01/2.8516e-01 [####################] (0s)\n", "7/250 | 2.5043e-01/2.5043e-01 [####################] (0s)\n", "8/250 | 2.2804e-01/2.2804e-01 [####################] (0s)\n", "9/250 | 1.9864e-01/1.9864e-01 [####################] (0s)\n", "10/250 | 1.8128e-01/1.8128e-01 [####################] (0s)\n", "11/250 | 1.8519e-01/1.8128e-01 [####################] (0s)\n", "12/250 | 1.6765e-01/1.6765e-01 [####################] (0s)\n", "13/250 | 1.5993e-01/1.5993e-01 [####################] (0s)\n", "14/250 | 1.5118e-01/1.5118e-01 [####################] (0s)\n", "15/250 | 1.4756e-01/1.4756e-01 [####################] (0s)\n", "16/250 | 1.4488e-01/1.4488e-01 [####################] (0s)\n", "17/250 | 1.5423e-01/1.4488e-01 [####################] (0s)\n", "18/250 | 1.3836e-01/1.3836e-01 [####################] (0s)\n", "19/250 | 1.2858e-01/1.2858e-01 [####################] (0s)\n", "20/250 | 1.3492e-01/1.2858e-01 [####################] (0s)\n", "21/250 | 1.2824e-01/1.2824e-01 [####################] (0s)\n", "22/250 | 1.2356e-01/1.2356e-01 [####################] (0s)\n", "23/250 | 1.1941e-01/1.1941e-01 [####################] (0s)\n", "24/250 | 1.3161e-01/1.1941e-01 [####################] (0s)\n", "25/250 | 1.2221e-01/1.1941e-01 [####################] (0s)\n", "26/250 | 1.2183e-01/1.1941e-01 [####################] (0s)\n", "27/250 | 1.1759e-01/1.1759e-01 [####################] (0s)\n", "28/250 | 1.1619e-01/1.1619e-01 [####################] (0s)\n", "29/250 | 1.1807e-01/1.1619e-01 [####################] (0s)\n", "30/250 | 1.0958e-01/1.0958e-01 [####################] (0s)\n", "31/250 | 1.3600e-01/1.0958e-01 [####################] (0s)\n", "32/250 | 1.1083e-01/1.0958e-01 [####################] (0s)\n", "33/250 | 1.0806e-01/1.0806e-01 [####################] (0s)\n", "34/250 | 1.1257e-01/1.0806e-01 [####################] (0s)\n", "35/250 | 1.2772e-01/1.0806e-01 [####################] (0s)\n", "36/250 | 1.1246e-01/1.0806e-01 [####################] (0s)\n", "37/250 | 1.0966e-01/1.0806e-01 [####################] (0s)\n", "38/250 | 1.2023e-01/1.0806e-01 [####################] (0s)\n", "39/250 | 1.0668e-01/1.0668e-01 [####################] (0s)\n", "40/250 | 1.0356e-01/1.0356e-01 [####################] (0s)\n", "41/250 | 1.1708e-01/1.0356e-01 [####################] (0s)\n", "42/250 | 1.0925e-01/1.0356e-01 [####################] (0s)\n", "43/250 | 1.0966e-01/1.0356e-01 [####################] (0s)\n", "44/250 | 1.0242e-01/1.0242e-01 [####################] (0s)\n", "45/250 | 1.0997e-01/1.0242e-01 [####################] (0s)\n", "46/250 | 1.0356e-01/1.0242e-01 [####################] (0s)\n", "47/250 | 1.0313e-01/1.0242e-01 [####################] (0s)\n", "48/250 | 1.0253e-01/1.0242e-01 [####################] (0s)\n", "49/250 | 1.0032e-01/1.0032e-01 [####################] (0s)\n", "50/250 | 9.7897e-02/9.7897e-02 [####################] (0s)\n", "51/250 | 1.0216e-01/9.7897e-02 [####################] (0s)\n", "52/250 | 1.0057e-01/9.7897e-02 [####################] (0s)\n", "53/250 | 1.0118e-01/9.7897e-02 [####################] (0s)\n", "54/250 | 9.7351e-02/9.7351e-02 [####################] (0s)\n", "55/250 | 1.1234e-01/9.7351e-02 [####################] (0s)\n", "56/250 | 9.3110e-02/9.3110e-02 [####################] (0s)\n", "57/250 | 9.9734e-02/9.3110e-02 [####################] (0s)\n", "58/250 | 9.1770e-02/9.1770e-02 [####################] (0s)\n", "59/250 | 8.9069e-02/8.9069e-02 [####################] (0s)\n", "60/250 | 1.0289e-01/8.9069e-02 [####################] (0s)\n", "61/250 | 9.6913e-02/8.9069e-02 [####################] (0s)\n", "62/250 | 9.6241e-02/8.9069e-02 [####################] (0s)\n", "63/250 | 9.8420e-02/8.9069e-02 [####################] (0s)\n", "64/250 | 8.7649e-02/8.7649e-02 [####################] (0s)\n", "65/250 | 9.0070e-02/8.7649e-02 [####################] (0s)\n", "66/250 | 9.6413e-02/8.7649e-02 [####################] (0s)\n", "67/250 | 9.2351e-02/8.7649e-02 [####################] (0s)\n", "68/250 | 1.0716e-01/8.7649e-02 [####################] (0s)\n", "69/250 | 1.0409e-01/8.7649e-02 [####################] (0s)\n", "70/250 | 8.6707e-02/8.6707e-02 [####################] (0s)\n", "71/250 | 8.8159e-02/8.6707e-02 [####################] (0s)\n", "72/250 | 8.7628e-02/8.6707e-02 [####################] (0s)\n", "73/250 | 8.3547e-02/8.3547e-02 [####################] (0s)\n", "74/250 | 8.0006e-02/8.0006e-02 [####################] (0s)\n", "75/250 | 8.8055e-02/8.0006e-02 [####################] (0s)\n", "76/250 | 8.6013e-02/8.0006e-02 [####################] (0s)\n", "77/250 | 9.0464e-02/8.0006e-02 [####################] (0s)\n", "78/250 | 8.8312e-02/8.0006e-02 [####################] (0s)\n", "79/250 | 8.0211e-02/8.0006e-02 [####################] (0s)\n", "80/250 | 8.8458e-02/8.0006e-02 [####################] (0s)\n", "81/250 | 8.6534e-02/8.0006e-02 [####################] (0s)\n", "82/250 | 8.5332e-02/8.0006e-02 [####################] (0s)\n", "83/250 | 8.3566e-02/8.0006e-02 [####################] (0s)\n", "84/250 | 8.0716e-02/8.0006e-02 [####################] (0s)\n", "85/250 | 8.1569e-02/8.0006e-02 [####################] (0s)\n", "86/250 | 8.0372e-02/8.0006e-02 [####################] (0s)\n", "87/250 | 7.9148e-02/7.9148e-02 [####################] (0s)\n", "88/250 | 7.3478e-02/7.3478e-02 [####################] (0s)\n", "89/250 | 8.1903e-02/7.3478e-02 [####################] (0s)\n", "90/250 | 1.0176e-01/7.3478e-02 [####################] (0s)\n", "91/250 | 8.8403e-02/7.3478e-02 [####################] (0s)\n", "92/250 | 9.1995e-02/7.3478e-02 [####################] (0s)\n", "93/250 | 1.0767e-01/7.3478e-02 [####################] (0s)\n", "94/250 | 8.3559e-02/7.3478e-02 [####################] (0s)\n", "95/250 | 7.8728e-02/7.3478e-02 [####################] (0s)\n", "96/250 | 7.3051e-02/7.3051e-02 [####################] (0s)\n", "97/250 | 9.4739e-02/7.3051e-02 [####################] (0s)\n", "98/250 | 9.4280e-02/7.3051e-02 [####################] (0s)\n", "99/250 | 9.2930e-02/7.3051e-02 [####################] (0s)\n", "100/250 | 7.2277e-02/7.2277e-02 [####################] (0s)\n", "101/250 | 7.1439e-02/7.1439e-02 [####################] (0s)\n", "102/250 | 6.8890e-02/6.8890e-02 [####################] (0s)\n", "103/250 | 6.8998e-02/6.8890e-02 [####################] (0s)\n", "104/250 | 7.1708e-02/6.8890e-02 [####################] (0s)\n", "105/250 | 7.4463e-02/6.8890e-02 [####################] (0s)\n", "106/250 | 6.4985e-02/6.4985e-02 [####################] (0s)\n", "107/250 | 6.9123e-02/6.4985e-02 [####################] (0s)\n", "108/250 | 6.5660e-02/6.4985e-02 [####################] (0s)\n", "109/250 | 6.5931e-02/6.4985e-02 [####################] (0s)\n", "110/250 | 6.1449e-02/6.1449e-02 [####################] (0s)\n", "111/250 | 6.4720e-02/6.1449e-02 [####################] (0s)\n", "112/250 | 6.7240e-02/6.1449e-02 [####################] (0s)\n", "113/250 | 7.3970e-02/6.1449e-02 [####################] (0s)\n", "114/250 | 6.1726e-02/6.1449e-02 [####################] (0s)\n", "115/250 | 8.1460e-02/6.1449e-02 [####################] (0s)\n", "116/250 | 8.6221e-02/6.1449e-02 [####################] (0s)\n", "117/250 | 7.0352e-02/6.1449e-02 [####################] (0s)\n", "118/250 | 7.6675e-02/6.1449e-02 [####################] (0s)\n", "119/250 | 9.5870e-02/6.1449e-02 [####################] (0s)\n", "120/250 | 8.7505e-02/6.1449e-02 [####################] (0s)\n", "121/250 | 8.1257e-02/6.1449e-02 [####################] (0s)\n", "122/250 | 8.4181e-02/6.1449e-02 [####################] (0s)\n", "123/250 | 7.5042e-02/6.1449e-02 [####################] (0s)\n", "124/250 | 6.2688e-02/6.1449e-02 [####################] (0s)\n", "125/250 | 8.3008e-02/6.1449e-02 [####################] (0s)\n", "126/250 | 6.0585e-02/6.0585e-02 [####################] (0s)\n", "127/250 | 6.7834e-02/6.0585e-02 [####################] (0s)\n", "128/250 | 9.3002e-02/6.0585e-02 [####################] (0s)\n", "129/250 | 8.4863e-02/6.0585e-02 [####################] (0s)\n", "130/250 | 6.4601e-02/6.0585e-02 [####################] (0s)\n", "131/250 | 6.9106e-02/6.0585e-02 [####################] (0s)\n", "132/250 | 6.9529e-02/6.0585e-02 [####################] (0s)\n", "133/250 | 6.9735e-02/6.0585e-02 [####################] (0s)\n", "134/250 | 6.3158e-02/6.0585e-02 [####################] (0s)\n", "135/250 | 5.8077e-02/5.8077e-02 [####################] (0s)\n", "136/250 | 5.8051e-02/5.8051e-02 [####################] (0s)\n", "137/250 | 5.1926e-02/5.1926e-02 [####################] (0s)\n", "138/250 | 6.9325e-02/5.1926e-02 [####################] (0s)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "139/250 | 6.6578e-02/5.1926e-02 [####################] (0s)\n", "140/250 | 5.4720e-02/5.1926e-02 [####################] (0s)\n", "141/250 | 6.1397e-02/5.1926e-02 [####################] (0s)\n", "142/250 | 4.6176e-02/4.6176e-02 [####################] (0s)\n", "143/250 | 5.6338e-02/4.6176e-02 [####################] (0s)\n", "144/250 | 5.9826e-02/4.6176e-02 [####################] (0s)\n", "145/250 | 5.2593e-02/4.6176e-02 [####################] (0s)\n", "146/250 | 4.7631e-02/4.6176e-02 [####################] (0s)\n", "147/250 | 6.0039e-02/4.6176e-02 [####################] (0s)\n", "148/250 | 5.1952e-02/4.6176e-02 [####################] (0s)\n", "149/250 | 6.4384e-02/4.6176e-02 [####################] (0s)\n", "150/250 | 4.7902e-02/4.6176e-02 [####################] (0s)\n", "151/250 | 5.4240e-02/4.6176e-02 [####################] (0s)\n", "152/250 | 8.7593e-02/4.6176e-02 [####################] (0s)\n", "153/250 | 5.5680e-02/4.6176e-02 [####################] (0s)\n", "154/250 | 5.9909e-02/4.6176e-02 [####################] (0s)\n", "155/250 | 4.5948e-02/4.5948e-02 [####################] (0s)\n", "156/250 | 4.5787e-02/4.5787e-02 [####################] (0s)\n", "157/250 | 5.4487e-02/4.5787e-02 [####################] (0s)\n", "158/250 | 4.5993e-02/4.5787e-02 [####################] (0s)\n", "159/250 | 5.2104e-02/4.5787e-02 [####################] (0s)\n", "160/250 | 6.5566e-02/4.5787e-02 [####################] (0s)\n", "161/250 | 4.5222e-02/4.5222e-02 [####################] (0s)\n", "162/250 | 5.6127e-02/4.5222e-02 [####################] (0s)\n", "163/250 | 5.9438e-02/4.5222e-02 [####################] (0s)\n", "164/250 | 4.9327e-02/4.5222e-02 [####################] (0s)\n", "165/250 | 3.8123e-02/3.8123e-02 [####################] (0s)\n", "166/250 | 5.3154e-02/3.8123e-02 [####################] (0s)\n", "167/250 | 4.5029e-02/3.8123e-02 [####################] (0s)\n", "168/250 | 4.1853e-02/3.8123e-02 [####################] (0s)\n", "169/250 | 4.6038e-02/3.8123e-02 [####################] (0s)\n", "170/250 | 5.9997e-02/3.8123e-02 [####################] (0s)\n", "171/250 | 4.2596e-02/3.8123e-02 [####################] (0s)\n", "172/250 | 4.8706e-02/3.8123e-02 [####################] (0s)\n", "173/250 | 4.2727e-02/3.8123e-02 [####################] (0s)\n", "174/250 | 5.2578e-02/3.8123e-02 [####################] (0s)\n", "175/250 | 4.3788e-02/3.8123e-02 [####################] (0s)\n", "176/250 | 7.0637e-02/3.8123e-02 [####################] (0s)\n", "177/250 | 4.1196e-02/3.8123e-02 [####################] (0s)\n", "178/250 | 3.1865e-02/3.1865e-02 [####################] (0s)\n", "179/250 | 3.7221e-02/3.1865e-02 [####################] (0s)\n", "180/250 | 3.3198e-02/3.1865e-02 [####################] (0s)\n", "181/250 | 3.1750e-02/3.1750e-02 [####################] (0s)\n", "182/250 | 3.2191e-02/3.1750e-02 [####################] (0s)\n", "183/250 | 3.5275e-02/3.1750e-02 [####################] (0s)\n", "184/250 | 5.1212e-02/3.1750e-02 [####################] (0s)\n", "185/250 | 4.9963e-02/3.1750e-02 [####################] (0s)\n", "186/250 | 4.9463e-02/3.1750e-02 [####################] (0s)\n", "187/250 | 5.6674e-02/3.1750e-02 [####################] (0s)\n", "188/250 | 2.8552e-02/2.8552e-02 [####################] (0s)\n", "189/250 | 2.9090e-02/2.8552e-02 [####################] (0s)\n", "190/250 | 2.9629e-02/2.8552e-02 [####################] (0s)\n", "191/250 | 3.5665e-02/2.8552e-02 [####################] (0s)\n", "192/250 | 3.7032e-02/2.8552e-02 [####################] (0s)\n", "193/250 | 2.9679e-02/2.8552e-02 [####################] (0s)\n", "194/250 | 3.8773e-02/2.8552e-02 [####################] (0s)\n", "195/250 | 3.4265e-02/2.8552e-02 [####################] (0s)\n", "196/250 | 4.1550e-02/2.8552e-02 [####################] (0s)\n", "197/250 | 3.5000e-02/2.8552e-02 [####################] (0s)\n", "198/250 | 3.6903e-02/2.8552e-02 [####################] (0s)\n", "199/250 | 2.7838e-02/2.7838e-02 [####################] (0s)\n", "200/250 | 2.5960e-02/2.5960e-02 [####################] (0s)\n", "201/250 | 2.7936e-02/2.5960e-02 [####################] (0s)\n", "202/250 | 3.4787e-02/2.5960e-02 [####################] (0s)\n", "203/250 | 3.4081e-02/2.5960e-02 [####################] (0s)\n", "204/250 | 4.0938e-02/2.5960e-02 [####################] (0s)\n", "205/250 | 3.7885e-02/2.5960e-02 [####################] (0s)\n", "206/250 | 3.4678e-02/2.5960e-02 [####################] (0s)\n", "207/250 | 6.0238e-02/2.5960e-02 [####################] (0s)\n", "208/250 | 3.5822e-02/2.5960e-02 [####################] (0s)\n", "209/250 | 2.8570e-02/2.5960e-02 [####################] (0s)\n", "210/250 | 2.7827e-02/2.5960e-02 [####################] (0s)\n", "211/250 | 3.3676e-02/2.5960e-02 [####################] (0s)\n", "212/250 | 3.0658e-02/2.5960e-02 [####################] (0s)\n", "213/250 | 3.1736e-02/2.5960e-02 [####################] (0s)\n", "214/250 | 3.2551e-02/2.5960e-02 [####################] (0s)\n", "215/250 | 2.3595e-02/2.3595e-02 [####################] (0s)\n", "216/250 | 3.0965e-02/2.3595e-02 [####################] (0s)\n", "217/250 | 4.1416e-02/2.3595e-02 [####################] (0s)\n", "218/250 | 7.1400e-02/2.3595e-02 [####################] (0s)\n", "219/250 | 6.0360e-02/2.3595e-02 [####################] (0s)\n", "220/250 | 4.7201e-02/2.3595e-02 [####################] (0s)\n", "221/250 | 4.6297e-02/2.3595e-02 [####################] (0s)\n", "222/250 | 4.8735e-02/2.3595e-02 [####################] (0s)\n", "223/250 | 3.6821e-02/2.3595e-02 [####################] (0s)\n", "224/250 | 3.1825e-02/2.3595e-02 [####################] (0s)\n", "225/250 | 3.2482e-02/2.3595e-02 [####################] (0s)\n", "226/250 | 3.4695e-02/2.3595e-02 [####################] (0s)\n", "227/250 | 3.2790e-02/2.3595e-02 [####################] (0s)\n", "228/250 | 3.5002e-02/2.3595e-02 [####################] (0s)\n", "229/250 | 4.2969e-02/2.3595e-02 [####################] (0s)\n", "230/250 | 4.6462e-02/2.3595e-02 [####################] (0s)\n", "231/250 | 4.0521e-02/2.3595e-02 [####################] (0s)\n", "232/250 | 3.8239e-02/2.3595e-02 [####################] (0s)\n", "233/250 | 3.0208e-02/2.3595e-02 [####################] (0s)\n", "234/250 | 2.4054e-02/2.3595e-02 [####################] (0s)\n", "235/250 | 4.7296e-02/2.3595e-02 [####################] (0s)\n", "236/250 | 3.9185e-02/2.3595e-02 [####################] (0s)\n", "237/250 | 3.4520e-02/2.3595e-02 [####################] (0s)\n", "238/250 | 3.4001e-02/2.3595e-02 [####################] (0s)\n", "239/250 | 4.1196e-02/2.3595e-02 [####################] (0s)\n", "240/250 | 2.8311e-02/2.3595e-02 [####################] (0s)\n", "241/250 | 5.0485e-02/2.3595e-02 [####################] (0s)\n", "242/250 | 3.6910e-02/2.3595e-02 [####################] (0s)\n", "243/250 | 4.3837e-02/2.3595e-02 [####################] (0s)\n", "244/250 | 3.7291e-02/2.3595e-02 [####################] (0s)\n", "245/250 | 3.3454e-02/2.3595e-02 [####################] (0s)\n", "246/250 | 2.1437e-02/2.1437e-02 [####################] (0s)\n", "247/250 | 3.5610e-02/2.1437e-02 [####################] (0s)\n", "248/250 | 2.7829e-02/2.1437e-02 [####################] (0s)\n", "249/250 | 4.2877e-02/2.1437e-02 [####################] (0s)\n", "250/250 | 3.8412e-02/2.1437e-02 [####################] (0s)\n", "\n", "========================================\n", "Total epochs: 250\n", "(best epoch: 245)\n", "(best validation loss: 2.1437E-02)\n", "========================================\n" ] } ], "source": [ "key, batch_key = jr.split(key)\n", "\n", "model.train(\n", " X_train_sc_32,\n", " Y=Y_train_sc_32,\n", " X_val=X_val_sc_32,\n", " Y_val=Y_val_sc_32,\n", " lr=1e-2,\n", " epochs=250,\n", " batch_size=5,\n", " batch_rng=batch_key,\n", " verbose=True)" ] }, { "cell_type": "markdown", "id": "7c004178", "metadata": {}, "source": [ "### Inference\n", "\n", "Now we're ready to make new predictions with the model. All Models and Modules are directly callable once compiled, so we need only to pass in the (scaled) inputs like `model(X)` and unscale the predictions." ] }, { "cell_type": "code", "execution_count": 11, "id": "9cb396f2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# pass in the scaled inputs (converted to JAX arrays again)\n", "X_sc = xscaler.transform(X)\n", "X_sc = np.array(X_sc)\n", "\n", "# get scaled predictions\n", "Y_pred_sc = model(X_sc)\n", "\n", "# unscale predictions\n", "Y_pred = yscaler.inverse_transform(Y_pred_sc)\n", "\n", "# plot results\n", "plt.plot(X, Y_pred, label=\"NN\", lw=2, zorder=10)\n", "plt.plot(X_train, Y_train, 'o', label=\"Train\")\n", "plt.plot(X_val, Y_val, 's', label=\"Val\")\n", "plt.plot(X_test, Y_test, 'x', label=\"Test\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "id": "184c3307", "metadata": {}, "source": [ "## Parametric Matrix Model (`AffineObservablePMM`)\n", "\n", "Here we repeat the above process for the `AffineObservablePMM`, which is both a Module and a Model. For consistency, we wrap it with a `SequentialModel`, though this isn't necessary as it is itself a subclass of `SequentialModel`. The only other difference in usage here is that this kind of PMM performs best when the data are scaled uniformly, using `MinMaxScaler`." ] }, { "cell_type": "markdown", "id": "609b897b", "metadata": {}, "source": [ "### Data Preparation (Scaling)\n", "\n", "This kind of PMM typically functions best when the data are scaled uniformly (`MinMaxScaler`)" ] }, { "cell_type": "code", "execution_count": 12, "id": "640a64ac", "metadata": {}, "outputs": [], "source": [ "xscaler = MinMaxScaler()\n", "yscaler = MinMaxScaler()\n", "\n", "X_train_sc = xscaler.fit_transform(X_train)\n", "X_val_sc = xscaler.transform(X_val)\n", "Y_train_sc = yscaler.fit_transform(Y_train)\n", "Y_val_sc = yscaler.transform(Y_val)\n", "\n", "# we need to convert the arrays back from numpy arrays to jax.numpy arrays, since sklearn uses pure numpy\n", "X_train_sc = np.array(X_train_sc)\n", "X_val_sc = np.array(X_val_sc)\n", "Y_train_sc = np.array(Y_train_sc)\n", "Y_val_sc = np.array(Y_val_sc)" ] }, { "cell_type": "markdown", "id": "5b004f0d", "metadata": {}, "source": [ "### Model Creation\n", "\n", "This model is a single `AffineObservablePMM` module with a Hermitian matrices size of `5`, two eigenvectors, one secondary matrix, and using expectation values instead of transition amplitudes. Internally, this Module is just a `SequentialModel` of Modules that first form the primary matrix $H(c) = H_0 + \\sum_i c_i H_i$, then take the eigendecomposition, then compute the sum of transition amplitudes or expectation values with trainable secondary matrices, then finally add a trainable bias." ] }, { "cell_type": "code", "execution_count": 13, "id": "674e42ee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequentialModel(\n", " [\n", " AffineObservablePMM(\n", " []\n", " ),\n", " ]\n", ")\n", "Total trainable floats: 0\n" ] } ], "source": [ "# just a single PMM module, which contains all the submodules\n", "modules = [\n", " pmm.modules.AffineObservablePMM(\n", " matrix_size=5,\n", " num_eig=2,\n", " num_secondaries=1,\n", " output_size=1,\n", " use_expectation_values=True\n", " )\n", " ]\n", "\n", "model = pmm.SequentialModel(modules)\n", "\n", "# print model summary before compilation\n", "print(model)\n", "print(f\"Total trainable floats: {model.get_num_trainable_floats()}\")" ] }, { "cell_type": "markdown", "id": "36bad79c", "metadata": {}, "source": [ "We can see that the model summary is very sparse at the moment. This is because the model doesn't know what data to expect yet and doesn't have any initial values for trainable parameters." ] }, { "cell_type": "markdown", "id": "1f679fe9", "metadata": {}, "source": [ "### Model Compilation\n", "\n", "It's more accurate to call this preparation or initialization, since all compilation happens JIT. This step doesn't _need_ to be done manually, as the model will automatically compile itself for the provided training data when the `Model.train` method is called.\n", "\n", "Models are compiled by providing them with a random key or seed as well as the shape of the input data (excluding the batch dimension). This allows the model to prepare all its modules by, for instance, setting initial values for trainable parameters. Forward passes (inferences/predictions) cannot be done with Models (or Modules) until after compilation with the corresponding input shape." ] }, { "cell_type": "code", "execution_count": 14, "id": "349df642", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequentialModel(\n", " [\n", " AffineObservablePMM(\n", " (\n", " AffineHermitianMatrix(5x5,) (trainable floats: 50),\n", " Eigenvectors(num_eig=2, which=LM),\n", " ExpectationValueSum(output_size=1, num_observables=1, centered=True) (trainable floats: 25),\n", " Bias(real=True) (trainable floats: 1),\n", " )\n", " ),\n", " ]\n", ")\n" ] } ], "source": [ "key, compile_key = jr.split(key)\n", "\n", "# the random key here can be replaced with an integer seed or None, in which case a random seed will be chosen\n", "# the model only needs to know the input shape without the batch dimension\n", "model.compile(key, X_train_sc.shape[1:])\n", "\n", "# print the model summary after compilation\n", "print(model)" ] }, { "cell_type": "markdown", "id": "28e0223f", "metadata": {}, "source": [ "Now we see some actual details about the model. We can see it is built from 1 `AffineObservablePMM` Module (which itself is a Model) which uses four submodules: `AffineHermitianMatrix`, `Eigenvectors`, `ExpectationValueSum`, and `Bias`.\n", "\n", "We can perform a forward pass (inference/prediction) with the randomly initialized model just to make sure everything is working:" ] }, { "cell_type": "code", "execution_count": 15, "id": "d86e12ac", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([[0.01202144]], dtype=float64)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(np.array([[1.0]]))" ] }, { "cell_type": "markdown", "id": "2665a0b5", "metadata": {}, "source": [ "### 64-bit versus 32-bit\n", "\n", "By default, JAX disables 64-bit floating point support (double precision). This is re-enabled by calling `jax.config.update(\"jax_enable_x64\", True)` after JAX is imported but _before_ it is initialized (usually before the first array operation). On certain hardware (like consumer GPUs) 64-bit operations can be nearly an order of magnitude slower than 32-bit. Additionally, there is little benefit to storing the trainable parameters of a model in 64-bit precision or for training models with 64-bit precision. It is for this reason that the PMM library defaults to (and will raise warnings otherwise) 32-bit models, 32-bit training, and 64-bit inference.\n", "\n", "Here, we explicitly cast the model to 32-bit as well as the training/validation data." ] }, { "cell_type": "code", "execution_count": 16, "id": "27812ed7", "metadata": {}, "outputs": [], "source": [ "model = model.astype(np.float32)\n", "\n", "X_train_sc_32 = X_train_sc.astype(np.float32)\n", "Y_train_sc_32 = Y_train_sc.astype(np.float32)\n", "X_val_sc_32 = X_val_sc.astype(np.float32)\n", "Y_val_sc_32 = Y_val_sc.astype(np.float32)" ] }, { "cell_type": "markdown", "id": "cf6284bf", "metadata": {}, "source": [ "### Training\n", "\n", "To train the model, we simply call the `model.train()` method and supply it with the training and validation data and optionally things like the loss function and options for the optimization process such as the learning rate, total number of epochs, batch size, etc." ] }, { "cell_type": "code", "execution_count": 17, "id": "bd0c283a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/250 | 5.4442e-02/5.4442e-02 [####################] (0s)\n", "2/250 | 1.1614e-01/5.4442e-02 [####################] (0s)\n", "3/250 | 5.6446e-02/5.4442e-02 [####################] (0s)\n", "4/250 | 5.8632e-02/5.4442e-02 [####################] (0s)\n", "5/250 | 9.0178e-02/5.4442e-02 [####################] (0s)\n", "6/250 | 5.4176e-02/5.4176e-02 [####################] (0s)\n", "7/250 | 6.2531e-02/5.4176e-02 [####################] (0s)\n", "8/250 | 5.2981e-02/5.2981e-02 [####################] (0s)\n", "9/250 | 3.3681e-02/3.3681e-02 [####################] (0s)\n", "10/250 | 2.9399e-02/2.9399e-02 [####################] (0s)\n", "11/250 | 3.0416e-02/2.9399e-02 [####################] (0s)\n", "12/250 | 2.5683e-02/2.5683e-02 [####################] (0s)\n", "13/250 | 2.2098e-02/2.2098e-02 [####################] (0s)\n", "14/250 | 1.8719e-02/1.8719e-02 [####################] (0s)\n", "15/250 | 1.2611e-02/1.2611e-02 [####################] (0s)\n", "16/250 | 1.0284e-02/1.0284e-02 [####################] (0s)\n", "17/250 | 1.0085e-02/1.0085e-02 [####################] (0s)\n", "18/250 | 1.2628e-02/1.0085e-02 [####################] (0s)\n", "19/250 | 8.1678e-03/8.1678e-03 [####################] (0s)\n", "20/250 | 6.8490e-03/6.8490e-03 [####################] (0s)\n", "21/250 | 7.4929e-03/6.8490e-03 [####################] (0s)\n", "22/250 | 6.5267e-03/6.5267e-03 [####################] (0s)\n", "23/250 | 5.5787e-03/5.5787e-03 [####################] (0s)\n", "24/250 | 7.4365e-03/5.5787e-03 [####################] (0s)\n", "25/250 | 6.4864e-03/5.5787e-03 [####################] (0s)\n", "26/250 | 6.3493e-03/5.5787e-03 [####################] (0s)\n", "27/250 | 5.9987e-03/5.5787e-03 [####################] (0s)\n", "28/250 | 7.7315e-03/5.5787e-03 [####################] (0s)\n", "29/250 | 5.9036e-03/5.5787e-03 [####################] (0s)\n", "30/250 | 6.0343e-03/5.5787e-03 [####################] (0s)\n", "31/250 | 6.7738e-03/5.5787e-03 [####################] (0s)\n", "32/250 | 6.3439e-03/5.5787e-03 [####################] (0s)\n", "33/250 | 9.1537e-03/5.5787e-03 [####################] (0s)\n", "34/250 | 7.5088e-03/5.5787e-03 [####################] (0s)\n", "35/250 | 6.0626e-03/5.5787e-03 [####################] (0s)\n", "36/250 | 6.4997e-03/5.5787e-03 [####################] (0s)\n", "37/250 | 6.7489e-03/5.5787e-03 [####################] (0s)\n", "38/250 | 7.2865e-03/5.5787e-03 [####################] (0s)\n", "39/250 | 7.4852e-03/5.5787e-03 [####################] (0s)\n", "40/250 | 6.1094e-03/5.5787e-03 [####################] (0s)\n", "41/250 | 5.7489e-03/5.5787e-03 [####################] (0s)\n", "42/250 | 6.0052e-03/5.5787e-03 [####################] (0s)\n", "43/250 | 7.0590e-03/5.5787e-03 [####################] (0s)\n", "44/250 | 7.0230e-03/5.5787e-03 [####################] (0s)\n", "45/250 | 6.9031e-03/5.5787e-03 [####################] (0s)\n", "46/250 | 6.2807e-03/5.5787e-03 [####################] (0s)\n", "47/250 | 7.0101e-03/5.5787e-03 [####################] (0s)\n", "48/250 | 7.6631e-03/5.5787e-03 [####################] (0s)\n", "49/250 | 6.0248e-03/5.5787e-03 [####################] (0s)\n", "50/250 | 6.5716e-03/5.5787e-03 [####################] (0s)\n", "51/250 | 6.7445e-03/5.5787e-03 [####################] (0s)\n", "52/250 | 5.2785e-03/5.2785e-03 [####################] (0s)\n", "53/250 | 3.1016e-03/3.1016e-03 [####################] (0s)\n", "54/250 | 8.9250e-03/3.1016e-03 [####################] (0s)\n", "55/250 | 6.9021e-03/3.1016e-03 [####################] (0s)\n", "56/250 | 3.8885e-03/3.1016e-03 [####################] (0s)\n", "57/250 | 7.6216e-03/3.1016e-03 [####################] (0s)\n", "58/250 | 7.3961e-03/3.1016e-03 [####################] (0s)\n", "59/250 | 7.0220e-03/3.1016e-03 [####################] (0s)\n", "60/250 | 6.1720e-03/3.1016e-03 [####################] (0s)\n", "61/250 | 5.7947e-03/3.1016e-03 [####################] (0s)\n", "62/250 | 5.0734e-03/3.1016e-03 [####################] (0s)\n", "63/250 | 4.9879e-03/3.1016e-03 [####################] (0s)\n", "64/250 | 5.5361e-03/3.1016e-03 [####################] (0s)\n", "65/250 | 4.5192e-03/3.1016e-03 [####################] (0s)\n", "66/250 | 6.7327e-03/3.1016e-03 [####################] (0s)\n", "67/250 | 7.7463e-03/3.1016e-03 [####################] (0s)\n", "68/250 | 5.5338e-03/3.1016e-03 [####################] (0s)\n", "69/250 | 3.7159e-03/3.1016e-03 [####################] (0s)\n", "70/250 | 3.1470e-03/3.1016e-03 [####################] (0s)\n", "71/250 | 5.3691e-03/3.1016e-03 [####################] (0s)\n", "72/250 | 9.9612e-03/3.1016e-03 [####################] (0s)\n", "73/250 | 5.4284e-03/3.1016e-03 [####################] (0s)\n", "74/250 | 6.3011e-03/3.1016e-03 [####################] (0s)\n", "75/250 | 4.0493e-03/3.1016e-03 [####################] (0s)\n", "76/250 | 4.3772e-03/3.1016e-03 [####################] (0s)\n", "77/250 | 4.2802e-03/3.1016e-03 [####################] (0s)\n", "78/250 | 6.9974e-03/3.1016e-03 [####################] (0s)\n", "79/250 | 3.5675e-03/3.1016e-03 [####################] (0s)\n", "80/250 | 5.5649e-03/3.1016e-03 [####################] (0s)\n", "81/250 | 7.4596e-03/3.1016e-03 [####################] (0s)\n", "82/250 | 4.3281e-03/3.1016e-03 [####################] (0s)\n", "83/250 | 2.9820e-03/2.9820e-03 [####################] (0s)\n", "84/250 | 6.6735e-03/2.9820e-03 [####################] (0s)\n", "85/250 | 3.7156e-03/2.9820e-03 [####################] (0s)\n", "86/250 | 2.8052e-03/2.8052e-03 [####################] (0s)\n", "87/250 | 1.6796e-02/2.8052e-03 [####################] (0s)\n", "88/250 | 1.0363e-02/2.8052e-03 [####################] (0s)\n", "89/250 | 5.0196e-03/2.8052e-03 [####################] (0s)\n", "90/250 | 3.4737e-03/2.8052e-03 [####################] (0s)\n", "91/250 | 1.9751e-03/1.9751e-03 [####################] (0s)\n", "92/250 | 2.8802e-03/1.9751e-03 [####################] (0s)\n", "93/250 | 5.8319e-03/1.9751e-03 [####################] (0s)\n", "94/250 | 3.4529e-03/1.9751e-03 [####################] (0s)\n", "95/250 | 2.6395e-03/1.9751e-03 [####################] (0s)\n", "96/250 | 2.8948e-03/1.9751e-03 [####################] (0s)\n", "97/250 | 3.9332e-03/1.9751e-03 [####################] (0s)\n", "98/250 | 5.2598e-03/1.9751e-03 [####################] (0s)\n", "99/250 | 4.5422e-03/1.9751e-03 [####################] (0s)\n", "100/250 | 3.1950e-03/1.9751e-03 [####################] (0s)\n", "101/250 | 7.2841e-03/1.9751e-03 [####################] (0s)\n", "102/250 | 7.1810e-03/1.9751e-03 [####################] (0s)\n", "103/250 | 5.0848e-03/1.9751e-03 [####################] (0s)\n", "104/250 | 3.0136e-03/1.9751e-03 [####################] (0s)\n", "105/250 | 2.4440e-03/1.9751e-03 [####################] (0s)\n", "106/250 | 3.4854e-03/1.9751e-03 [####################] (0s)\n", "107/250 | 3.3913e-03/1.9751e-03 [####################] (0s)\n", "108/250 | 1.9056e-03/1.9056e-03 [####################] (0s)\n", "109/250 | 2.3935e-03/1.9056e-03 [####################] (0s)\n", "110/250 | 2.2171e-03/1.9056e-03 [####################] (0s)\n", "111/250 | 2.8577e-03/1.9056e-03 [####################] (0s)\n", "112/250 | 2.0798e-03/1.9056e-03 [####################] (0s)\n", "113/250 | 2.2106e-03/1.9056e-03 [####################] (0s)\n", "114/250 | 2.2582e-03/1.9056e-03 [####################] (0s)\n", "115/250 | 1.9352e-03/1.9056e-03 [####################] (0s)\n", "116/250 | 3.8784e-03/1.9056e-03 [####################] (0s)\n", "117/250 | 2.3278e-03/1.9056e-03 [####################] (0s)\n", "118/250 | 6.5267e-03/1.9056e-03 [####################] (0s)\n", "119/250 | 7.0995e-03/1.9056e-03 [####################] (0s)\n", "120/250 | 2.6266e-03/1.9056e-03 [####################] (0s)\n", "121/250 | 4.7335e-03/1.9056e-03 [####################] (0s)\n", "122/250 | 7.1512e-03/1.9056e-03 [####################] (0s)\n", "123/250 | 4.2358e-03/1.9056e-03 [####################] (0s)\n", "124/250 | 3.5876e-03/1.9056e-03 [####################] (0s)\n", "125/250 | 3.0308e-03/1.9056e-03 [####################] (0s)\n", "126/250 | 4.3498e-03/1.9056e-03 [####################] (0s)\n", "127/250 | 3.6074e-03/1.9056e-03 [####################] (0s)\n", "128/250 | 3.4459e-03/1.9056e-03 [####################] (0s)\n", "129/250 | 2.9632e-03/1.9056e-03 [####################] (0s)\n", "130/250 | 3.1751e-03/1.9056e-03 [####################] (0s)\n", "131/250 | 3.5406e-03/1.9056e-03 [####################] (0s)\n", "132/250 | 2.7615e-03/1.9056e-03 [####################] (0s)\n", "133/250 | 2.3193e-03/1.9056e-03 [####################] (0s)\n", "134/250 | 2.5355e-03/1.9056e-03 [####################] (0s)\n", "135/250 | 2.2808e-03/1.9056e-03 [####################] (0s)\n", "136/250 | 2.4032e-03/1.9056e-03 [####################] (0s)\n", "137/250 | 3.0431e-03/1.9056e-03 [####################] (0s)\n", "138/250 | 2.8142e-03/1.9056e-03 [####################] (0s)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "139/250 | 1.9903e-03/1.9056e-03 [####################] (0s)\n", "140/250 | 1.9782e-03/1.9056e-03 [####################] (0s)\n", "141/250 | 1.9603e-03/1.9056e-03 [####################] (0s)\n", "142/250 | 2.7303e-03/1.9056e-03 [####################] (0s)\n", "143/250 | 5.9771e-03/1.9056e-03 [####################] (0s)\n", "144/250 | 4.0507e-03/1.9056e-03 [####################] (0s)\n", "145/250 | 2.1805e-03/1.9056e-03 [####################] (0s)\n", "146/250 | 1.7523e-03/1.7523e-03 [####################] (0s)\n", "147/250 | 2.2935e-03/1.7523e-03 [####################] (0s)\n", "148/250 | 3.1664e-03/1.7523e-03 [####################] (0s)\n", "149/250 | 3.0230e-03/1.7523e-03 [####################] (0s)\n", "150/250 | 2.4747e-03/1.7523e-03 [####################] (0s)\n", "151/250 | 3.7523e-03/1.7523e-03 [####################] (0s)\n", "152/250 | 5.4125e-03/1.7523e-03 [####################] (0s)\n", "153/250 | 2.5811e-03/1.7523e-03 [####################] (0s)\n", "154/250 | 4.2159e-03/1.7523e-03 [####################] (0s)\n", "155/250 | 3.1262e-03/1.7523e-03 [####################] (0s)\n", "156/250 | 1.7317e-03/1.7317e-03 [####################] (0s)\n", "157/250 | 4.1367e-03/1.7317e-03 [####################] (0s)\n", "158/250 | 4.9401e-03/1.7317e-03 [####################] (0s)\n", "159/250 | 7.3494e-03/1.7317e-03 [####################] (0s)\n", "160/250 | 3.3021e-03/1.7317e-03 [####################] (0s)\n", "161/250 | 4.1837e-03/1.7317e-03 [####################] (0s)\n", "162/250 | 2.0441e-03/1.7317e-03 [####################] (0s)\n", "163/250 | 2.8287e-03/1.7317e-03 [####################] (0s)\n", "164/250 | 3.9031e-03/1.7317e-03 [####################] (0s)\n", "165/250 | 3.3648e-03/1.7317e-03 [####################] (0s)\n", "166/250 | 2.8648e-03/1.7317e-03 [####################] (0s)\n", "167/250 | 2.1758e-03/1.7317e-03 [####################] (0s)\n", "168/250 | 1.5925e-03/1.5925e-03 [####################] (0s)\n", "169/250 | 2.4283e-03/1.5925e-03 [####################] (0s)\n", "170/250 | 2.0659e-03/1.5925e-03 [####################] (0s)\n", "171/250 | 2.6633e-03/1.5925e-03 [####################] (0s)\n", "172/250 | 2.4290e-03/1.5925e-03 [####################] (0s)\n", "173/250 | 3.4595e-03/1.5925e-03 [####################] (0s)\n", "174/250 | 4.2879e-03/1.5925e-03 [####################] (0s)\n", "175/250 | 2.5606e-03/1.5925e-03 [####################] (0s)\n", "176/250 | 1.7552e-03/1.5925e-03 [####################] (0s)\n", "177/250 | 1.8150e-03/1.5925e-03 [####################] (0s)\n", "178/250 | 1.8843e-03/1.5925e-03 [####################] (0s)\n", "179/250 | 2.4673e-03/1.5925e-03 [####################] (0s)\n", "180/250 | 2.7264e-03/1.5925e-03 [####################] (0s)\n", "181/250 | 6.2605e-03/1.5925e-03 [####################] (0s)\n", "182/250 | 2.2474e-03/1.5925e-03 [####################] (0s)\n", "183/250 | 5.4456e-03/1.5925e-03 [####################] (0s)\n", "184/250 | 6.7805e-03/1.5925e-03 [####################] (0s)\n", "185/250 | 5.3535e-03/1.5925e-03 [####################] (0s)\n", "186/250 | 2.3965e-03/1.5925e-03 [####################] (0s)\n", "187/250 | 4.7688e-03/1.5925e-03 [####################] (0s)\n", "188/250 | 3.9332e-03/1.5925e-03 [####################] (0s)\n", "189/250 | 2.9458e-03/1.5925e-03 [####################] (0s)\n", "190/250 | 2.7969e-03/1.5925e-03 [####################] (0s)\n", "191/250 | 3.7599e-03/1.5925e-03 [####################] (0s)\n", "192/250 | 2.3289e-03/1.5925e-03 [####################] (0s)\n", "193/250 | 2.3922e-03/1.5925e-03 [####################] (0s)\n", "194/250 | 1.4419e-03/1.4419e-03 [####################] (0s)\n", "195/250 | 1.7286e-03/1.4419e-03 [####################] (0s)\n", "196/250 | 2.6265e-03/1.4419e-03 [####################] (0s)\n", "197/250 | 5.1592e-03/1.4419e-03 [####################] (0s)\n", "198/250 | 2.0504e-03/1.4419e-03 [####################] (0s)\n", "199/250 | 1.8579e-03/1.4419e-03 [####################] (0s)\n", "200/250 | 5.1812e-03/1.4419e-03 [####################] (0s)\n", "201/250 | 5.8354e-03/1.4419e-03 [####################] (0s)\n", "202/250 | 3.6949e-03/1.4419e-03 [####################] (0s)\n", "203/250 | 2.4144e-03/1.4419e-03 [####################] (0s)\n", "204/250 | 2.9960e-03/1.4419e-03 [####################] (0s)\n", "205/250 | 4.1210e-03/1.4419e-03 [####################] (0s)\n", "206/250 | 6.0538e-03/1.4419e-03 [####################] (0s)\n", "207/250 | 5.2219e-03/1.4419e-03 [####################] (0s)\n", "208/250 | 3.7427e-03/1.4419e-03 [####################] (0s)\n", "209/250 | 4.7810e-03/1.4419e-03 [####################] (0s)\n", "210/250 | 5.7836e-03/1.4419e-03 [####################] (0s)\n", "211/250 | 2.9323e-03/1.4419e-03 [####################] (0s)\n", "212/250 | 3.8641e-03/1.4419e-03 [####################] (0s)\n", "213/250 | 4.2803e-03/1.4419e-03 [####################] (0s)\n", "214/250 | 2.0412e-03/1.4419e-03 [####################] (0s)\n", "215/250 | 1.6764e-03/1.4419e-03 [####################] (0s)\n", "216/250 | 2.0434e-03/1.4419e-03 [####################] (0s)\n", "217/250 | 3.2198e-03/1.4419e-03 [####################] (0s)\n", "218/250 | 7.8106e-03/1.4419e-03 [####################] (0s)\n", "219/250 | 3.3164e-03/1.4419e-03 [####################] (0s)\n", "220/250 | 3.7824e-03/1.4419e-03 [####################] (0s)\n", "221/250 | 4.9126e-03/1.4419e-03 [####################] (0s)\n", "222/250 | 7.0033e-03/1.4419e-03 [####################] (0s)\n", "223/250 | 7.2956e-03/1.4419e-03 [####################] (0s)\n", "224/250 | 5.2377e-03/1.4419e-03 [####################] (0s)\n", "225/250 | 3.7460e-03/1.4419e-03 [####################] (0s)\n", "226/250 | 7.7423e-03/1.4419e-03 [####################] (0s)\n", "227/250 | 9.5499e-03/1.4419e-03 [####################] (0s)\n", "228/250 | 2.0417e-02/1.4419e-03 [####################] (0s)\n", "229/250 | 7.7284e-03/1.4419e-03 [####################] (0s)\n", "230/250 | 6.8024e-03/1.4419e-03 [####################] (0s)\n", "231/250 | 6.3797e-03/1.4419e-03 [####################] (0s)\n", "232/250 | 5.5671e-03/1.4419e-03 [####################] (0s)\n", "233/250 | 5.5593e-03/1.4419e-03 [####################] (0s)\n", "234/250 | 7.8933e-03/1.4419e-03 [####################] (0s)\n", "235/250 | 1.3514e-02/1.4419e-03 [####################] (0s)\n", "236/250 | 1.0794e-02/1.4419e-03 [####################] (0s)\n", "237/250 | 5.7856e-03/1.4419e-03 [####################] (0s)\n", "238/250 | 2.7149e-03/1.4419e-03 [####################] (0s)\n", "239/250 | 2.0554e-03/1.4419e-03 [####################] (0s)\n", "240/250 | 3.3540e-03/1.4419e-03 [####################] (0s)\n", "241/250 | 5.6794e-03/1.4419e-03 [####################] (0s)\n", "242/250 | 3.3093e-03/1.4419e-03 [####################] (0s)\n", "243/250 | 3.2271e-03/1.4419e-03 [####################] (0s)\n", "244/250 | 3.0530e-03/1.4419e-03 [####################] (0s)\n", "245/250 | 1.7806e-03/1.4419e-03 [####################] (0s)\n", "246/250 | 1.5051e-03/1.4419e-03 [####################] (0s)\n", "247/250 | 3.2142e-03/1.4419e-03 [####################] (0s)\n", "248/250 | 2.2576e-03/1.4419e-03 [####################] (0s)\n", "249/250 | 1.9008e-03/1.4419e-03 [####################] (0s)\n", "250/250 | 2.2744e-03/1.4419e-03 [####################] (0s)\n", "\n", "========================================\n", "Total epochs: 250\n", "(best epoch: 193)\n", "(best validation loss: 1.4419E-03)\n", "========================================\n" ] } ], "source": [ "# using the same batch key as with the neural network\n", "model.train(\n", " X_train_sc_32,\n", " Y=Y_train_sc_32,\n", " X_val=X_val_sc_32,\n", " Y_val=Y_val_sc_32,\n", " lr=1e-2,\n", " epochs=250,\n", " batch_size=5,\n", " batch_rng=batch_key,\n", " verbose=True)" ] }, { "cell_type": "markdown", "id": "d5e2e4d9", "metadata": {}, "source": [ "### Inference\n", "\n", "Now we're ready to make new predictions with the model. All Models and Modules are directly callable once compiled, so we need only to pass in the (scaled) inputs like `model(X)` and unscale the predictions." ] }, { "cell_type": "code", "execution_count": 18, "id": "2b1bc51e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# pass in the scaled inputs (converted to JAX arrays again)\n", "X_sc = xscaler.transform(X)\n", "X_sc = np.array(X_sc)\n", "\n", "# get scaled predictions\n", "Y_pred_sc = model(X_sc)\n", "\n", "# unscale predictions\n", "Y_pred = yscaler.inverse_transform(Y_pred_sc)\n", "\n", "# plot results\n", "plt.plot(X, Y_pred, label=\"PMM\", lw=2, zorder=10)\n", "plt.plot(X_train, Y_train, 'o', label=\"Train\")\n", "plt.plot(X_val, Y_val, 's', label=\"Val\")\n", "plt.plot(X_test, Y_test, 'x', label=\"Test\")\n", "plt.legend()" ] } ], "metadata": { "kernelspec": { "display_name": "pmmenv", "language": "python", "name": "pmmenv" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }