diff --git a/serket/nn/blur.py b/serket/nn/blur.py index b8c2dcd..27c6c80 100644 --- a/serket/nn/blur.py +++ b/serket/nn/blur.py @@ -40,10 +40,10 @@ class AvgBlur2D(sk.TreeClass): >>> layer = sk.nn.AvgBlur2D(in_features=1, kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) [[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] """ def __init__(self, in_features: int, kernel_size: int | tuple[int, int]): @@ -97,10 +97,10 @@ class GaussianBlur2D(sk.TreeClass): >>> layer = sk.nn.GaussianBlur2D(in_features=1, kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) [[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764] - [0.7259314 1. 1. 1. 0.7259314] - [0.7259314 1. 1. 1. 0.7259314] - [0.7259314 1. 1. 1. 0.7259314] - [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] + [0.7259314 1. 1. 1. 0.7259314] + [0.7259314 1. 1. 1. 0.7259314] + [0.7259314 1. 1. 1. 0.7259314] + [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] """ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): @@ -158,10 +158,10 @@ class Filter2D(sk.TreeClass): >>> layer = sk.nn.Filter2D(in_features=1, kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) [[[4. 6. 6. 6. 4.] - [6. 9. 9. 9. 6.] - [6. 9. 9. 9. 6.] - [6. 9. 9. 9. 6.] - [4. 6. 6. 6. 4.]]] + [6. 9. 9. 9. 6.] + [6. 9. 9. 9. 6.] + [6. 9. 9. 9. 6.] + [4. 6. 6. 6. 4.]]] """ def __init__(self, in_features: int, kernel: jax.Array): @@ -203,10 +203,10 @@ class FFTFilter2D(sk.TreeClass): >>> layer = sk.nn.FFTFilter2D(in_features=1, kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) [[[4.0000005 6.0000005 6.000001 6.0000005 4.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [4. 6.0000005 6.0000005 6.0000005 4. ]]] + [6.0000005 9. 9. 9. 6.0000005] + [6.0000005 9. 9. 9. 6.0000005] + [6.0000005 9. 9. 9. 6.0000005] + [4. 6.0000005 6.0000005 6.0000005 4. ]]] """ def __init__(self, in_features: int, kernel: jax.Array): diff --git a/serket/nn/crop.py b/serket/nn/crop.py index 79bd96a..2d9752c 100644 --- a/serket/nn/crop.py +++ b/serket/nn/crop.py @@ -90,14 +90,14 @@ class Crop2D(CropND): >>> x = jnp.arange(1, 26).reshape((1, 5, 5)) >>> print(x) [[[ 1 2 3 4 5] - [ 6 7 8 9 10] - [11 12 13 14 15] - [16 17 18 19 20] - [21 22 23 24 25]]] + [ 6 7 8 9 10] + [11 12 13 14 15] + [16 17 18 19 20] + [21 22 23 24 25]]] >>> print(sk.nn.Crop2D(size=3, start=(2, 0))(x)) [[[11 12 13] - [16 17 18] - [21 22 23]]] + [16 17 18] + [21 22 23]]] """ def __init__(self, size: int | tuple[int, int], start: int | tuple[int, int]): diff --git a/serket/nn/flip.py b/serket/nn/flip.py index 725360c..a2f1636 100644 --- a/serket/nn/flip.py +++ b/serket/nn/flip.py @@ -35,13 +35,13 @@ class FlipLeftRight2D(sk.TreeClass): >>> x = jnp.arange(1,10).reshape(1,3, 3) >>> print(x) [[[1 2 3] - [4 5 6] - [7 8 9]]] + [4 5 6] + [7 8 9]]] >>> print(sk.nn.FlipLeftRight2D()(x)) [[[3 2 1] - [6 5 4] - [9 8 7]]] + [6 5 4] + [9 8 7]]] """ @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -66,13 +66,13 @@ class FlipUpDown2D(sk.TreeClass): >>> x = jnp.arange(1,10).reshape(1,3, 3) >>> print(x) [[[1 2 3] - [4 5 6] - [7 8 9]]] + [4 5 6] + [7 8 9]]] >>> print(sk.nn.FlipUpDown2D()(x)) [[[7 8 9] - [4 5 6] - [1 2 3]]] + [4 5 6] + [1 2 3]]] """ @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 6d33cb2..0ca1d4d 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -100,7 +100,7 @@ def __init__( Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state + >>> cell = sk.nn.SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state >>> rnn_state = sk.tree_state(cell) # 20-dimensional hidden state >>> x = jnp.ones((10,)) # 10 features >>> result = cell(x, rnn_state) @@ -174,7 +174,7 @@ class DenseCell(RNNCell): Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> cell = DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state + >>> cell = sk.nn.DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state >>> dummy_state = sk.tree_state(cell) # 20-dimensional hidden state >>> x = jnp.ones((10,)) # 10 features >>> result = cell(x, dummy_state) @@ -1006,14 +1006,14 @@ class ScanRNN(sk.TreeClass): return_sequences: whether to return the hidden state for each timestep. Example: - >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state - >>> rnn = ScanRNN(cell) + >>> cell = sk.nn.SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state + >>> rnn = sk.nn.ScanRNN(cell) >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features >>> result, state = rnn(x) # 20 features >>> print(result.shape) (20,) - >>> cell = SimpleRNNCell(10, 20) - >>> rnn = ScanRNN(cell, return_sequences=True) + >>> cell = sk.nn.SimpleRNNCell(10, 20) + >>> rnn = sk.nn.ScanRNN(cell, return_sequences=True) >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features >>> result, state = rnn(x) # 5 timesteps, 20 features >>> print(result.shape)