diff --git a/docs/API/recurrent.rst b/docs/API/recurrent.rst index 2fe915d..b6e7e64 100644 --- a/docs/API/recurrent.rst +++ b/docs/API/recurrent.rst @@ -3,8 +3,6 @@ Recurrent .. currentmodule:: serket.nn -.. autoclass:: RNNCell - .. autoclass:: LSTMCell .. autoclass:: GRUCell .. autoclass:: SimpleRNNCell @@ -24,7 +22,4 @@ Recurrent .. autoclass:: FFTConvGRU2DCell .. autoclass:: FFTConvGRU3DCell -.. autoclass:: ScanRNN - - -.. autofunction:: scan_rnn \ No newline at end of file +.. autofunction:: scan_cell \ No newline at end of file diff --git a/docs/notebooks/train_bilstm.ipynb b/docs/notebooks/train_bilstm.ipynb index 8dacf2e..f9fc894 100644 --- a/docs/notebooks/train_bilstm.ipynb +++ b/docs/notebooks/train_bilstm.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -37,10 +37,12 @@ "import optax # for gradient optimization\n", "import serket as sk\n", "import functools as ft\n", + "from typing_extensions import Annotated\n", "import time\n", "\n", "EPOCHS = 100\n", - "LR = 1e-3" + "LR = 1e-3\n", + "Input = Annotated[jax.Array, \"Float[seq_len, input_dim]\"]" ] }, { @@ -52,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -91,26 +93,32 @@ " key: jax.Array,\n", " ):\n", " k1, k2, k3 = jr.split(key, 3)\n", - " self.rnn1 = sk.nn.ScanRNN(\n", - " cell=sk.nn.LSTMCell(in_features, hidden_dim, key=k1),\n", - " backward_cell=sk.nn.LSTMCell(in_features, hidden_dim, key=k2),\n", - " return_sequences=True, # return all outputs of the sequence\n", - " )\n", - " self.rnn2 = sk.nn.ScanRNN(\n", - " # in_features is hidden_dim*2 (for each cell from previous layer)\n", - " cell=sk.nn.LSTMCell(hidden_dim * 2, out_features, key=k3)\n", - " )\n", + " self.cell1 = sk.nn.LSTMCell(in_features, hidden_dim, key=k1)\n", + " self.cell2 = sk.nn.LSTMCell(in_features, hidden_dim, key=k2)\n", + " self.cell3 = sk.nn.LSTMCell(hidden_dim * 2, out_features, key=k3)\n", "\n", - " def __call__(self, x):\n", - " return self.rnn2(self.rnn1(x))\n", + " def __call__(self, input: Input) -> Input:\n", + " # initialize the states of the cells\n", + " state = sk.tree_state(self)\n", + " # run the forward cell\n", + " output1, state1 = sk.nn.scan_cell(self.cell1)(input, state.cell1)\n", + " # run the backward cell\n", + " output2, state2 = sk.nn.scan_cell(self.cell2, reverse=True)(input, state.cell2)\n", + " # concatenate the outputs\n", + " output = jnp.concatenate((output1, output2), axis=1)\n", + " # run the final cell\n", + " output, state3 = sk.nn.scan_cell(self.cell3)(output, state.cell3)\n", + " # return the last time step\n", + " return output[-1]\n", "\n", "\n", - "nn = BiLstm(1, 64, 1, key=jax.random.PRNGKey(0))\n", + "key = jax.random.PRNGKey(0)\n", + "net = BiLstm(1, 64, 1, key=key)\n", "# 1) mask the non-jaxtype parameters\n", - "nn = sk.tree_mask(nn)\n", + "net = sk.tree_mask(net)\n", "# 2) initialize the optimizer state\n", "optim = optax.adam(LR)\n", - "optim_state = optim.init(nn)" + "optim_state = optim.init(net)" ] }, { @@ -122,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -132,26 +140,26 @@ "\n", "\n", "@ft.partial(jax.grad, has_aux=True)\n", - "def loss_func(nn, x, y):\n", + "def loss_func(net: BiLstm, x: jax.Array, y: jax.Array):\n", " # pass non-jaxtype over jax transformation\n", " # using `tree_mask`/`tree_unmask` scheme\n", " # 3) unmask the non-jaxtype parameters to be used in the computation\n", - " nn = sk.tree_unmask(nn)\n", + " net = sk.tree_unmask(net)\n", " # 4) vectorize the computation over the batch dimension\n", " # and get the logits\n", " # here we dont vectorize over state argument so we use `None`\n", - " logits = jax.vmap(nn)(x)\n", + " logits = jax.vmap(net)(x)\n", " # 5) use the appropriate loss function\n", " loss = mse(logits, y)\n", " return loss, (loss, logits)\n", "\n", "\n", "@jax.jit\n", - "def train_step(nn, optim_state, x, y):\n", - " grads, (loss, logits) = loss_func(nn, x, y)\n", + "def train_step(net: BiLstm, optim_state: optax.OptState, x: jax.Array, y: jax.Array):\n", + " grads, (loss, logits) = loss_func(net, x, y)\n", " updates, optim_state = optim.update(grads, optim_state)\n", - " nn = optax.apply_updates(nn, updates)\n", - " return nn, optim_state, (loss, logits)" + " net = optax.apply_updates(net, updates)\n", + " return net, optim_state, (loss, logits)" ] }, { @@ -163,29 +171,29 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.632614e-03\tTime: 0.019\r" + "Epoch: 100/100\tBatch: 100/100\tBatch loss: 2.065103e-03\tTime: 0.022\r" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -198,7 +206,7 @@ "for i in range(1, EPOCHS + 1):\n", " t0 = time.time()\n", " for j, (xb, yb) in enumerate(zip(x_train, y_train)):\n", - " nn, optim_state, (loss, logits) = train_step(nn, optim_state, xb, yb)\n", + " net, optim_state, (loss, logits) = train_step(net, optim_state, xb, yb)\n", " print(\n", " f\"Epoch: {i:003d}/{EPOCHS:003d}\\t\"\n", " f\"Batch: {j+1:003d}/{len(x_train):003d}\\t\"\n", @@ -209,10 +217,10 @@ "\n", "\n", "# 6) un-mask the trained network\n", - "eval_nn = sk.tree_unmask(nn)\n", + "eval_net = sk.tree_unmask(net)\n", "\n", "\n", - "y_pred = jax.vmap(eval_nn)(x_train.reshape(-1, 2, 1))\n", + "y_pred = jax.vmap(eval_net)(x_train.reshape(-1, 2, 1))\n", "plt.plot(x[1:], y[1:], \"--k\", label=\"data\")\n", "plt.plot(x[1:], y_pred, label=\"prediction\")\n", "plt.legend()" diff --git a/docs/notebooks/train_convlstm.ipynb b/docs/notebooks/train_convlstm.ipynb index f4bc827..f0a3aa5 100644 --- a/docs/notebooks/train_convlstm.ipynb +++ b/docs/notebooks/train_convlstm.ipynb @@ -25,28 +25,9 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "YABV1l4q2HR9", - "outputId": "2c1eb5c7-2a0e-4542-ba83-c1baf3cfdc60" + "id": "YABV1l4q2HR9" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.8/57.8 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for serket (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], + "outputs": [], "source": [ "!pip install git+https://github.com/ASEM000/serket --quiet\n", "!pip install optax --quiet\n", @@ -162,7 +143,7 @@ "height": 1000 }, "id": "c_mHfJYd2A68", - "outputId": "bb25c7df-fd0f-410a-f679-a468dba2e5ea" + "outputId": "6cc2ac5e-7add-4429-9d83-89a0894c62bd" }, "outputs": [ { @@ -205,54 +186,27 @@ "outputs": [], "source": [ "class Net(sk.TreeClass):\n", - " def __init__(self, hidden_features: int, *, key: jax.Array):\n", - " k1, k2, k3, k4 = jr.split(key, 4)\n", - " self.convlstm1 = sk.nn.ScanRNN(\n", - " cell=sk.nn.ConvLSTM2DCell(\n", - " in_features=1,\n", - " hidden_features=hidden_features,\n", - " kernel_size=5,\n", - " padding=\"same\",\n", - " key=k1,\n", - " ),\n", - " return_sequences=True,\n", - " )\n", - " self.convlstm2 = sk.nn.ScanRNN(\n", - " cell=sk.nn.ConvLSTM2DCell(\n", - " in_features=hidden_features,\n", - " hidden_features=hidden_features,\n", - " kernel_size=3,\n", - " padding=\"same\",\n", - " key=k2,\n", - " ),\n", - " return_sequences=True,\n", - " )\n", - " self.convlstm3 = sk.nn.ScanRNN(\n", - " cell=sk.nn.ConvLSTM2DCell(\n", - " in_features=hidden_features,\n", - " hidden_features=hidden_features,\n", - " kernel_size=1,\n", - " padding=\"same\",\n", - " key=k3,\n", - " ),\n", - " return_sequences=True,\n", - " )\n", - " self.conv = sk.nn.Conv2D(\n", - " in_features=hidden_features,\n", - " out_features=1,\n", - " kernel_size=3,\n", - " padding=\"same\",\n", - " key=k4,\n", - " )\n", + " def __init__(self, features: int, *, key: jax.Array):\n", + " k1, k2, k3 = jr.split(key, 3)\n", + " self.convlstm1 = sk.nn.ConvLSTM2DCell(1, features, 5, key=k1)\n", + " self.convlstm2 = sk.nn.ConvLSTM2DCell(features, features, 3, key=k2)\n", + " self.conv = sk.nn.Conv2D(features, 1, 3, key=k3)\n", "\n", " def __call__(\n", - " self, input: Annotated[jax.Array, \"f32[F,1,H,W]\"]\n", - " ) -> Annotated[jax.Array, \"f32[F,1,H,W]\"]:\n", - " input = jax.nn.relu(self.convlstm1(input))\n", - " input = jax.nn.relu(self.convlstm2(input))\n", - " input = jax.nn.relu(self.convlstm3(input))\n", - " input = jax.vmap(self.conv)(input) # vectorize over frames\n", - " return jax.nn.sigmoid(input)" + " self, input: Annotated[jax.Array, \"Float[F,1,H,W]\"]\n", + " ) -> Annotated[jax.Array, \"Float[F,1,H,W]\"]:\n", + " # F: number of frames\n", + " # C: number of channels\n", + " # H: height of the frame\n", + " # W: width of the frame\n", + " # initialize state for the cells by passing sample input\n", + " state = sk.tree_state(self, input=input[0])\n", + " output, _ = sk.nn.scan_cell(self.convlstm1)(input, state.convlstm1)\n", + " output, _ = sk.nn.scan_cell(self.convlstm2)(output, state.convlstm2)\n", + " # vectorize convolution over frames\n", + " output = jax.vmap(self.conv)(output)\n", + " # apply sigmoid to get values between 0 and 1\n", + " return jax.nn.sigmoid(output)" ] }, { @@ -272,38 +226,36 @@ "base_uri": "https://localhost:8080/" }, "id": "9igKcrLM2A69", - "outputId": "cfb76f8b-a239-49db-83fe-10994c9fe9b6" + "outputId": "182e3f31-0955-47a8-de86-3c7b7ee94877" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "┌──────────┬───────────────────────┬───────┬────────┐\n", - "│Name │Type │Count │Size │\n", - "├──────────┼───────────────────────┼───────┼────────┤\n", - "│.convlstm1│ScanRNN[ConvLSTM2DCell]│416,256│1.59MB │\n", - "├──────────┼───────────────────────┼───────┼────────┤\n", - "│.convlstm2│ScanRNN[ConvLSTM2DCell]│295,168│1.13MB │\n", - "├──────────┼───────────────────────┼───────┼────────┤\n", - "│.convlstm3│ScanRNN[ConvLSTM2DCell]│33,024 │129.00KB│\n", - "├──────────┼───────────────────────┼───────┼────────┤\n", - "│.conv │Conv2D │577 │2.25KB │\n", - "├──────────┼───────────────────────┼───────┼────────┤\n", - "│Σ │Net │745,025│2.84MB │\n", - "└──────────┴───────────────────────┴───────┴────────┘\n" + "┌──────────┬──────────────┬───────┬────────┐\n", + "│Name │Type │Count │Size │\n", + "├──────────┼──────────────┼───────┼────────┤\n", + "│.convlstm1│ConvLSTM2DCell│105,728│413.00KB│\n", + "├──────────┼──────────────┼───────┼────────┤\n", + "│.convlstm2│ConvLSTM2DCell│73,856 │288.50KB│\n", + "├──────────┼──────────────┼───────┼────────┤\n", + "│.conv │Conv2D │289 │1.13KB │\n", + "├──────────┼──────────────┼───────┼────────┤\n", + "│Σ │Net │179,873│702.63KB│\n", + "└──────────┴──────────────┴───────┴────────┘\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Loss: 6.7986e-01: 100%|██████████| 1000/1000 [01:20<00:00, 12.39it/s]\n" + "Loss: 6.7987e-01: 100%|██████████| 1000/1000 [00:33<00:00, 30.02it/s]\n" ] } ], "source": [ - "config.hidden_features = 64\n", + "config.features = 32\n", "config.epochs = 1000\n", "config.key = jr.PRNGKey(0)\n", "config.optim = ConfigDict()\n", @@ -317,21 +269,21 @@ "config.optim.scales = [0.5, 0.5, 0.5]\n", "\n", "\n", - "def train(config: ConfigDict)->Net:\n", + "def train(config: ConfigDict) -> Net:\n", " lr = optax.piecewise_constant_schedule(\n", " init_value=config.optim.init_value,\n", " boundaries_and_scales=dict(zip(config.optim.boundaries, config.optim.scales)),\n", " )\n", " optim = getattr(optax, config.optim.kind)(learning_rate=lr)\n", - " net = sk.tree_mask(Net(hidden_features=config.hidden_features, key=config.key))\n", + " net = sk.tree_mask(Net(features=config.features, key=config.key))\n", " optim_state = optim.init(net)\n", "\n", " print(sk.tree_summary(net, depth=1))\n", "\n", " def loss_func(\n", " net: Net,\n", - " xb: Annotated[jax.Array, \"f32[N,F,1,H,W]\"],\n", - " yb: Annotated[jax.Array, \"f32[N,F,1,H,W]\"],\n", + " xb: Annotated[jax.Array, \"Float[N,F,1,H,W]\"],\n", + " yb: Annotated[jax.Array, \"Float[N,F,1,H,W]\"],\n", " ):\n", " net = sk.tree_unmask(net)\n", " logits = jax.vmap(net)(xb) # vectorize over the batch dimension\n", @@ -341,8 +293,8 @@ " def train_step(\n", " net: Net,\n", " optim_state: Any,\n", - " xb: Annotated[jax.Array, \"f32[N,F,1,H,W]\"],\n", - " yb: Annotated[jax.Array, \"f32[N,F,1,H,W]\"],\n", + " xb: Annotated[jax.Array, \"Float[N,F,1,H,W]\"],\n", + " yb: Annotated[jax.Array, \"Float[N,F,1,H,W]\"],\n", " ):\n", " loss, grads = jax.value_and_grad(loss_func)(net, xb, yb)\n", " updates, optim_state = optim.update(grads, optim_state)\n", @@ -375,10 +327,10 @@ "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 440 + "height": 439 }, "id": "ALAbakDJ2A69", - "outputId": "07b80e11-8e15-4401-e211-006f4dc84bf9" + "outputId": "5697a561-9a81-43d6-8790-8a712b46d34a" }, "outputs": [ { @@ -390,7 +342,7 @@ }, { "data": { - "image/png": "", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWIAAAGVCAYAAADEy/vbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAr7klEQVR4nO3daXBb12H28edeLAR3bNwXUQxFUYqsWJsjL2Eo2bGU6chNo1q108p2Mk7a2O10kg+eaadNv3TSsad139jT2ImmTixLduU4jjxqKkuulspKJFsLKZIiKVJcxQ0EQZAECIBY7nk/2EBEWwsoAToA+fxm7lgGuBweQX9eXBzcqwghBIiISBpV9gCIiBY7hpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMn28H6goSjLHQUS0IMXz5mXuERMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJppc9ACJKXYqiQFVV2Gw2WK1WBINBOBwOzMzMyB7agsIQE9E1qaoKg8EAk8mEbdu24Rvf+AYGBwfxs5/9DE1NTbKHt6Dw0AQRfU50T1in08FoNKK2thYPPfQQ7r33XthsNiiKInuICwr3iCntRaMBAJqmAQCWLl2KNWvWQFVVNDY24vLlyzKHmDaysrKwdu1aVFVVxUJsMpmwcuVK6HQ6WCwWPPTQQygpKbnm5zscDpw9exZut/sOjzy9McSU9nQ6HfR6PYQQiEQiEEJg3bp1+Lu/+zvo9Xq88MIL6OnpiUWari8/Px/f+ta3sG3bttgvN1VVkZOTA4PBgNLSUnz3u9+F3++/5uefPn0aQ0NDDPE8McQpQq/XIysrCzqdDrOzs/D7/RBCyB5WylMUBZmZmcjNzQUAhMNhCCFgt9tRWloKg8GAgoICWK3WWIiFEAgEAteNyWKm1+thtVpRXl5+3fttNtt1P7+4uBh2ux0Wi4VzPA8McYqorKzE9u3bUVZWhv/7v//DwYMHEQgEZA8r5en1ejQ0NODhhx+O7RULIbB8+XLk5uZCVVU88sgjqKmpiYVY0zQcO3YMBw8exOzsrOSfYGGprq7GM888g9HRURw9ehSHDh3iHMeBIU4RxcXF2L59O1avXo1QKIQjR44wxHHQ6/VYt24dvvOd78BkMsVuVxQl9oLSV7/6VdTX18fuC4VCCAQC+N///V9GIsFKS0vx6KOPYnZ2Fj6fD0ePHuUcx+GWQ5yRkYGKigrYbDa43W709/dzwuNkt9uxZMkSGAyG2G2rVq1CXl4e9Ho9ysrKcM8998Dr9cbun5iY4BxfR/TFuugxzWvdf/Wr/Hq9HqWlpdiwYQPcbjcGBgYwMTFxp4a74EVf5OPKivjFHWKdThf7sxACNpsNTz75JDZv3oyTJ0/i5ZdfxuDgYFIGudCsX78ezz77LAoLC2O35ebmoqKiAnq9Hps2bUJtbS0ikUjs/uPHj+Oll17C0NCQjCEvKKqq4qGHHsKKFSswMDCAl156CcePH5c9LFrE4g6xXv+HDxVCICsrC8uWLcOGDRswNjYWe6FJ0zS+yHQTdrsda9euRWlp6TXvLy4uRnFx8ZzbhoeH5zz1plunKEpsji0WCywWi+whLTjcG56fuEP8zDPPxP4shIDFYsGyZcug0+lQXV2NnTt3Ynh4GGfOnEFjY+OcvTkiWnw+e0iIri/uEP/oRz+a8/+qqiIzMxOqqmLFihWoqqrC9PQ0XnzxRTQ3NzPERATgkyDzWfKNxR1is9l83fsMBgMMBgP0ej2Ki4tRUVGBmZkZuN1uvriUIJmZmSgpKUEgEMD09DQ8Ho/sIS0Ier0edrsdFRUV8Pl8mJyc5E5EguTm5qKsrAxTU1OYnJzkmuIbSOjyNaPRiC1btmDJkiXo7e3FG2+8gba2tkR+i0Xri1/8Ip577jmMj4/j3XffxcGDBxmMBLBardi5cycaGhpw5swZ7N27F06nU/aw0p5Op0N9fT1sNhuGhoawd+9enD9/XvawUlZCQ6zX6/GlL30Jq1evRnNzMw4dOsQQJ0h5eTnKy8vh8XjQ0dGBQ4cOMcQJkJOTgwceeADAJ0sy9+/fzxAngKqqWLlyJerq6tDd3Y3jx48zxDeQ0BBHIhF0d3ejp6cHly9f5trMBHI4HOjo6IDL5eJ5ExLI7/ejvb0dDocDjY2NfPqcIEII9Pb2oqurC4ODgxgbG5M9pJSW0BAHg0H89re/xWuvvQav18vJT6DW1lY8//zz6O/vx/j4OPeGE8TlcuGXv/wlPvjgA3g8Hu48JEg4HMaxY8fw6quvYnJyEg6HQ/aQUlrcIf7s220VRYHBYICqqtA0DaFQCDMzMxgZGUFnZyeCwWDCB7uYeTwe9PT0oLu7W/ZQFpRgMIjBwUFcunSJr+wn2MTEBLq6uuD1evkM7ibiDvELL7wQWxeo1+uRn5+PTZs2oa6uDl1dXXj//fcxOjqKU6dOcW+N0g4jnHjREzBxbm8u7hD/27/9W2wv2GAwoKKiAmVlZVi+fDm6urqwa9cuXL58GeFwmCFOAj6YKd0wxPGLO8R+vx+KoiAUCsFgMGBqagr9/f24ePEi+vr64PF4uGY4TlNTU+jq6sLU1BSAP5xTt6ioCBkZGXC73XA6nXPeLj44OMjDPQnkcrngdDoxMDDANdlJwPjOT9whDofDAD5ZGREKhTA8PIzXXnsNBw4cgNPp5JKfeTh79iz+6Z/+CSaTKXamqhUrVuC73/0uqqqqcOzYMezduxc+ny/2gB4bG8P4+LjkkS8MkUgEx44dw549e+ByudDZ2Sl7SLTIxR3iaBAikUgsxryS660ZGRnByMhI7Hi7qqrwer3YsWMHNE1DX18fjh49iunpae5ZJIGmaejt7cWRI0fmnGqUSBaeGF6y6OGH0dFRHDx4EK2trTh37hyCwSAjHAchBMbGxtDW1oacnByYzWZkZGTAaDTCZDIhFAqho6MD/f39CIfDCAQCCAaDOH/+fOxZHiWO0+lEU1MTxsfH0dzczDmOE0MsUfRilwDQ09ODl19+GQaDAV6vl1fniFMkEkFfXx9OnDgBi8WC5cuXw2azIS8vD0ajEX6/H4cPH8aBAwfg8/kwPj6OYDAIr9fL1zSSoL+/H7t27UJ7e3tsrunmGOIUEQwGeZz9FgghYm8eCoVCsQtbhkIhKIoCr9cLh8OB4eFh+P1+jI+PM8BJFH0cj4yMwO/3c/1wnBhiSmuapqG7uxuBQAAZGRk4deoUTCYTMjIykJmZGXstY3x8HOFwmE+Vk0zTNASDQczOznKu54EhprSmaRoGBwdjl5C61onIuZb1zhFCIBgMIhgM8v0E88AQ04IQDS2DS+lIEXzkEhFJde3rjxMR0R3DEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExHJJuIEgBs3btzSclNVNbbpdDrxjW98Q7S3twun0yl+8IMfiIyMjNh9Op1OqKqasO8dD16hgyhNqKoKRVEghOBFOeMQna/oBmDOf6ObTqeDTqeLXd1FURRomhb7WE3Tkn7lF4aYKA3Y7XZ85StfQVlZGS5evIjTp0/D7/fLHlbKKi0tRX19PQoLCwFgTowBYNWqVbBYLMjIyMD999+PSCQyJ7bRX3ZerxenTp1CR0dHUsfLEBOlgeLiYjz55JO499578eabb6K5uZkhvoHKykp873vfw5e+9KXPRRgA9Ho9srKyoCgKtmzZgoaGhjn3a5oGTdMwOjqKf/7nf2aIKT2pqor8/HxkZ2djdnYWk5OTCIVCsoeVFkwmEywWC/T6P/zzLC0thc1mg8Vigd1uR3l5ObKysmL3BwKBRT/HVz/mSkpKYvN1rSt7Xy0zMxOZmZlzbovuEYdCIZSUlKCiogKBQAButxvhcDjhY4/74qE3+2GIrpaXl4c/+7M/Q0NDA9rb27F7924MDAzIHlZaWLduHf7iL/4CRUVFsdssFgvWrFmDgoIC9Pb2orm5GYFAIHb/xYsXsXv3bly5ckXGkFNCTk4OHn30UWzevBnFxcVYu3YtrFbrLX89IQQCgQAuXLiA3t5etLS04I033sDg4OC8v048H8RVE9wSvhUUFIhdu3aJSCQijh49KlavXi19TOmyPfLII+Ly5cvx/tMUQghx5MgRsWrVKuljl7lZrVbx05/+VITD4XnNXbwOHz4svvjFL857XPHgoYlPKYqCJUuWoKamBsFgEB0dHRgbG5M9rLSQm5uLFStWwGazxW7Lz89HRUUFFEWBzWbD/fffj7Kystj9MzMznOOrmEwm1NXVoaSkBOvWrZtz2CEeNpst9mJed3c3enp6Ft3Kis+ukEgnDPGndDodNm3ahL/6q7/CxMQEXnjhBUYiTiUlJXjmmWewcePG2G06nQ4FBQVQFAU1NTV47rnnMDs7G3ua1t/fj+eff55z/Cmr1YqnnnoKW7duRXZ2Nux2+7w+f9myZXjuuefg8Xjw2muv4dVXX51z6IJS26IPsaqqMBgMyMjIQHFxMWpra+F0OmG325GZmYlIJIJQKJT0dYTpzGg0oqKiAsuXL7/m/VlZWaiqqppzm06nQ05Ozh0YXXowGAwoLS297hzeTHSOZ2dnYbPZoKp802w6WfQhrqiowNe//nWUlZVh48aNMJlMsFqt2L59O1atWoW2tjYcOnQIk5OTsoe6oKTj08dk45yktmT+/TDEFRV46qmncPfdd0On00Gv1yMjIwPf/OY3EYlEsH//fpw+fZohprSwWGN+p44NR79Pop8hL8oQ63Q6FBYWwmKxoLq6Gnl5ecjIyIjdrygKDAYDDAYDbDYbamtrkZmZCafTCZfLJXHkqUNRFBQUFMBqtaKmpgbZ2dnz+nyj0Yjy8nKsXLkS09PTcDgci3INrNVqRUFBASoqKpCXl3fbX09RFBQWFqKurg7T09MYHR2F1+tNwEhTV35+PoqKimC32+NaN3yrsrOzUV1djXA4DJfLhfHx8YQFeVGuI87Ly8O3v/1tbN26FTabDXV1dcjNzb3mxzocDly6dAkTExPYs2cP9u/fj0gkcodHnHoyMjLw+OOPY/v27bBarairq5vXmk2fz4dLly5hbGwMH374IXbt2rXoXrjT6XR45JFHsHPnTthsNixfvnzO2uFbIYRAb28venp6MDAwgJ/97Gf4+OOPEzTi1PTggw/iO9/5DoqKilBbW4uKioqkfB+Xy4W2tjZMTEzg17/+Nfbt24dgMHjTz4snsYtyj9hoNGLlypX42te+FjsxyPUUFRWhsLAQ09PTOHHixIL6hXQ7dDodli1bhoceeggZGRnznpesrCysWbMGQghMT0/DZDIlaaSpK7pkctOmTcjPz0/IY0tRFFRXV2Pp0qXo7OzEu+++m4CRpraSkhLU19ejpKQkqS9SWq1W3HvvvfD7/WhqaoJOp0vY116UIZ6dncXp06eRmZmJoqIibNiwARaL5ZofOzAwgHPnzmF8fBwdHR2Lbm3m9YTDYTQ1NeGtt95CQUEB1q9fj+Li4rg/3+Px4Ny5cxgcHMTHH3+MmZmZJI42NWmaho6ODrz99tsoKCjAunXrUFlZedtfs62tDS0tLRgaGsLQ0FCCRpu6ent78e6776KoqAhr1qxBbW1tUr7P6Ogozpw5g7GxMVy4cCGxz4zjfVcJUuCdM4naVFUVZrNZlJWViT/5kz8RLS0t1/25f/vb34ovf/nLoqSkROTk5Egfe6psiqKIvLw8UVpaKh588EHx4YcfxvtQEkII0dvbK5566ilRVlYmrFar0Ol00n8mGVtOTo4oKSkR99xzjzhw4MC85vBaZmdnxUsvvSRqampEUVGRMJlM0n/GZG9ZWVmiuLhYrFq1SuzevVtomnbb83gtJ06cEJs2bRJlZWUiLy9PKIoS1/jisSj3iDVNw+TkJCYnJ+FwOOByueB2u2EymZCZmQlN0+Dz+RAMBuF0OjEyMoKRkRHZw04p4tNDCtPT07BarXEdK7ta9AWPxbDHdiNerxderxdGozFhb8DweDwYGRlZNM8yfD4ffD4fQqEQZmZmIIRIyiHE2dlZOJ1ODA8Pc9VEovX19eGVV15BcXExNm3ahC1btsDr9eI3v/kNLly4gO7ubrjdbtnDJIqL+PSsYYkOBX1CCJGUuV30IR4eHsavfvUrGI1GmEwmbN68GVNTUzh48CDee++9pE38Ysc5/TzOSWpL5t/Pog8x8MmhinA4HHvhyO12w+Vy8YW5OPn9frS1tcXO6Rpdh11RURFbcdLX1wefzxd7MA8ODvKZxlVmZ2dx6dIlnDp1ClarFUuWLJnXSpLp6Wn09vZiamoKg4ODfOymGYb4U5FIBEePHkVHRwdCoRDPnTsPw8PD+I//+I85b0iwWCz4/ve/j23btqGrqwv//u//jq6urtj9gUCAc3yViYkJvP766zhw4ADq6+vxt3/7tygvL4/78zs7O/Hiiy+iq6sLo6Oj8z5mT3IxxJ8SQvBFuVvk9/s/dymZgoICOBwOhMNhuN1utLS0oLm5WdIIU18wGMTly5cBfLJ2fWZmBuFweM7FLqNrZKPHga8WnePW1tY7O/AUEp2XcDgcuyBoIkQvm/TZ69olEkNMSeH3+3Hs2DH4fD5cvnwZ4+PjsoeUNnp7e7Fnzx7Y7fZYiAsKCuZcPPTUqVNzrlnX2dm56N9+HwgEcPz4cQSDQZSXl8+5eOit8vl8+N3vfof29nZ0dHRgYmIiQaP9jHjX0CEF1gtyS59NURSRlZUlzGazyMnJWbTrhG9lMxgMIi8vT5jNZmGxWITFYhH19fXixIkTIhwOi9dff13U1NQIi8UizGZzbI5VVZU+dpmboigiMzNTmM1m8fWvf100NTXFm7frGh0dFd///veFxWK55TmOB/eIKSmEELH1nTQ/oVBozgmQFEWBy+XCyMgIrly5gtHRUbjdbr7Y+RlCCPj9fvj9fjidTgwNDX3u/CfRZxgmkwlmsxmqqmJ6ejp2YiTxmUMPTqcTTqcz6XO9KE/6Q5RuzGYz1qxZg8LCQnR3d6OlpQWzs7Oyh5WyCgoKbnjx0PXr1+Oxxx5Dbm4u3nvvPRw+fPiaV2eOnleir6/vlscSV2Lj3UVHCjz14MaNG7fb2RRFEaqqij/+4z8W7e3twul0ih/84AfCaDTG/Zbl+W48NEFE9BlCCDgcDpw4cQL5+fno6+uT/m5EHpogokUnJycHdrsder0+dq6ZZIknsQwxEVESxZNYXuqViEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDKGmIhIMoaYiEgyhpiISDJ9vB8ohEjmOIiIFi3uERMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRScYQExFJxhATEUnGEBMRySbiBIAbN27cpG6KoghVVYWqqkKn0wmdTid27twp+vr6xNDQkPje974n9Hr9nPtVVZU65njEfYUOIkocVVWhKAo0TePVb+IQna/oBiD2Z1VVY/+vqip0Ol3s4wFA0zQoigIhRGxLNQwx0R1WXFyM+vp62Gw2tLS04OOPP0YwGJQ9rJRVWVmJ+vp6WK3WORGO/nft2rXIzc2FXq/Hpk2bkJmZOSe2QghomoapqSmcPHkSPT09Un6OG2GIie6w8vJyPP3007jrrruwa9cuXLhwgSG+gS984Qt49tlnUVtbO2ePOMpoNMJkMkFRFGzbtg0PP/zwnPs1TYOmaejv74fL5WKIKT3pdDqYzWZkZmYiEAhgcnIS4XBY9rDSQlZWFvLz86HX/+GfWllZGWw2GywWCwoLC1FWVoaZmZnY/X6/H263G5FIRMaQU4JOp0N+fj6ysrLmzNdnI/xZmZmZyMzMnHNbdI/Y6/WipKQEFRUV8Pl8mJycTJk5VkScB0xuNgG0cNlsNvz5n/85Nm7ciPPnz2PPnj0YHR2VPay0cP/99+Oxxx6DzWaL3Wa327F27VqYzWZ0dXWhtbUVoVAodv+5c+fwxhtvYGxsTMaQU4LFYsG3vvUt3HfffSgtLcXatWuRl5d3y19PCAGv14umpiYMDg7izJkz2Lt37x2Z47gSy1UT3G62VVZWin379gkhhDhw4ICoqamRPqZ02R5//HFx5cqVeP+ZCSGE2L9/v/jCF74gfewyt9LSUvHGG28ITdPmNXfx+s1vfiOqq6vvyM8SjwV5aEJVVVRXV2Pp0qXw+Xxob2/HxMSE7GGlBbPZjLq6OlgslthtBQUFKCkpAQAUFhaioaEBy5Yti93v8XjQ1tbGOf5UdnY2VqxYgYKCAtx9992fe6p8M0VFRaivr0d1dTW6u7vR29ubkq/0J9NnX5Rb6BZkiI1GI7Zu3Ypvf/vbGBgYwPPPP4/Tp0/LHlZaqKqqwg9/+EOsXr06dpter0dhYSEAYOXKlfj7v/97BIPBWBw6Ozvx4x//GB999JGUMaeaoqIi/OVf/iW+8pWvIC8vD/n5+fP6/FWrVuEf/uEfMDU1hVdeeQW//OUv5xy6WCyiy9IWgwUVYlVVYTAYkJWVhZKSEtTW1kKv18NisSAzMxPhcHhRPqDnw2QyYcmSJVi+fPk178/JyUFOTs6c24LBILKysu7E8NKC0WhEeXn5defwZqJz7PV6YbPZFs1e4WK2oEJcU1ODrVu3ori4GPfffz+MRiMKCwvx+OOPY+PGjWhqasIHH3wAr9cre6gLCkPxeZyT1JdKf0cLKsTV1dV4+umnY3vCOp0ORUVFeOyxxxCJRLBnzx78/ve/Z4iJUti11gon43tE/5sKx9/TPsR6vR5FRUXIz89HVVUVcnNzkZGREbtfURQYDAYYDAYUFBRg+fLlyM/Px9jYGCYnJ+UNPIWoqorCwkKYzWYsXbp03i8uZWRkoKqqCitXroTb7cbY2FjKrM+8UxRFgc1mg91uR01NDbKzs2/7a+p0OhQWFmLFihXweDwYHR2Fz+dLwGhTV3RtdWlp6W0tV7uZnJwc1NTUwGAwwOl0wuVyJe17xSPt1xHb7XY8/fTTaGhoQEFBAerq6q57vHJ4eBidnZ1wOp34xS9+gffffz8lfhvKlp2djSeeeALbtm2D1WpFXV3dvF5g8nq96OjogMvlwqFDh/CLX/xi0f2SMxgM2LFjB3bs2BGbQ7vdfltfMxKJoKenB319fbh8+TJ++tOforW1NUEjTj2KouCP/uiP8MQTT8R2mqKrdRLN4XDEHrNvvvkm9u/fn7Sdh3gak/Z7xBkZGbjrrrtib2u80S+M0tJSlJSUwOl04vDhw3dqiCnPYDCgrq4OX/va16DT6eb9SzcnJwfr1q2DEAKDg4MwGo1JGmnqii6Z3Lx5M7KzsxOy46LT6VBTU4OamhrY7Xbs27cvASNNXYqioLy8HA0NDbBarUldNVFYWAibzYbp6WmcPHlS+o5m2ofY5/Ph5MmT0DQNZWVlWL9+PXJzc6/5sT09PTh//jzGxsZw+fLlOzzS1BUMBnH27Fm8+eabKCwsxIYNG+a8E+xmJicncfbsWYyMjOD06dMIBAJJHG1q0jQNLS0t2LdvHwoKCrB+/XqUlpbe1tcMh8NoaWlBe3s7enp64HA4EjTa1CSEQFdXF9555x0UFBRg3bp1WLp0aVK+V/TddePj4+jo6JD/zDjed6IgBd5tc61Np9MJi8UiysrKxBNPPCG6u7uv+zO8/fbb4u677xYlJSUiKytL+thTZVMUReTn54vS0lLxyCOPiMbGxngfFkIIIdrb28WOHTtEWVmZMJvN0s//KmvLzc0VpaWl4qtf/ao4duzYvObwWmZmZsSPf/xjUVVVJQoLC4XRaJT+MyZ7y87OFiUlJWLdunXi3Xffve05vJ73339f3HfffaKsrEzk5uYm9WeKR9rvEUciEbjdbrjdboyMjMDlcsXWDZtMJkQiEfh8PoRCITidTgwPDy/q9/BfixACU1NTmJqawtjY2LzPBBad26GhoSSNMD14PB54PB7k5eVhdnb2tr+epmmYnp7G8PDwojk728zMDGZmZqCqKvx+f9K+TyAQgMPhwPDwsPy9YSyAQxNX6+zsxE9+8hMUFhbi4YcfxoMPPgin04lf//rXuHTpEjo6Orh0LQlS4YFMNB/Rx2yqPHYXVIj7+/tx5coVZGVlwWq1oqGhAS6XCwcOHMCRI0dS9uz8tPAk6nHGx2zypNK8LqgQA588nQuFQhgYGMBHH32EK1euwO12Q9M02UNLC16vFy0tLbGlPIqiwGg0YsmSJbDZbHC73ejv70cgEIg9kLu7uzE9PS1z2CnF7/ejo6MD+fn5KCgoQGVl5bxWkrjdbvT19cHtdqfMU2dKrgUXYuCTY5b/8z//g8bGRgQCAQwMDMgeUtro6+vDiy++OOd8EsXFxfibv/kbPPjgg7h48SJ+8pOfzJlTn8/HOb6Kw+HAq6++irfeegtbtmzBX//1X8dOmhSPlpYW/L//9/8wMDCA4eHhRffmmKjFtPO0IEOsaRqGhoYW/YtHt8Lr9aKtrW3ObZWVlXA6nYhEInC5XGhqauLyvxsIBAK4dOkSgE8u8+P3+xEOh+e8rTa6RlZ8evWIq7lcLly4cCElL+lzp0QPyYTDYaiqmrA1xdHLJkUikZQK/YIMMSWWx+PBoUOHMDY2hosXL2Jqakr2kNJGZ2cnXn/9deTn58dCXFJSgvr6etjtdjQ3N+Ojjz6asyqipaVl0c/xzMwMPvjgA0xMTKCqqgr19fVzzpF9K6anp3HixAl0d3ejpaUFHo8nQaNNgHjX3SEF1hhyk7MpiiKysrKE2WwW2dnZi3ad8K1sBoNB5OXlCbPZLCwWi7BYLGLLli3i7NmzYnZ2Vrz88suisrJSWCwWYTabOcefblc/5rZv3y4uXboU9xrh6+nr6xM7d+6843McD+4R000JIeDz+Rb8CWeSIRQKzTkHtqIocLlcGB4ehtlshsPhgNvtTq29sxRw9WNufHwcQ0NDnzsZVfQZRvQCrYqiYGpqKnYhVvGZFzmHhobgdDpT8jwoaX/SH6J0c/XFQzs7O3Hx4kVesOAGSkpKcPfdd8NsNn/uPlVVcd9992HHjh3Q6/V45513cPz48Wse//V6vWhsbMTg4OAdGPUfxJNY7hET3WHj4+M86dQ8jIyMYGRk5HO3R89bHD1rm9FoxJkzZ/Bf//Vf0DQtrZb9McRElLaEEBgaGsLx48eh0+lw5cqVtHwTDA9NEFFay8vLg91uh6IoGB8fT7kVJ/EkliEmIkqieBK7eK5XTUSUohhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJ9PF+oBAimeMgIlq0uEdMRCQZQ0xEJBlDTEQkGUNMRCQZQ0xEJBlDTEQkGUNMRCQZQ0xEJBlDTEQkGUNMRCQZQ0xEJBlDTEQkGUNMRCQZQ0xEJBlDTEQkWdznI1YUJZnjICJakOI5lzv3iImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCRjiImIJGOIiYgkY4iJiCTTyx4AUSIoigIAEEJIHkn6U1U1Np83IoSApml3YEQLH0NMaU1RFNjtdthsNgSDQTgcDszMzMgeVtrKycnBAw88gNra2pt+bFtbG37/+9/D5/PdgZEtbAwxpTVVVVFUVISVK1fC4/HA7/fD5/Nxz/gW5efn49FHH8U3v/nNG36cEAJ79+7FhQsXGOIEYIglUBQFOp0OABCJRBiN26AoCnJyclBcXIzc3Fw4nU4YDIbYnEYiEUxNTXEvOU6qqiI7Oxtms/mGHyeEQGFhIcrKymAwGDjHt4khlsBkMiE/Px8A4PF4+AC+DTqdDnfddRe2b98ORVEwMTEBv98fC7HH48Hbb7+No0eP8hdeAimKgg0bNuBHP/oRxsbGsG/fPhw/fpxzfIsYYgkMBgNycnKgqipmZ2f5VPo2qKqK8vJy3HPPPTCZTJ+7f3x8HOfOncPRo0cljG5hW7p0KZYuXYqxsTGcO3cOx48flz2ktBV3iKurq6FpGjweD9xuN18tvQ2FhYXYuHEjTCYTpqam4PV6AQCapkEIgf7+fnR3dyMcDkseafozGo1YtWoVtm7diomJCbS3t2N6elr2sBaUeFZY0I3FHeLNmzcjEomgvb0d58+fRzAYTOa4FrRVq1bh2WefRVFREcLhcCzA4XAYoVAIb775Jl555RWGOAGys7OxY8cOPPzwwzhz5gz+5V/+BW1tbbKHRTRH3CEuKiqCpmkYHR1FVlYWVFVFKBRCJBJJ5vgWpJycHFRWVqK0tDR229UhLiwsjL2YR7dHp9OhqKgIRUVFcDgc1zx8QSRb3CFuaGgAAKxYsQIPPPAAxsfHcfjwYVy8eDFZY1tUoispNE2DqvINj5R+FEXhax23aN4hvvo45sDAAEOcQKqqQqfTMcSUdhRFiW2M8fzFHWK9fu6HRp9eRxfSj46OIhQKJXyAi42iKLBarairq8PExAQcDgc8Ho/sYS0I2dnZqK6uhs/nw8TEBMbHx/micwLodDqUlJSgrq4OXq8Xo6OjmJ2dlT2stHLLy9esViuefPJJbNmyBadPn8bPf/5zDA8PJ3Jsi5KqqmhoaEB5eTmGhoawa9cunDx5UvawFoTq6mr88Ic/hMvlwnvvvYe9e/fC7/fLHlbay8nJwZ/+6Z/iy1/+Mpqbm/Hzn/8cPT09soeVVm45xCaTCatXr8Zdd92FQCCArKysRI5r0VIUBUuWLEFlZSX6+vqwf/9+2UNaMCwWCzZu3IhgMIi2tja+IJogRqMRK1asQG1tbWyNPM3PLYfY6/Xi/Pnz6O/vR2NjI58+J4gQAh0dHWhubsbw8DCuXLkie0gLxtjYGM6ePYuxsTE0NjZyeWCCBAIBNDY2oqurC+3t7ZicnJQ9pLRzyyGenJzEW2+9hQMHDsDv92NqaiqR41q0IpEITp48iRdffBFut5vzmkC9vb14+eWX0dLSAq/Xy+OYCTIzM4P9+/fjrbfeQiAQYIhvQdwhFkJACIHZ2VkEAgG4XC6Mjo5iaGgomeNblLxeL4aHh/kOsAQLBoNwOp18zCaYpmmYnJzE8PBwbFUVzU/cIfZ4PNA0DadOncKRI0fgdDrR2tqazLEtOtETbfOBnBzRnQlKDs7vrYs7xD6fD+FwGOfOncN//ud/Ynp6mkt/Eiz6QOaDmdINH7e3J+4Qnz17FpqmYWBgAMFgkBG+DdE931AohMHBQYyNjc0530R/fz/fOp4g4XAYAwMDGBsbw8WLF3nKUUpJcYf4H//xHwEATqcTgUAgaQNaDDRNg6ZpmJ6exq9+9Su89957sRPEa5oGh8PBOU6QmZkZvP3229i/fz+mpqa4CiUJuCd8++IOcVNTUxKHsbhE93z9fj96enrw8ccf80odCfLZY+yBQCA2x5zfxBJCIBKJIBKJ8BnybeKJ4SXo6urC7t27odfr0draygfxbYoen3S5XPjwww9je71CCHi9Xp4PJUna29vxu9/9Dg6HAxcvXuQvutvAEEvQ3NyMrq4uKIoCv9/PECeAEAIOhwO7d+/GiRMn5qxAufrSSZQYQgg0NTXhX//1X+FwOPhW8dvEEEsQCoV4gqQEEUJgamoKg4ODGBoagsvlgtvt5qv4tyj6C0zTNMzOzmJqair2DsSr51MIgZGREUxMTPBNRwnAEFNaC4VC+OCDDzA4OAiv14vOzk5G+DYIIRAMBuHz+dDa2op33nkHDofjc6t4hBDo7e2NXeaLbg9DTGktEomgtbWVby5KECEEQqEQgsEgBgYG8N///d/o7u6O7SVTcjDERBQzOzuL1tZW5OXlobGxETMzM3yGcQcoIs4Z5pVaiRY+vV4Pu92OnJwc+Hw+jI+P80LBtymexDLERERJFE9ieXE0IiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIskYYiIiyRhiIiLJGGIiIsniPvsaz75ERJQc3CMmIpKMISYikowhJiKSjCEmIpKMISYikowhJiKSjCEmIpKMISYikowhJiKS7P8D6+vs/P3Fr5MAAAAASUVORK5CYII=", "text/plain": [ "
" ] diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 2539192..2fd28ca 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -61,12 +61,19 @@ def general_linear( out = "".join(str(axis) for axis in range(input.ndim) if axis not in in_axis) out_axis = out_axis if out_axis >= 0 else out_axis + len(out) + 1 out = out[:out_axis] + "F" + out[out_axis:] - result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight) + + try: + einsum = f"{lhs},{rhs}->{out}" + result = jnp.einsum(einsum, input, weight) + except ValueError as error: + raise ValueError(f"{einsum=}\n{input.shape=}\n{weight.shape=}\n{error=}") if bias is None: return result - broadcast_shape = list(range(result.ndim)) - del broadcast_shape[out_axis] + + with jax.ensure_compile_time_eval(): + broadcast_shape = list(range(result.ndim)) + del broadcast_shape[out_axis] bias = jnp.expand_dims(bias, axis=broadcast_shape) return result + bias @@ -306,6 +313,10 @@ class MLP(sk.TreeClass): >>> _, material_layer = lazy_layer.at['__call__'](input) >>> material_layer.in_linear.in_features (10,) + + Note: + :class:`.MLP` uses ``jax.lax.scan`` to reduce the ``jaxpr`` size. + Leading to faster compilation times and smaller ``jaxpr`` size. """ def __init__( diff --git a/serket/_src/nn/normalization.py b/serket/_src/nn/normalization.py index 03412d3..8d4b7b9 100644 --- a/serket/_src/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -211,6 +211,7 @@ class GroupNorm(sk.TreeClass): >>> import serket as sk >>> import jax.numpy as jnp >>> import jax.random as jr + >>> key = jr.PRNGKey(0) >>> layer = sk.nn.GroupNorm(5, groups=1, key=key) >>> input = jnp.ones((5,10)) >>> layer(input).shape @@ -426,13 +427,12 @@ class BatchNorm(sk.TreeClass): .. warning:: Works under - - ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)`` - - ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)`` + - ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)`` otherwise will be a no-op. Training behavior: - - ``output = (x - batch_mean) / sqrt(batch_var + eps)`` + - ``output = (input - batch_mean) / sqrt(batch_var + eps)`` - ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean`` - ``running_var = momentum * running_var + (1 - momentum) * batch_var`` @@ -461,8 +461,9 @@ class BatchNorm(sk.TreeClass): >>> import jax.random as jr >>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0)) >>> state = sk.tree_state(bn) - >>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10)) - >>> x, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state) + >>> key = jr.PRNGKey(0) + >>> input = jr.uniform(key, shape=(5, 10)) + >>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state) Example: Working with :class:`.BatchNorm` with threading the state. @@ -476,19 +477,19 @@ class BatchNorm(sk.TreeClass): ... k1, k2 = jax.random.split(key) ... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1) ... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2) - ... def __call__(self, x, state): - ... x, bn1 = self.bn1(x, state.bn1) - ... x = x + 1.0 - ... x, bn2 = self.bn2(x, state.bn2) + ... def __call__(self, input, state): + ... input, bn1 = self.bn1(input, state.bn1) + ... input = input + 1.0 + ... input, bn2 = self.bn2(input, state.bn2) ... # update the output state ... state = state.at["bn1"].set(bn1).at["bn2"].set(bn2) - ... return x, state + ... return input, state >>> net: ThreadedBatchNorm = ThreadedBatchNorm(key=jr.PRNGKey(0)) >>> # initialize state as the same structure as tree >>> state: ThreadedBatchNorm = sk.tree_state(net) - >>> x = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5) - >>> for xi in x: - ... out, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(xi, state) + >>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5) + >>> for input in inputs: + ... output, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(input, state) Example: Working with :class:`.BatchNorm` without threading the state. @@ -511,28 +512,28 @@ class BatchNorm(sk.TreeClass): ... self.bn1_state = sk.tree_state(self.bn1) ... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2) ... self.bn2_state = sk.tree_state(self.bn2) - ... def _call(self, x): + ... def _call(self, input): ... # this method will raise `AttributeError` if used directly ... # because this method mutates the state ... # instead, use `at["_call"]` to call this method to ... # return the output and updated state in a functional manner - ... x, self.bn1_state = self.bn1(x, self.bn1_state) - ... x = x + 1.0 - ... x, self.bn2_state = self.bn2(x, self.bn2_state) - ... return x - ... def __call__(self, x): - ... return self.at["_call"](x) + ... input, self.bn1_state = self.bn1(input, self.bn1_state) + ... input = input + 1.0 + ... input, self.bn2_state = self.bn2(input, self.bn2_state) + ... return input + ... def __call__(self, input): + ... return self.at["_call"](input) >>> # define a function to mask and unmask the net across `vmap` >>> # this is necessary because `vmap` needs the output to be of inexact - >>> def mask_vmap(net, x): + >>> def mask_vmap(net, input): ... @ft.partial(jax.vmap, out_axes=(0, None)) - ... def forward(x): - ... return sk.tree_mask(net(x)) - ... return sk.tree_unmask(forward(x)) + ... def forward(input): + ... return sk.tree_mask(net(input)) + ... return sk.tree_unmask(forward(input)) >>> net: UnthreadedBatchNorm = UnthreadedBatchNorm(key=jr.PRNGKey(0)) - >>> input = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5) - >>> for xi in input: - ... out, net = mask_vmap(net, xi) + >>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5) + >>> for input in inputs: + ... output, net = mask_vmap(net, input) Note: :class:`.BatchNorm` supports lazy initialization, meaning that the @@ -549,8 +550,9 @@ class BatchNorm(sk.TreeClass): >>> key = jr.PRNGKey(0) >>> lazy_layer = sk.nn.BatchNorm(None, key=key) >>> input = jnp.ones((5,10)) - >>> _ , material_layer = lazy_layer.at['__call__'](input) - >>> output, state = jax.vmap(material_layer, out_axes=(0, None))(input) + >>> _ , material_layer = lazy_layer.at["__call__"](input, None) + >>> state = sk.tree_state(material_layer) + >>> output, state = jax.vmap(material_layer, in_axes=(0, None), out_axes=(0, None))(input, state) >>> output.shape (5, 10) @@ -589,11 +591,8 @@ def __init__( @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__( - self, - input: jax.Array, - state: BatchNormState | None = None, + self, input: jax.Array, state: BatchNormState ) -> tuple[jax.Array, BatchNormState]: - state = sk.tree_state(self) if state is None else state batchnorm_impl = custom_vmap(lambda x, state: (x, state)) momentum, eps = jax.lax.stop_gradient((self.momentum, self.eps)) @@ -622,13 +621,12 @@ class EvalNorm(sk.TreeClass): .. warning:: Works under - - ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)`` - - ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)`` + - ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)`` otherwise will be a no-op. Evaluation behavior: - - ``output = (x - running_mean) / sqrt(running_var + eps)`` + - ``output = (input - running_mean) / sqrt(running_var + eps)`` Args: in_features: the shape of the input to be normalized. @@ -653,10 +651,10 @@ class EvalNorm(sk.TreeClass): >>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0)) >>> state = sk.tree_state(bn) >>> input = jax.random.uniform(jr.PRNGKey(0), shape=(5, 10)) - >>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state) + >>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state) >>> # convert to evaluation mode >>> bn = sk.tree_eval(bn) - >>> output, state = jax.vmap(bn, in_axes=(0, None))(input, state) + >>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0,None))(input, state) Note: If ``axis_name`` is specified, then ``axis_name`` argument must be passed diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index 25c5c9c..c34688b 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -16,13 +16,12 @@ import abc import functools as ft -from contextlib import suppress -from typing import Any, Callable import jax import jax.numpy as jnp import jax.random as jr -import jax.tree_util as jtu +from typing_extensions import ParamSpec +from typing import Callable, Any, TypeVar import serket as sk from serket._src.custom_transform import tree_state @@ -48,6 +47,10 @@ validate_spatial_nd, ) +P = ParamSpec("P") +T = TypeVar("T") +S = TypeVar("S") + State = Any """Defines RNN related classes.""" @@ -73,100 +76,11 @@ class RNNState(sk.TreeClass): hidden_state: jax.Array -class RNNCell(sk.TreeClass): - """Abstract class for RNN cells. - - Subclass this class to define a new RNN cell that can be used with :class:`nn.ScanRNN`. - or :func:`nn.scan_rnn`. - - Subclasses must - - Implement ``__call__`` method that accept an input and a state and returns - tuple of output and new state. - - Define state rule using :func:`serket.tree_state` decorator. - - Define ``spatial_ndim`` attribute that specifies the spatial dimension of - the cell. For non-spatial cells (e.g. :class:`.LSTMCell`), set ``spatial_ndim`` to 0, - for 1D cells (e.g. :class:`.ConvLSTM1DCell` ) set it to 1 and so on. - - Note: - :class:`.ScanRNN` and :func:`.scan_rnn` offers a unified interface for - scanning over time steps of RNN cells. Supports forward and backward - scanning, helpful error messages for wrong input shapes and more. - - Example: - Define a simple ``RNN`` cell that matrix multiplies the input with a ones matrix - and adds the result to the hidden state. - - >>> import serket as sk - >>> import jax - >>> import jax.numpy as jnp - >>> class CustomRNNState(sk.TreeClass): - ... def __init__(self, hidden_state: jax.Array): - ... self.hidden_state = hidden_state - - >>> class CustomRNNCell(sk.nn.RNNCell): - ... def __init__(self, in_features: int, hidden_features: int): - ... self.in_features = in_features - ... self.hidden_features = hidden_features - ... self.in_to_hidden = lambda x: x @ jnp.ones((in_features, hidden_features)) - ... def __call__( - ... self, - ... input: jax.Array, - ... state: CustomRNNState | None = None, - ... ) -> CustomRNNState: - ... # if no state is provided, by default it will be initialized with - ... # rule defined using `sk.tree_state.def_state` below when the cell is - ... # wrapped with `sk.nn.ScanRNN`/`sk.nn.scan_rnn` - ... output = self.in_to_hidden(input) - ... state = CustomRNNState(state.hidden_state + output) - ... return output, state - ... # to validate the shape of the input and give more helpful error message - ... # define the shape of the input input. e.g. in case of Non-spatial RNN - ... # spatial_ndim should be 0, otherwise it should be 1 for 1D (e.g. ConvLSTM1D) - ... # 2 for 2D (e.g. ConvLSTM2D) and so on. - ... spatial_ndim: int = 0 - - >>> # initialize the cell with zeros hidden state - >>> @sk.tree_state.def_state(CustomRNNCell) - ... def custom_rnn_state(cell: CustomRNNCell, **_) -> CustomRNNState: - ... zeros = jnp.zeros((cell.hidden_features,)) - ... return CustomRNNState(hidden_state=zeros) - >>> cell = CustomRNNCell(5, 10) - >>> print(repr(sk.tree_state(cell))) - CustomRNNState(hidden_state=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])) - >>> inputs = jnp.ones((5, 5)) # 5 time steps, 5 features - >>> # 5 time steps will perform 5 steps of matrix multiplication of - >>> # the running hidden state with ones matrix of shape (5, 10) and - >>> # add the result to the hidden state - >>> print(sk.nn.ScanRNN(cell)(inputs)) - [25. 25. 25. 25. 25. 25. 25. 25. 25. 25.] - - This is equivalent to the following code: - - >>> import jax.numpy as jnp - >>> h = jnp.zeros(10) # 10 hidden features initialized with zeros - >>> inputs = jnp.ones((5, 5)) # 5 time steps, 5 input_features - >>> for i in range(5): # the scanning as a python loop - ... h = h + inputs[i] @ jnp.ones((5, 10)) - >>> print(h) - [25. 25. 25. 25. 25. 25. 25. 25. 25. 25.] - """ - - @abc.abstractmethod - def __call__(self, input: jax.Array, state: State) -> tuple[jax.Array, State]: - ... - - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - # 0 for non-spatial, 1 for 1D, 2 for 2D, 3 for 3D etc. - ... - - class SimpleRNNState(RNNState): ... -class SimpleRNNCell(RNNCell): +class SimpleRNNCell(sk.TreeClass): """Vanilla RNN cell that defines the update rule for the hidden state Args: @@ -286,7 +200,7 @@ class DenseState(RNNState): ... -class DenseCell(RNNCell): +class DenseCell(sk.TreeClass): """No hidden state cell that applies a dense(Linear+activation) layer to the input Args: @@ -369,8 +283,8 @@ def __call__( ) -> tuple[jax.Array, DenseState]: if not isinstance(state, DenseState): raise TypeError(f"Expected {state=} to be an instance of `DenseState`") - - h = self.act(self.in_to_hidden(input)) + h = self.in_to_hidden(input) + h = self.act(h) return h, DenseState(h) spatial_ndim: int = 0 @@ -381,7 +295,7 @@ class LSTMState(RNNState): cell_state: jax.Array -class LSTMCell(RNNCell): +class LSTMCell(sk.TreeClass): """LSTM cell that defines the update rule for the hidden state and cell state Args: @@ -513,7 +427,7 @@ class GRUState(RNNState): ... -class GRUCell(RNNCell): +class GRUCell(sk.TreeClass): """GRU cell that defines the update rule for the hidden state and cell state Args: @@ -633,7 +547,7 @@ class ConvLSTMNDState(RNNState): cell_state: jax.Array -class ConvLSTMNDCell(RNNCell): +class ConvLSTMNDCell(sk.TreeClass): @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, @@ -1064,7 +978,7 @@ class ConvGRUNDState(RNNState): ... -class ConvGRUNDCell(RNNCell): +class ConvGRUNDCell(sk.TreeClass): @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, @@ -1471,99 +1385,30 @@ class FFTConvGRU3DCell(ConvGRUNDCell): convolution_layer = FFTConv3D -# Scanning API - - -def materialize_cell(instance, input: jax.Array, state=None, **__) -> RNNCell: - # in case of lazy initialization, we need to materialize the cell - # before it can be passed to the scan function - cell = instance.cell - state = state if state is not None else sk.tree_state(instance, input=input) - state = split_state(state, 2) if instance.backward_cell is not None else [state] - _, cell = cell.at["__call__"](input[0], state[0]) - return cell - - -def materialize_backward_cell(instance, x, state=None, **__) -> RNNCell | None: - if instance.backward_cell is None: - return None - cell = instance.cell - state = state if state is not None else sk.tree_state(instance, input=x) - state = split_state(state, 2) if instance.backward_cell is not None else [state] - _, cell = cell.at["__call__"](x[0], state[-1]) - return cell - - -def is_lazy_init(_, cell, backward_cell=None, **__) -> bool: - lhs = getattr(cell, "in_features", False) is None - rhs = getattr(backward_cell, "in_features", False) is None - return lhs or rhs - - -def is_lazy_call(instance, x, state=None, **_) -> bool: - lhs = getattr(instance.cell, "in_features", False) is None - rhs = getattr(instance.backward_cell, "in_features", False) is None - return lhs or rhs - - -updates = dict(cell=materialize_cell, backward_cell=materialize_backward_cell) - - -def split_state(state: RNNState, splits: int) -> list[RNNState]: - flat_arrays: list[jax.Array] = jtu.tree_leaves(state) - return [type(state)(*x) for x in zip(*(jnp.split(x, splits) for x in flat_arrays))] - - -def concat_state(states: list[RNNState]) -> RNNState: - # undo the split - return ( - states[0] - if len(states) == 1 - else jax.tree_map(lambda *x: jnp.concatenate([*x]), *states) - ) - - -def scan_rnn( - cell: RNNCell, - backward_cell: RNNCell | None, - input: jax.Array, - state: State, - return_sequences: bool = False, - return_state: bool = False, -) -> jax.Array | tuple[jax.Array, State]: - """Scans a RNN cell(s) over a sequence. +def scan_cell(cell, in_axis=0, out_axis=0, reverse=False): + """Scan am RNN cell over a sequence. Args: - cell: the forward RNN cell to scan. - backward_cell: the backward RNN cell to scan. Pass ``None`` for unidirectional RNN. - input: the input sequence. - state: the initial state of the RNN cell. In case of bidirectional RNN, - the forward and backward states are concatenated along the first axis. - return_sequences: whether to return the output for each timestep. Defaults - to ``False``. - return_state: whether to return the final state of the RNN cell(s). Defaults - to ``False``. + cell: the RNN cell to scan. The cell should have the following signature: + `cell(input, state) -> tuple[output, state]` + in_axis: the axis to scan over. Defaults to 0. + out_axis: the axis to move the output to. Defaults to 0. + reverse: whether to scan the sequence in reverse order. Defaults to ``False``. Example: - Unionidirectional RNN: + Unidirectional RNN: >>> import serket as sk >>> import jax >>> import jax.numpy as jnp - - >>> cell = sk.nn.SimpleRNNCell(1, 2, key=jax.random.PRNGKey(0)) + >>> import jax.random as jr + >>> key = jr.PRNGKey(0) + >>> cell = sk.nn.SimpleRNNCell(1, 2, key=key) >>> state = sk.tree_state(cell) - >>> input = jnp.ones([10, 1]) # [time steps, features] - - >>> out = sk.nn.scan_rnn(cell, None, input, state) - >>> print(out.shape) - (2,) - - >>> out = sk.nn.scan_rnn(cell, None, input, state, return_sequences=True) - >>> print(out.shape) + >>> input = jnp.ones([10, 1]) + >>> output, state = sk.nn.scan_cell(cell)(input, state) + >>> print(output.shape) (10, 2) - - >>> out, state = sk.nn.scan_rnn(cell, None, input, state, return_state=True) Example: Bidirectional RNN: @@ -1571,191 +1416,33 @@ def scan_rnn( >>> import serket as sk >>> import jax >>> import jax.numpy as jnp - - >>> cell = sk.nn.SimpleRNNCell(1, 2, key=jax.random.PRNGKey(0)) - >>> back_cell = sk.nn.SimpleRNNCell(1, 2, key=jax.random.PRNGKey(1)) - >>> # concat state of forward and backward cells - >>> concat_state_func = lambda *x: jnp.concatenate([*x]) - >>> state = jax.tree_map(concat_state_func, *sk.tree_state((cell, back_cell))) - >>> input = jnp.ones([10, 1]) # [time steps, features] - - >>> out = sk.nn.scan_rnn(cell, back_cell, input, state) - >>> print(out.shape) - (4,) - - >>> out = sk.nn.scan_rnn(cell, back_cell, input, state, return_sequences=True) - >>> print(out.shape) - (10, 4) - - >>> out, state = sk.nn.scan_rnn(cell, back_cell, input, state, return_state=True) - - Returns: - return the result and state if ``return_state`` is ``True``. otherwise, - return the result. - - Note: - See :class:`.nn.ScanRNN` for a class-based API. - """ - - def accumulate_scan( - cell: RNNCell, - input: jax.Array, - state: State, - reverse: bool = False, - ) -> tuple[jax.Array, State]: - def scan_func(carry, input): - output, state = cell(input, state=carry) - return state, output - - input = jnp.flip(input, axis=0) if reverse else input # flip over time axis - carry, output = jax.lax.scan(scan_func, state, input) - output = jnp.flip(output, axis=-1) if reverse else output - return output, carry - - def unaccumulate_scan( - cell: RNNCell, - input: jax.Array, - state: State, - reverse: bool = False, - ) -> jax.Array: - def scan_func(carry, input): - _, state = cell(input, state=carry) - return state, None - - input = jnp.flip(input, axis=0) if reverse else input - carry, _ = jax.lax.scan(scan_func, state, input) - result = carry.hidden_state - return result, carry - - if backward_cell is None: - scan_func = accumulate_scan if return_sequences else unaccumulate_scan - result, state = scan_func(cell, input, state) - return (result, state) if return_state else result - # bidirectional RNN - lhs_state, rhs_state = split_state(state, splits=2) - scan_func = accumulate_scan if return_sequences else unaccumulate_scan - lhs_result, lhs_state = scan_func(cell, input, lhs_state, False) - rhs_result, rhs_state = scan_func(backward_cell, input, rhs_state, True) - concat_axis = int(return_sequences) - result = jnp.concatenate((lhs_result, rhs_result), axis=concat_axis) - state = concat_state((lhs_state, rhs_state)) - return (result, state) if return_state else result - - -def check_cells(*cells: Any) -> None: - """Checks that the cells are compatible with each other.""" - cell0, *cells = cells - for cell in cells: - if not isinstance(cell, RNNCell): - raise TypeError(f"{cell=} to be an instance of `RNNCell`.") - with suppress(AttributeError): - # if the user has not specified the in_features, we cannot check - # that the cells are compatible - if cell0.in_features != cell.in_features: - raise ValueError(f"{cell0.in_features=} != {cell.in_features=}") - with suppress(AttributeError): - # if the user has not specified the hidden_features, we cannot check - # that the cells are compatible - if cell0.hidden_features != cell.hidden_features: - raise ValueError(f"{cell0.hidden_features=} != {cell.hidden_features=}") - - -class ScanRNN(sk.TreeClass): - """Scans RNN cell over a sequence. - - Args: - cell: the RNN cell to scan. - backward_cell: (optional) the backward RNN cell to scan in case of bidirectional RNN. - return_sequences: whether to return the output for each timestep. - return_state: whether to return the final state of the RNN cell(s). - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk >>> import jax.random as jr - >>> # 10-dimensional input, 20-dimensional hidden state - >>> cell = sk.nn.SimpleRNNCell(10, 20, key=jr.PRNGKey(0)) - >>> rnn = sk.nn.ScanRNN(cell, return_state=True) - >>> input = jnp.ones((5, 10)) # 5 timesteps, 10 features - >>> output, state = rnn(input) + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> cell1 = sk.nn.SimpleRNNCell(1, 2, key=k1) + >>> cell2 = sk.nn.SimpleRNNCell(1, 2, key=k2) + >>> state1 = sk.tree_state(cell1) + >>> state2 = sk.tree_state(cell2) + >>> input = jnp.ones([10, 1]) + >>> output1, state1 = sk.nn.scan_cell(cell1)(input, state1) + >>> output2, state2 = sk.nn.scan_cell(cell2, reverse=True)(input, state2) + >>> output = jnp.concatenate((output1, output2), axis=1) >>> print(output.shape) - (20,) - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk - >>> import jax.random as jr - >>> cell = sk.nn.SimpleRNNCell(10, 20, key=jr.PRNGKey(0)) - >>> rnn = sk.nn.ScanRNN(cell, return_sequences=True, return_state=True) - >>> input = jnp.ones((5, 10)) # 5 timesteps, 10 features - >>> output, state = rnn(input) # 5 timesteps, 20 features - >>> output.shape - (5, 20) + (10, 4) """ - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__( - self, - cell: RNNCell, - backward_cell: RNNCell | None = None, - *, - return_sequences: bool = False, - return_state: bool = False, - ): - if not isinstance(cell, RNNCell): - raise TypeError(f"Expected {cell=} to be an instance of `RNNCell`.") - - if backward_cell is not None: - check_cells(cell, backward_cell) + def scan_func(state, input): + output, state = cell(input, state) + return state, output - self.cell = cell - self.backward_cell = backward_cell - self.return_sequences = return_sequences - self.return_state = return_state + def wrapper(input: T, state: S) -> tuple[T, S]: + # push the scan axis to the front + input = jnp.moveaxis(input, in_axis, 0) + state, output = jax.lax.scan(scan_func, state, input, reverse=reverse) + # move the output axis to the desired location + output = jnp.moveaxis(output, 0, out_axis) + return output, state - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - def __call__( - self, - input: jax.Array, - state: State | None = None, - ) -> jax.Array | tuple[jax.Array, State]: - """Scans the RNN cell over a sequence. - - Args: - input: the input sequence. - state: the initial state. if None, state is initialized by the rule - defined using :func:`.tree_state`. - - Returns: - return the result and state if ``return_state`` is True. otherwise, - return only the result. - """ - - if input.ndim != self.cell.spatial_ndim + 2: - raise ValueError( - f"Expected input to have {(self.cell.spatial_ndim + 2)=} dimensions corresponds to " - f"(timesteps, in_features, {', '.join(['...']*self.cell.spatial_ndim)})." - f"\nGot {input.ndim=} and {input.shape=}." - ) - - with suppress(AttributeError): - # if the user has not specified the in_features, we cannot check - # that the cells are compatible - if self.cell.in_features != input.shape[1]: - raise ValueError( - f"Expected input to have shape (timesteps, {self.cell.in_features}, " - f"{', '.join(['...']*self.cell.spatial_ndim)})." - f"\nGot {input.shape[1]=} and {self.cell.in_features=}." - ) - - return scan_rnn( - self.cell, - self.backward_cell, - input, - tree_state(self, input=input) if state is None else state, - self.return_sequences, - self.return_state, - ) + return wrapper # register state handlers @@ -1782,7 +1469,7 @@ def _(cell: GRUCell) -> GRUState: return GRUState(jnp.zeros([cell.hidden_features])) -def _check_rnn_cell_tree_state_input(cell: RNNCell, input): +def _check_rnn_cell_tree_state_input(cell, input): if not (hasattr(input, "ndim") and hasattr(input, "shape")): raise TypeError( f"Expected {input=} to have `ndim` and `shape` attributes." @@ -1792,9 +1479,9 @@ def _check_rnn_cell_tree_state_input(cell: RNNCell, input): if input.ndim != cell.spatial_ndim + 1: raise ValueError( - f"{input.ndim=} != {(cell.spatial_ndim + 1)=}." - f"Expected input to {type(cell).__name__} to have `shape` (in_features, {'...'*cell.spatial_ndim})." - "Pass a single sample input to `tree_state`" + f"{input.ndim=} != {(cell.spatial_ndim+1)=}.\n" + f"Expected input to {type(cell).__name__} to have `shape` (in_features, {'... '*cell.spatial_ndim}).\n" + f"Pass a single sample input to `tree_state({type(cell).__name__}, input=...)`" ) if len(spatial_dim := input.shape[1:]) != cell.spatial_ndim: @@ -1816,33 +1503,3 @@ def _(cell: ConvGRUNDCell, *, input: Any) -> ConvGRUNDState: input = _check_rnn_cell_tree_state_input(cell, input) shape = (cell.hidden_features, *input.shape[1:]) return ConvGRUNDState(jnp.zeros(shape).astype(input.dtype)) - - -@tree_state.def_state(ScanRNN) -def _(rnn: ScanRNN, input: jax.Array | None = None) -> RNNState: - # the idea here is to combine the state of the forward and backward cells - # if backward cell exists. to have single state input for `ScanRNN` and - # single state output not to complicate the ``__call__`` signature on the - # user side. - input = [None] if input is None else input - # non-spatial cells don't need an input instead - # pass `None` to `tree_state` - # otherwise pass the a single time step input to the cells - return ( - tree_state(rnn.cell, input=input[0]) - if rnn.backward_cell is None - else concat_state(tree_state((rnn.cell, rnn.backward_cell), input=input[0])) - ) - - -@sk.tree_summary.def_type(ScanRNN) -def _(rnn: ScanRNN) -> str: - # display the type of the rnn cell and the type of the cell(s) it scans - # e.g. ScanRNN[SimpleRNNCell] instead of ScanRNN - return ( - f"{type(rnn).__name__}" - + "[" - + f"{type(rnn.cell).__name__}" - + (f",{type(rnn.backward_cell).__name__}" if rnn.backward_cell else "") - + "]" - ) diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index d4b9456..d2ccee8 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -145,10 +145,8 @@ FFTConvLSTM3DCell, GRUCell, LSTMCell, - RNNCell, - ScanRNN, SimpleRNNCell, - scan_rnn, + scan_cell, ) from serket._src.nn.reshape import ( CenterCrop1D, @@ -316,10 +314,8 @@ "FFTConvLSTM3DCell", "GRUCell", "LSTMCell", - "RNNCell", - "ScanRNN", "SimpleRNNCell", - "scan_rnn", + "scan_cell", # reshape "CenterCrop1D", "CenterCrop2D", diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 94cabb2..bf5c1c4 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -1200,7 +1200,15 @@ def test_lazy_conv_local(): @pytest.mark.parametrize( - "sk_layer,keras_layer,kernel_size,strides,padding,dilation,ndim", + ( + "sk_layer", + "keras_layer", + "kernel_size", + "strides", + "padding", + "dilation", + "ndim", + ), [ *product( [sk.nn.Conv1D, sk.nn.FFTConv1D], diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 54fedff..68d70fe 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -12,799 +12,208 @@ # See the License for the specific language governing permissions and # limitations under the License. -# testing against keras -# import tensorflow.keras as tfk -# import tensorflow as tf -# import numpy as np -# from serket._src.nn.recurrent import LSTMCell, ScanRNN +import os -# batch_size = 1 -# time_steps = 2 -# in_features = 3 -# hidden_features=2 +os.environ["KERAS_BACKEND"] = "jax" +from itertools import product -# inputs = np.ones([batch_size,time_steps, in_features]).astype(np.float32) -# inp = tf.keras.Input(shape=(time_steps, in_features)) -# rnn = (tf.keras.layers.LSTM(hidden_features, return_sequences=True, return_state=False))(inp) -# rnn = tf.keras.Model(inputs=inp, outputs=rnn) -# # rnn(inputs) -# w_in_to_hidden = jnp.array(rnn.weights[0].numpy()) -# w_hidden_to_hidden = jnp.array(rnn.weights[1].numpy()) -# b_hidden_to_hidden = jnp.array(rnn.weights[2].numpy()) -# x = jnp.ones([time_steps, in_features]) -# cell = LSTMCell(in_features, hidden_features, recurrent_weight_init="glorot_uniform", bias_init="zeros", -# weight_init="glorot_uniform") -# cell = cell.at["in_to_hidden.weight"].set(w_in_to_hidden) -# cell = cell.at["hidden_to_hidden.weight"].set(w_hidden_to_hidden) -# cell = cell.at["hidden_to_hidden.bias"].set(b_hidden_to_hidden) -# ScanRNN(cell, return_sequences=True)(x) ,rnn(inputs) - -# testing with keras -# inputs = np.ones([batch_size,time_steps, in_features]).astype(np.float32) -# inp = tf.keras.Input(shape=(time_steps, in_features)) -# rnn = tfk.layers.Bidirectional(tf.keras.layers.LSTM(hidden_features, return_sequences=False))(inp) -# rnn = tf.keras.Model(inputs=inp, outputs=rnn) -# # rnn(inputs) -# w_in_to_hidden = jnp.array(rnn.weights[0].numpy()) -# w_hidden_to_hidden = jnp.array(rnn.weights[1].numpy()) -# b_hidden_to_hidden = jnp.array(rnn.weights[2].numpy()) -# x = jnp.ones([time_steps, in_features]) -# cell = LSTMCell(in_features, hidden_features) -# cell = cell.at["in_to_hidden.weight"].set(w_in_to_hidden) -# cell = cell.at["hidden_to_hidden.weight"].set(w_hidden_to_hidden) -# cell = cell.at["hidden_to_hidden.bias"].set(b_hidden_to_hidden) - -# w_in_to_hidden_reverse = jnp.array(rnn.weights[3].numpy()) -# w_hidden_to_hidden_reverse = jnp.array(rnn.weights[4].numpy()) -# b_hidden_to_hidden_reverse = jnp.array(rnn.weights[5].numpy()) -# reverse_cell = LSTMCell(in_features, hidden_features) - -# reverse_cell = reverse_cell.at["in_to_hidden.weight"].set(w_in_to_hidden_reverse) -# reverse_cell = reverse_cell.at["hidden_to_hidden.weight"].set(w_hidden_to_hidden_reverse) -# reverse_cell = reverse_cell.at["hidden_to_hidden.bias"].set(b_hidden_to_hidden_reverse) - - -import jax import jax.numpy as jnp +import jax.random as jr +import keras import numpy.testing as npt import pytest -from serket._src.nn.recurrent import ( - ConvLSTM1DCell, - DenseCell, - FFTConvLSTM1DCell, - GRUCell, - LSTMCell, - ScanRNN, - SimpleRNNCell, -) - -# import pytest +import serket as sk -def test_vanilla_rnn(): +def test_simple_rnn(): + key = jr.PRNGKey(0) + time_step = 3 in_features = 2 hidden_features = 3 - # batch_size = 1 - time_steps = 10 - - # test against keras - # copy weights from keras to serket and compare outputs - # inputs = np.ones([batch_size,time_steps, in_features]).astype(np.float32) - # inp = tf.keras.Input(shape=(time_steps, in_features)) - # rnn = (tf.keras.layers.SimpleRNN(hidden_features, return_sequences=False, return_state=False))(inp) - # rnn = tf.keras.Model(inputs=inp, outputs=rnn) - - x = jnp.ones([time_steps, in_features]).astype(jnp.float32) - - w_in_to_hidden = jnp.array( - [[0.6252413, -0.34832734, 0.6286191], [0.84620893, 0.52448165, 0.13104844]] - ) - - w_hidden_to_hidden = jnp.array( - [ - [-0.24631214, -0.86077654, -0.44541454], - [-0.96763766, 0.24441445, 0.06276101], - [-0.05484254, -0.4464587, 0.893122], - ] - ) - - cell = SimpleRNNCell( - in_features=in_features, - hidden_features=hidden_features, - recurrent_weight_init="glorot_uniform", - key=jax.random.PRNGKey(0), - ) - - w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_hidden_to_hidden"]["weight"].set(w_combined.T) - sk_layer = ScanRNN(cell) - y = jnp.array([0.9637042, -0.8282256, 0.7314449]) - npt.assert_allclose(sk_layer(x), y) + input = jr.uniform(key, (time_step, in_features)) + keras_rnn = keras.layers.SimpleRNN( + hidden_features, + return_sequences=True, + return_state=True, + ) + serket_cell = sk.nn.SimpleRNNCell(in_features, hidden_features, key=key) + keras_output, *keras_state = keras_rnn(input[None]) + keras_output = keras_output[0] # drop batch dimension + i2h, h2h, b = keras_rnn.weights + serket_cell = ( + serket_cell.at["in_hidden_to_hidden"]["weight"] + .set(jnp.concatenate([i2h.numpy().T, h2h.numpy().T], axis=-1)) + .at["in_hidden_to_hidden"]["bias"] + .set(b.numpy()) + ) + serket_rnn = sk.nn.scan_cell(serket_cell) + serket_output, serket_state = serket_rnn(input, sk.tree_state(serket_cell)) + npt.assert_allclose(keras_output, serket_output, atol=1e-6) + npt.assert_allclose(keras_state[0][0], serket_state.hidden_state, atol=1e-6) def test_lstm(): - # tensorflow + key = jr.PRNGKey(0) + time_step = 3 in_features = 2 hidden_features = 3 - # batch_size = 1 - time_steps = 10 - - # inputs = np.ones([batch_size,time_steps, in_features]).astype(np.float32) - # inp = tf.keras.Input(shape=(time_steps, in_features)) - # rnn = (tf.keras.layers.LSTM(hidden_features, return_sequences=False, return_state=False))(inp) - # rnn = tf.keras.Model(inputs=inp, outputs=rnn) - - # w_in_to_hidden = jnp.array(rnn.weights[0].numpy()) - # w_hidden_to_hidden = jnp.array(rnn.weights[1].numpy()) - # b_hidden_to_hidden = jnp.array(rnn.weights[2].numpy()) - - x = jnp.ones([time_steps, in_features]).astype(jnp.float32) - - w_in_to_hidden = jnp.array( - [ - [ - -0.1619612, - -0.17861447, - -0.374527, - 0.21063584, - 0.1806348, - 0.0344786, - 0.44189203, - -0.55044144, - 0.28518462, - -0.09390897, - 0.56036115, - 0.19108337, - ], - [ - 0.03269911, - -0.21127799, - 0.55661833, - -0.6470987, - -0.27472985, - -0.21884575, - 0.2479667, - -0.34201348, - 0.00261247, - -0.6468279, - 0.5003185, - 0.6460693, - ], - ] - ) - - w_hidden_to_hidden = jnp.array( - [ - [ - 0.3196982, - 0.25284654, - -0.18152222, - 0.44958767, - -0.44068673, - -0.19395973, - -0.00905689, - -0.17610262, - 0.21773854, - -0.47118214, - -0.07700437, - 0.24598895, - ], - [ - -0.23678103, - -0.01854092, - -0.15681103, - -0.20309119, - -0.51169145, - 0.33006623, - 0.35155487, - 0.1802753, - -0.08975402, - -0.30867696, - 0.37548447, - -0.3264465, - ], - [ - -0.14270899, - 0.26242012, - -0.31327525, - 0.206014, - 0.5501963, - 0.14983827, - -0.15515868, - 0.2578809, - -0.14565073, - -0.33286166, - 0.4204296, - 0.21370588, - ], - ] - ) - - b_hidden_to_hidden = jnp.array( - [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - ) - - cell = LSTMCell( - in_features=in_features, - hidden_features=hidden_features, - recurrent_weight_init="glorot_uniform", - key=jax.random.PRNGKey(0), - ) - w_combined = jnp.concatenate([w_in_to_hidden.T, w_hidden_to_hidden.T], axis=1) - - cell = cell.at["in_hidden_to_hidden"]["weight"].set(w_combined) - cell = cell.at["in_hidden_to_hidden"]["bias"].set(b_hidden_to_hidden) - - sk_layer = ScanRNN(cell, return_sequences=False) - - y = jnp.array([0.18658024, -0.6338659, 0.3445018]) - npt.assert_allclose(y, sk_layer(x), atol=1e-5) - - w_in_to_hidden = jnp.array( - [ - [ - 0.11943924, - -0.609248, - -0.45503575, - -0.3439762, - -0.33675978, - 0.05291432, - -0.12904513, - -0.22977036, - 0.32492596, - 0.06835997, - 0.0484916, - 0.07520777, - ], - [ - 0.39872873, - -0.08020723, - -0.4879259, - -0.61926323, - -0.45951623, - -0.44556192, - -0.05298251, - 0.54848397, - 0.19754452, - 0.6012858, - -0.06859863, - 0.16502213, - ], - ] - ) - - w_hidden_to_hidden = jnp.array( - [ - [ - 0.18880641, - 0.21262297, - -0.2961502, - 0.33976135, - -0.09891935, - -0.00502901, - 0.34378093, - 0.4202192, - 0.36584634, - 0.08396737, - -0.4975226, - 0.15165171, - ], - [ - 0.30486387, - -0.46795598, - -0.07052832, - 0.51685417, - -0.23734125, - 0.1711132, - 0.16389124, - -0.08915165, - -0.02928232, - -0.2173849, - 0.19655496, - -0.45694238, - ], - [ - -0.1722902, - -0.23029403, - 0.05032581, - 0.21182823, - 0.5298174, - -0.50670344, - -0.18930247, - 0.30799994, - -0.18611868, - -0.08317372, - -0.26286182, - -0.30177474, - ], - ] - ) - - b_hidden_to_hidden = jnp.array( - [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - ) - - cell = LSTMCell( - in_features=in_features, - hidden_features=hidden_features, - recurrent_weight_init="glorot_uniform", - key=jax.random.PRNGKey(0), - ) - - w_combined = jnp.concatenate([w_in_to_hidden.T, w_hidden_to_hidden.T], axis=-1) - - cell = cell.at["in_hidden_to_hidden"]["weight"].set(w_combined) - cell = cell.at["in_hidden_to_hidden"]["bias"].set(b_hidden_to_hidden) - - sk_layer = ScanRNN(cell, return_sequences=True) - - y = jnp.array( - [ - [-0.07431775, 0.05081949, 0.07480226], - [-0.12263095, 0.07622699, 0.1146026], - [-0.15380122, 0.0886446, 0.13589925], - [-0.17376699, 0.0944333, 0.14736715], - [-0.18647897, 0.09689739, 0.1535385], - [-0.19453025, 0.09775667, 0.15683244], - [-0.19960524, 0.09789205, 0.1585632], - [-0.20278986, 0.0977404, 0.15945096], - [-0.20477988, 0.09750732, 0.15989034], - [-0.20601842, 0.09728104, 0.16009602], - ] - ) - - npt.assert_allclose(y, sk_layer(x), atol=1e-5) - - cell = LSTMCell( - in_features=in_features, - hidden_features=hidden_features, - recurrent_weight_init="glorot_uniform", - key=jax.random.PRNGKey(0), - ) - - sk_layer = ScanRNN(cell, return_sequences=True) - assert sk_layer(x).shape == (10, 3) - - -def test_gru(): - w1 = jnp.array( - [ - [ - -0.04667467, - 0.25340378, - 0.26873875, - 0.15961742, - 0.56519365, - 0.46263158, - -0.0030899, - 0.31380886, - 0.44481528, - ] - ] - ) - - w2 = jnp.array( - [ - [ - 0.23404205, - 0.10193896, - 0.27892762, - -0.488236, - -0.4173184, - -0.0588184, - 0.41350085, - 0.36151117, - -0.45407838, - ], - [ - -0.560196, - -0.22648495, - -0.12656957, - 0.31881046, - 0.47110367, - 0.30805635, - 0.41259462, - 0.40002275, - -0.0368616, - ], - [ - 0.5745573, - 0.4343021, - 0.42046744, - -0.09401041, - 0.5539224, - -0.13675115, - -0.5197817, - -0.21241805, - -0.16732433, - ], - ] - ) + input = jr.uniform(key, (time_step, in_features)) + keras_rnn = keras.layers.LSTM( + hidden_features, + return_sequences=True, + return_state=True, + ) + serket_cell = sk.nn.LSTMCell(in_features, hidden_features, key=key) + keras_output, *keras_state = keras_rnn(input[None]) + keras_output = keras_output[0] # drop batch dimension + i2h, h2h, b = keras_rnn.weights + + # serket combines the input and hidden weights + serket_cell = ( + serket_cell.at["in_hidden_to_hidden"]["weight"] + .set(jnp.concatenate([i2h.numpy().T, h2h.numpy().T], axis=-1)) + .at["in_hidden_to_hidden"]["bias"] + .set(b.numpy()) + ) + serket_rnn = sk.nn.scan_cell(serket_cell) + serket_output, serket_state = serket_rnn(input, sk.tree_state(serket_cell)) + npt.assert_allclose(keras_output, serket_output, atol=1e-6) + npt.assert_allclose(keras_state[0][0], serket_state.hidden_state, atol=1e-6) + npt.assert_allclose(keras_state[1][0], serket_state.cell_state, atol=1e-6) - cell = GRUCell(1, 3, bias_init=None, key=jax.random.PRNGKey(0)) - cell = cell.at["in_to_hidden"]["weight"].set(w1.T) - cell = cell.at["hidden_to_hidden"]["weight"].set(w2.T) - y = jnp.array([[-0.00142191, 0.11011646, 0.1613554]]) - ypred = ScanRNN(cell, return_sequences=True)(jnp.ones([1, 1])) - npt.assert_allclose(y, ypred, atol=1e-4) - - -@pytest.mark.parametrize("layer", [ConvLSTM1DCell, FFTConvLSTM1DCell]) -def test_conv_lstm1d(layer): - w_in_to_hidden = jnp.array( - [ - [ - [0.3159187, -0.37110862, 0.23497376], - [0.06916022, 0.16520068, -0.1498835], - ], - [ - [0.13892826, -0.2475906, 0.11548725], - [-0.14935534, 0.0077568, 0.31523505], - ], - [ - [0.20523027, 0.333159, -0.26372582], - [0.21769527, -0.28275424, 0.07145688], - ], - [ - [-0.32436138, 0.17985162, -0.05102682], - [-0.33781663, 0.07652837, 0.14034107], - ], - [ - [-0.2476197, 0.27073297, -0.15494357], - [-0.17142114, 0.0436784, -0.2635818], - ], - [ - [-0.1563589, -0.30193892, -0.3076105], - [0.30359367, -0.37472126, 0.08727607], - ], - [ - [0.02532503, -0.33569914, -0.16816947], - [-0.28197324, -0.20834318, -0.31490648], - ], - [ - [0.37559494, -0.10307714, -0.28350165], - [0.16282192, 0.25434867, 0.14521858], - ], - [ - [-0.3619054, -0.05932748, 0.13838741], - [0.317831, -0.01710135, 0.01839554], - ], - [ - [-0.33236656, -0.15234765, 0.23833898], - [-0.0525074, -0.1169591, 0.22625437], - ], - [ - [0.3350378, 0.3527101, -0.08017969], - [-0.25890553, 0.24611798, 0.30005935], - ], - [ - [-0.07834777, -0.02483597, -0.28757787], - [-0.15855587, 0.14020738, -0.3187018], - ], - ] - ) - - w_hidden_to_hidden = jnp.array( - [ - [ - [0.44095814, 0.12996325, 0.1313585], - [0.18582591, 0.07248487, -0.7859758], - [-0.17839126, 0.15680492, -0.08622836], - ], - [ - [-0.11601712, 0.00761805, 0.43996823], - [0.27362385, 0.0799137, 0.2580722], - [-0.563254, 0.19736156, 0.26167846], - ], - [ - [-0.28901652, -0.25223732, -0.10025343], - [0.56027263, -0.28712046, -0.18524358], - [0.37074035, 0.3996833, 0.1725195], - ], - [ - [0.07441625, 0.20128009, 0.30421543], - [-0.06981394, -0.17527759, 0.22605616], - [0.11372325, 0.63972735, -0.19949353], - ], - [ - [0.08129799, -0.06646754, -0.44094074], - [-0.09799376, 0.16513337, 0.1980969], - [-0.01823295, 0.33500522, 0.19564764], - ], - [ - [-0.4375121, -0.07695349, 0.27423194], - [0.25537497, 0.64107186, -0.09421141], - [0.21401826, -0.15687335, -0.07473418], - ], - [ - [-0.37147775, 0.06210529, -0.04531584], - [-0.38045418, 0.26204777, -0.17553791], - [-0.16380772, 0.39306286, -0.444068], - ], - [ - [-0.08250815, 0.5762788, 0.3014125], - [0.08091379, -0.20550683, 0.06467859], - [0.02479128, -0.16484486, 0.09149422], - ], - [ - [-0.1793791, 0.23342696, -0.33710676], - [0.4355502, -0.23507121, 0.11481185], - [-0.21538775, -0.16292992, -0.6203824], - ], - [ - [-0.1719443, 0.04258863, -0.35778967], - [0.12353352, 0.0826712, -0.10358769], - [-0.55321497, 0.07205058, 0.29797262], - ], - [ - [-0.52755165, 0.27079415, -0.04477403], - [-0.3376618, -0.32239383, -0.3393156], - [0.04485175, -0.04528336, 0.30485243], - ], - [ - [-0.14193594, -0.634814, 0.28351584], - [-0.16348608, -0.4000306, -0.08978741], - [-0.26926947, -0.12314601, -0.19621553], - ], - ] - ) - - b_in_to_hidden = jnp.array( - [ - [0.0], - [0.0], - [0.0], - [1.0], - [1.0], - [1.0], - [0.0], - [0.0], - [0.0], - [0.0], - [0.0], - [0.0], - ] - ) +def test_bilstm(): + key = jr.PRNGKey(0) + time_step = 3 in_features = 2 hidden_features = 3 - time_steps = 1 - spatial_dim = (3,) - - # inputs = np.ones([batch_size,time_steps, in_features,*spatial_dim]).astype(np.float32) - # inp = tf.keras.Input(shape=(time_steps, in_features,*spatial_dim)) - # rnn = (tf.keras.layers.ConvLSTM1D(hidden_features,recurrent_activation="sigmoid", kernel_size=3, padding='same', - # return_sequences=False,data_format='channels_first'))(inp) - # rnn = tf.keras.Model(inputs=inp, outputs=rnn) - - cell = layer( - in_features=in_features, - hidden_features=hidden_features, - recurrent_act="sigmoid", - kernel_size=3, + input = jr.uniform(key, (time_step, in_features)) + keras_rnn = keras.layers.LSTM( + hidden_features, + return_sequences=True, + return_state=True, + ) + + keras_rnn = keras.layers.Bidirectional( + keras.layers.LSTM( + hidden_features, + return_sequences=True, + return_state=False, + ) + ) + + keras_output = keras_rnn(input[None]) + + i2hf, h2hf, bf, i2hb, h2hb, bb = keras_rnn.weights + + i2hf = i2hf.numpy().T + h2hf = h2hf.numpy().T + ih2hf = jnp.concatenate([i2hf, h2hf], axis=-1) + bf = bf.numpy() + i2hb = i2hb.numpy().T + h2hb = h2hb.numpy().T + ih2hb = jnp.concatenate([i2hb, h2hb], axis=-1) + bb = bb.numpy() + + serket_cell = sk.nn.LSTMCell(in_features, hidden_features, key=key) + + forward_cell = ( + serket_cell.at["in_hidden_to_hidden"]["weight"] + .set(ih2hf) + .at["in_hidden_to_hidden"]["bias"] + .set(bf) + ) + backward_cell = ( + serket_cell.at["in_hidden_to_hidden"]["weight"] + .set(ih2hb) + .at["in_hidden_to_hidden"]["bias"] + .set(bb) + ) + + state1 = sk.tree_state(forward_cell) + output1, _ = sk.nn.scan_cell(forward_cell)(input, state1) + state2 = sk.tree_state(backward_cell) + output2, _ = sk.nn.scan_cell(backward_cell, reverse=True)(input, state2) + serket_output = jnp.concatenate([output1, output2], axis=1) + + npt.assert_allclose(keras_output[0], serket_output, atol=1e-6) + + +@pytest.mark.parametrize( + ("sk_layer", "keras_layer", "ndim"), + [ + *product( + [sk.nn.ConvLSTM1DCell, sk.nn.FFTConvLSTM1DCell], + [keras.layers.ConvLSTM1D], + [1], + ), + *product( + [sk.nn.ConvLSTM2DCell, sk.nn.FFTConvLSTM2DCell], + [keras.layers.ConvLSTM2D], + [2], + ), + *product( + [sk.nn.ConvLSTM3DCell, sk.nn.FFTConvLSTM3DCell], + [keras.layers.ConvLSTM3D], + [3], + ), + ], +) +def test_conv_lstm(sk_layer, keras_layer, ndim): + key = jr.PRNGKey(0) + time_step = 3 + in_features = 2 + spatial = [5] * ndim + kernel_size = 3 + hidden_features = 3 + input = jr.uniform(key, (time_step, in_features, *spatial)) + keras_rnn = keras_layer( + hidden_features, + kernel_size, + data_format="channels_first", padding="same", - weight_init="glorot_uniform", - recurrent_weight_init="glorot_uniform", - bias_init="zeros", - key=jax.random.PRNGKey(0), - ) - - cell = cell.at["in_to_hidden"]["weight"].set(w_in_to_hidden) - cell = cell.at["hidden_to_hidden"]["weight"].set(w_hidden_to_hidden) - cell = cell.at["in_to_hidden"]["bias"].set(b_in_to_hidden) - - x = jnp.ones([time_steps, in_features, *spatial_dim]) - - res_sk = ScanRNN(cell, return_sequences=False)(x) - - y = jnp.array( - [ - [-0.19088623, -0.20386685, -0.11864982], - [0.00493522, 0.18935747, 0.16954307], - [0.01413723, 0.00672858, -0.03464129], - ] - ) - - assert jnp.allclose(res_sk, y, atol=1e-5) - - cell = layer( - in_features=in_features, - hidden_features=hidden_features, - recurrent_act="sigmoid", - kernel_size=3, + use_bias=False, + return_sequences=True, + return_state=True, + ) + keras_output, *keras_state = keras_rnn(input[None]) + serket_cell = sk_layer( + in_features, + hidden_features, + kernel_size, padding="same", - weight_init="glorot_uniform", - recurrent_weight_init="glorot_uniform", - bias_init="zeros", - key=jax.random.PRNGKey(0), - ) - - res_sk = ScanRNN(cell, return_sequences=False)(x) - assert res_sk.shape == (3, 3) - - -def test_bilstm(): - # batch_size = 1 - time_steps = 2 - in_features = 3 - hidden_features = 2 - - x = jnp.ones([time_steps, in_features]) - cell = LSTMCell(in_features, hidden_features, key=jax.random.PRNGKey(0)) - reverse_cell = LSTMCell(in_features, hidden_features, key=jax.random.PRNGKey(0)) - - w_in_to_hidden = jnp.array( - [ - [ - -0.6061297, - 0.6038931, - 0.0219295, - -0.53232527, - 0.63680524, - -0.1877076, - 0.5494583, - 0.5319734, - ], - [ - -0.11174804, - 0.1967476, - -0.01281184, - 0.6291546, - -0.10848027, - -0.32045278, - 0.07772851, - -0.07741755, - ], - [ - 0.69948727, - -0.48679155, - 0.39291233, - -0.0054667, - 0.5324392, - 0.62987834, - -0.2530458, - -0.5623743, - ], - ] - ) - - w_hidden_to_hidden = jnp.array( - [ - [ - -0.07784259, - 0.5912869, - -0.08792564, - -0.07326522, - -0.07806911, - -0.75162244, - 0.01986005, - 0.24453232, - ], - [ - 0.23444527, - -0.5768899, - 0.24225983, - -0.23526284, - -0.2299888, - -0.444415, - 0.4977502, - 0.00633401, - ], - ] - ) - - b_hidden_to_hidden = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]) - - w_in_to_hidden_reverse = jnp.array( - [ - [ - 0.28273338, - -0.1472258, - 0.3937468, - 0.34040576, - -0.299861, - -0.38785607, - 0.00533426, - 0.06143087, - ], - [ - -0.40093276, - 0.39314228, - -0.43308863, - 0.532469, - -0.71949875, - 0.16529655, - -0.07926816, - -0.5383911, - ], - [ - -0.0023067, - -0.5820745, - 0.31508905, - 0.29104167, - -0.35113502, - -0.6884494, - 0.14833266, - -0.46562153, - ], - ] - ) - - w_hidden_to_hidden_reverse = jnp.array( - [ - [ - 3.12127233e-01, - 7.36315727e-01, - -1.91057637e-01, - 1.89247921e-01, - 4.54114564e-02, - 6.95739524e-04, - 5.34631252e-01, - 1.43038025e-02, - ], - [ - 3.68674904e-01, - -1.35606900e-01, - -3.05835426e-01, - -1.86572984e-01, - -7.80997992e-01, - 2.84251571e-02, - -1.41527206e-02, - 3.26157391e-01, - ], - ] - ) - - b_hidden_to_hidden_reverse = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]) - - combined_w = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_hidden_to_hidden"]["weight"].set(combined_w.T) - cell = cell.at["in_hidden_to_hidden"]["bias"].set(b_hidden_to_hidden.T) - - combined_w_reverse = jnp.concatenate( - [w_in_to_hidden_reverse, w_hidden_to_hidden_reverse] - ) - reverse_cell = reverse_cell.at["in_hidden_to_hidden"]["weight"].set( - combined_w_reverse.T + key=key, + bias_init=None, + recurrent_act="sigmoid", ) - reverse_cell = reverse_cell.at["in_hidden_to_hidden"]["bias"].set( - b_hidden_to_hidden_reverse.T + w1, w2 = keras_rnn.weights + serket_cell = ( + serket_cell.at["in_to_hidden"]["weight"] + .set(jnp.transpose(w1.numpy(), (-1, -2, *range(ndim)))) + .at["hidden_to_hidden"]["weight"] + .set(jnp.transpose(w2.numpy(), (-1, -2, *range(ndim)))) ) - res = ScanRNN(cell, reverse_cell, return_sequences=False)(x) - - y = jnp.array([0.35901642, 0.00826644, -0.3015435, -0.13661332]) - - npt.assert_allclose(res, y, atol=1e-5) - - -def test_rnn_error(): - with pytest.raises(TypeError): - ScanRNN(None) - - with pytest.raises(TypeError): - ScanRNN(SimpleRNNCell(3, 3, key=jax.random.PRNGKey(0)), 1) - - layer = ScanRNN( - SimpleRNNCell(3, 3, key=jax.random.PRNGKey(0)), - SimpleRNNCell(3, 3, key=jax.random.PRNGKey(0)), - ) + state = sk.tree_state(serket_cell, input=input[0]) + sekret_output, serket_state = sk.nn.scan_cell(serket_cell)(input, state) - with pytest.raises(ValueError): - layer(jnp.ones([10, 3, 3])) + npt.assert_allclose(keras_output[0], sekret_output, atol=1e-6) + npt.assert_allclose(keras_state[0][0], serket_state.hidden_state, atol=1e-6) + npt.assert_allclose(keras_state[1][0], serket_state.cell_state, atol=1e-6) def test_dense_cell(): - cell = DenseCell( + cell = sk.nn.DenseCell( in_features=10, hidden_features=10, act=lambda x: x, weight_init="ones", bias_init=None, - key=jax.random.PRNGKey(0), + key=jr.PRNGKey(0), ) - x = jnp.ones([10, 10]) - res = ScanRNN(cell)(x) + input = jnp.ones([10, 10]) + state = sk.tree_state(cell) + output, _ = sk.nn.scan_cell(cell)(input, state) # 1x10 @ 10x10 => 1x10 - npt.assert_allclose(res, jnp.ones([10]) * 10.0) + npt.assert_allclose(output[-1], jnp.ones([10]) * 10.0)