Skip to content

Commit

Permalink
simplify example
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 23, 2023
1 parent 8d3bbdc commit 2df9291
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,18 @@ nn = sk.nn.Sequential(
sk.nn.Linear(1600, 10, key=k3),
)


nn = sk.tree_mask(nn) # pass non-jaxtype through jax-transforms
optim = optax.adam(LR)
optim_state = optim.init(nn)

@jax.vmap
def softmax_cross_entropy(logits, onehot):
assert onehot.shape == logits.shape == (10,)
return -jnp.sum(jax.nn.log_softmax(logits) * onehot)

@ft.partial(jax.grad, has_aux=True)
def loss_func(nn, x, y):
nn = sk.tree_unmask(nn)
logits = jax.vmap(nn)(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(softmax_cross_entropy(logits, onehot))
loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)

@jax.vmap
def accuracy_func(logits, y):
assert logits.shape == (10,)
return jnp.argmax(logits) == y

@jax.jit
def train_step(nn, optim_state, x, y):
grads, (loss, logits) = loss_func(nn, x, y)
Expand All @@ -92,7 +81,7 @@ def train_step(nn, optim_state, x, y):

for j, (xb, yb) in enumerate(zip(x_train, y_train)):
nn, optim_state, (loss, logits) = train_step(nn, optim_state, xb, yb)
accuracy = jnp.mean(accuracy_func(logits, yb))
accuracy = accuracy_func(logits, y_train)

nn = sk.tree_unmask(nn)
```

0 comments on commit 2df9291

Please sign in to comment.