From 35fc478b8def9aa3d7585e7c91cbbb0ce58b0b21 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 9 Apr 2024 02:06:18 +0900 Subject: [PATCH] fix doctest fail --- serket/_src/nn/normalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/serket/_src/nn/normalization.py b/serket/_src/nn/normalization.py index b70ff44..5b4e621 100644 --- a/serket/_src/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -571,7 +571,7 @@ class BatchNorm(TreeClass): >>> import serket as sk >>> import jax.random as jr >>> import jax.numpy as jnp - >>> class ThreadedBatchNorm(TreeClass): + >>> class ThreadedBatchNorm(sk.TreeClass): ... def __init__(self, *, key: jax.Array): ... k1, k2 = jax.random.split(key) ... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1) @@ -604,7 +604,7 @@ class BatchNorm(TreeClass): >>> import serket as sk >>> import jax.random as jr >>> import functools as ft - >>> class UnthreadedBatchNorm(TreeClass): + >>> class UnthreadedBatchNorm(sk.TreeClass): ... def __init__(self, *, key: jax.Array): ... k1, k2 = jax.random.split(key) ... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1) @@ -862,7 +862,7 @@ def weight_norm(leaf: T, axis: int | None = -1, eps: float = 1e-12) -> T: >>> import jax >>> import jax.numpy as jnp >>> import serket as sk - >>> class Net(TreeClass): + >>> class Net(sk.TreeClass): ... def __init__(self, *, key: jax.Array): ... k1, k2 = jax.random.split(key) ... self.l1 = sk.nn.Linear(2, 4, key=k1)