diff --git a/serket/nn/image.py b/serket/nn/image.py index 87e005d..71f6387 100644 --- a/serket/nn/image.py +++ b/serket/nn/image.py @@ -829,7 +829,6 @@ class RandomPerspective2D(sk.TreeClass): >>> import serket as sk >>> import jax.numpy as jnp >>> import jax - >>> layer = sk.nn.RandomPerspective2D(100) >>> x, y = jnp.meshgrid(jnp.linspace(-1, 1, 30), jnp.linspace(-1, 1, 30)) >>> d = jnp.sqrt(x**2 + y**2) >>> mask = d < 1 @@ -864,7 +863,9 @@ class RandomPerspective2D(sk.TreeClass): [0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]] - >>> out = layer(mask[None], key=jax.random.PRNGKey(10))[0] + >>> layer = sk.nn.RandomPerspective2D(100) + >>> key = jax.random.PRNGKey(10) + >>> out = layer(mask[None], key=key)[0] >>> print(out.astype(int)) [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]