Skip to content

Commit

Permalink
fix doctest fail
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 8, 2024
1 parent e3840a1 commit 35fc478
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions serket/_src/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class BatchNorm(TreeClass):
>>> import serket as sk
>>> import jax.random as jr
>>> import jax.numpy as jnp
>>> class ThreadedBatchNorm(TreeClass):
>>> class ThreadedBatchNorm(sk.TreeClass):
... def __init__(self, *, key: jax.Array):
... k1, k2 = jax.random.split(key)
... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1)
Expand Down Expand Up @@ -604,7 +604,7 @@ class BatchNorm(TreeClass):
>>> import serket as sk
>>> import jax.random as jr
>>> import functools as ft
>>> class UnthreadedBatchNorm(TreeClass):
>>> class UnthreadedBatchNorm(sk.TreeClass):
... def __init__(self, *, key: jax.Array):
... k1, k2 = jax.random.split(key)
... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1)
Expand Down Expand Up @@ -862,7 +862,7 @@ def weight_norm(leaf: T, axis: int | None = -1, eps: float = 1e-12) -> T:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> class Net(TreeClass):
>>> class Net(sk.TreeClass):
... def __init__(self, *, key: jax.Array):
... k1, k2 = jax.random.split(key)
... self.l1 = sk.nn.Linear(2, 4, key=k1)
Expand Down

0 comments on commit 35fc478

Please sign in to comment.