Skip to content

Commit

Permalink
Update train_bilstm.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 2, 2023
1 parent dc6779c commit 3499890
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions docs/notebooks/train_bilstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -79,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -99,15 +99,15 @@
"\n",
" def __call__(self, input: Input) -> Input:\n",
" # initialize the states of the cells\n",
" (state1, state2, state3) = sk.tree_state((self.cell1, self.cell2, self.cell3))\n",
" state = sk.tree_state(self)\n",
" # run the forward cell\n",
" output1, state1 = sk.nn.scan_cell(self.cell1)(input, state1)\n",
" output1, state1 = sk.nn.scan_cell(self.cell1)(input, state.cell1)\n",
" # run the backward cell\n",
" output2, state2 = sk.nn.scan_cell(self.cell2, reverse=True)(input, state2)\n",
" output2, state2 = sk.nn.scan_cell(self.cell2, reverse=True)(input, state.cell2)\n",
" # concatenate the outputs\n",
" output = jnp.concatenate((output1, output2), axis=1)\n",
" # run the final cell\n",
" output, state3 = sk.nn.scan_cell(self.cell3)(output, state3)\n",
" output, state3 = sk.nn.scan_cell(self.cell3)(output, state.cell3)\n",
" # return the last time step\n",
" return output[-1]\n",
"\n",
Expand All @@ -130,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -171,23 +171,23 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 2.065103e-03\tTime: 0.019\r"
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 2.065103e-03\tTime: 0.022\r"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x28a8d2450>"
"<matplotlib.legend.Legend at 0x298c1dc10>"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
Expand Down

0 comments on commit 3499890

Please sign in to comment.