diff --git a/docs/notebooks/train_bilstm.ipynb b/docs/notebooks/train_bilstm.ipynb index 9c900a8..f9fc894 100644 --- a/docs/notebooks/train_bilstm.ipynb +++ b/docs/notebooks/train_bilstm.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -99,15 +99,15 @@ "\n", " def __call__(self, input: Input) -> Input:\n", " # initialize the states of the cells\n", - " (state1, state2, state3) = sk.tree_state((self.cell1, self.cell2, self.cell3))\n", + " state = sk.tree_state(self)\n", " # run the forward cell\n", - " output1, state1 = sk.nn.scan_cell(self.cell1)(input, state1)\n", + " output1, state1 = sk.nn.scan_cell(self.cell1)(input, state.cell1)\n", " # run the backward cell\n", - " output2, state2 = sk.nn.scan_cell(self.cell2, reverse=True)(input, state2)\n", + " output2, state2 = sk.nn.scan_cell(self.cell2, reverse=True)(input, state.cell2)\n", " # concatenate the outputs\n", " output = jnp.concatenate((output1, output2), axis=1)\n", " # run the final cell\n", - " output, state3 = sk.nn.scan_cell(self.cell3)(output, state3)\n", + " output, state3 = sk.nn.scan_cell(self.cell3)(output, state.cell3)\n", " # return the last time step\n", " return output[-1]\n", "\n", @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -171,23 +171,23 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 100/100\tBatch: 100/100\tBatch loss: 2.065103e-03\tTime: 0.019\r" + "Epoch: 100/100\tBatch: 100/100\tBatch loss: 2.065103e-03\tTime: 0.022\r" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" },