Skip to content

Commit

Permalink
more doc simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 24, 2023
1 parent 2df9291 commit 482a48f
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<div align="center">
<img width="350px" src="assets/logo.svg"></div>
<img width="250px" src="assets/logo.svg"></div>

<h2 align="center">The ✨Magical✨ JAX Scientific ML Library.</h2>
<h2 align="center">The ✨Magical✨ JAX ML Library.</h2>
<h5 align = "center"> *Serket is the goddess of magic in Egyptian mythology

[**Installation**](#Installation)
Expand Down Expand Up @@ -44,20 +44,13 @@ import jax, jax.numpy as jnp
import serket as sk
import optax

x_train = ...
y_train = ...

x_train, y_train = ..., ...
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)

nn = sk.nn.Sequential(
sk.nn.Conv2D(1, 32, 3, key=k1, padding="valid"),
jax.nn.relu,
sk.nn.MaxPool2D(2, 2),
sk.nn.Conv2D(32, 64, 3, key=k2, padding="valid"),
jax.nn.relu,
sk.nn.MaxPool2D(2, 2),
jnp.ravel,
sk.nn.Linear(1600, 10, key=k3),
sk.nn.Linear(28 * 28, 64, key=k1), jax.nn.relu,
sk.nn.Linear(64, 64, key=k2), jax.nn.relu,
sk.nn.Linear(64, 10, key=k3),
)

nn = sk.tree_mask(nn) # pass non-jaxtype through jax-transforms
Expand Down

0 comments on commit 482a48f

Please sign in to comment.