From 5c05315e2e94c431e7f5550fd159664f898425c4 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 25 Jul 2023 04:45:24 +0900 Subject: [PATCH] docs edit --- docs/notebooks/bilstm.ipynb | 4 ++-- docs/notebooks/mnist.ipynb | 4 ++-- serket/nn/random_transform.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/notebooks/bilstm.ipynb b/docs/notebooks/bilstm.ipynb index 15cdb28..43c813c 100644 --- a/docs/notebooks/bilstm.ipynb +++ b/docs/notebooks/bilstm.ipynb @@ -161,13 +161,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.760744e-03\tTime: 0.020\r" + "Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.760744e-03\tTime: 0.025\r" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, diff --git a/docs/notebooks/mnist.ipynb b/docs/notebooks/mnist.ipynb index 9950fbb..e22f8ab 100644 --- a/docs/notebooks/mnist.ipynb +++ b/docs/notebooks/mnist.ipynb @@ -82,7 +82,7 @@ "source": [ "k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)\n", "\n", - "\n", + "@sk.autoinit\n", "class ConvNet(sk.TreeClass):\n", " conv1: sk.nn.Conv2D = sk.nn.Conv2D(1, 32, 3, key=k1, padding=\"valid\")\n", " pool1: sk.nn.MaxPool2D = sk.nn.MaxPool2D(2, 2)\n", @@ -201,7 +201,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 001/001\tBatch: 468/468\tBatch loss: 2.040178e-01\tBatch accuracy: 0.984375\tTime: 18.339\r" + "Epoch: 001/001\tBatch: 468/468\tBatch loss: 2.040178e-01\tBatch accuracy: 0.984375\tTime: 18.784\r" ] }, { diff --git a/serket/nn/random_transform.py b/serket/nn/random_transform.py index 50829c3..e103fc5 100644 --- a/serket/nn/random_transform.py +++ b/serket/nn/random_transform.py @@ -26,7 +26,7 @@ from serket.nn.resize import Resize2D from serket.nn.utils import Range - +@sk.autoinit class RandomApply(sk.TreeClass): """ Randomly applies a layer with probability p.