From 482a48fcf13ae6a1a13d62bcb246cb8f84de8ff5 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 24 Jul 2023 09:02:56 +0900 Subject: [PATCH] more doc simplifications --- README.md | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 7690fa9..718ffd9 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
-
+ -

The ✨Magical✨ JAX Scientific ML Library.

+

The ✨Magical✨ JAX ML Library.

*Serket is the goddess of magic in Egyptian mythology [**Installation**](#Installation) @@ -44,20 +44,13 @@ import jax, jax.numpy as jnp import serket as sk import optax -x_train = ... -y_train = ... - +x_train, y_train = ..., ... k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) nn = sk.nn.Sequential( - sk.nn.Conv2D(1, 32, 3, key=k1, padding="valid"), - jax.nn.relu, - sk.nn.MaxPool2D(2, 2), - sk.nn.Conv2D(32, 64, 3, key=k2, padding="valid"), - jax.nn.relu, - sk.nn.MaxPool2D(2, 2), - jnp.ravel, - sk.nn.Linear(1600, 10, key=k3), + sk.nn.Linear(28 * 28, 64, key=k1), jax.nn.relu, + sk.nn.Linear(64, 64, key=k2), jax.nn.relu, + sk.nn.Linear(64, 10, key=k3), ) nn = sk.tree_mask(nn) # pass non-jaxtype through jax-transforms