Skip to content

Commit

Permalink
more doc edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 20, 2023
1 parent 7862d62 commit 23fb3d9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
32 changes: 16 additions & 16 deletions serket/nn/blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ class AvgBlur2D(sk.TreeClass):
>>> layer = sk.nn.AvgBlur2D(in_features=1, kernel_size=3)
>>> print(layer(jnp.ones((1,5,5))))
[[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""

def __init__(self, in_features: int, kernel_size: int | tuple[int, int]):
Expand Down Expand Up @@ -97,10 +97,10 @@ class GaussianBlur2D(sk.TreeClass):
>>> layer = sk.nn.GaussianBlur2D(in_features=1, kernel_size=3)
>>> print(layer(jnp.ones((1,5,5))))
[[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
Expand Down Expand Up @@ -158,10 +158,10 @@ class Filter2D(sk.TreeClass):
>>> layer = sk.nn.Filter2D(in_features=1, kernel=jnp.ones((3,3)))
>>> print(layer(jnp.ones((1,5,5))))
[[[4. 6. 6. 6. 4.]
[6. 9. 9. 9. 6.]
[6. 9. 9. 9. 6.]
[6. 9. 9. 9. 6.]
[4. 6. 6. 6. 4.]]]
[6. 9. 9. 9. 6.]
[6. 9. 9. 9. 6.]
[6. 9. 9. 9. 6.]
[4. 6. 6. 6. 4.]]]
"""

def __init__(self, in_features: int, kernel: jax.Array):
Expand Down Expand Up @@ -203,10 +203,10 @@ class FFTFilter2D(sk.TreeClass):
>>> layer = sk.nn.FFTFilter2D(in_features=1, kernel=jnp.ones((3,3)))
>>> print(layer(jnp.ones((1,5,5))))
[[[4.0000005 6.0000005 6.000001 6.0000005 4.0000005]
[6.0000005 9. 9. 9. 6.0000005]
[6.0000005 9. 9. 9. 6.0000005]
[6.0000005 9. 9. 9. 6.0000005]
[4. 6.0000005 6.0000005 6.0000005 4. ]]]
[6.0000005 9. 9. 9. 6.0000005]
[6.0000005 9. 9. 9. 6.0000005]
[6.0000005 9. 9. 9. 6.0000005]
[4. 6.0000005 6.0000005 6.0000005 4. ]]]
"""

def __init__(self, in_features: int, kernel: jax.Array):
Expand Down
12 changes: 6 additions & 6 deletions serket/nn/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ class Crop2D(CropND):
>>> x = jnp.arange(1, 26).reshape((1, 5, 5))
>>> print(x)
[[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]]
>>> print(sk.nn.Crop2D(size=3, start=(2, 0))(x))
[[[11 12 13]
[16 17 18]
[21 22 23]]]
[16 17 18]
[21 22 23]]]
"""

def __init__(self, size: int | tuple[int, int], start: int | tuple[int, int]):
Expand Down
16 changes: 8 additions & 8 deletions serket/nn/flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class FlipLeftRight2D(sk.TreeClass):
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.FlipLeftRight2D()(x))
[[[3 2 1]
[6 5 4]
[9 8 7]]]
[6 5 4]
[9 8 7]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand All @@ -66,13 +66,13 @@ class FlipUpDown2D(sk.TreeClass):
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.FlipUpDown2D()(x))
[[[7 8 9]
[4 5 6]
[1 2 3]]]
[4 5 6]
[1 2 3]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand Down
12 changes: 6 additions & 6 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> rnn_state = sk.tree_state(cell) # 20-dimensional hidden state
>>> x = jnp.ones((10,)) # 10 features
>>> result = cell(x, rnn_state)
Expand Down Expand Up @@ -174,7 +174,7 @@ class DenseCell(RNNCell):
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> cell = DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> dummy_state = sk.tree_state(cell) # 20-dimensional hidden state
>>> x = jnp.ones((10,)) # 10 features
>>> result = cell(x, dummy_state)
Expand Down Expand Up @@ -1006,14 +1006,14 @@ class ScanRNN(sk.TreeClass):
return_sequences: whether to return the hidden state for each timestep.
Example:
>>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> rnn = ScanRNN(cell)
>>> cell = sk.nn.SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state
>>> rnn = sk.nn.ScanRNN(cell)
>>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features
>>> result, state = rnn(x) # 20 features
>>> print(result.shape)
(20,)
>>> cell = SimpleRNNCell(10, 20)
>>> rnn = ScanRNN(cell, return_sequences=True)
>>> cell = sk.nn.SimpleRNNCell(10, 20)
>>> rnn = sk.nn.ScanRNN(cell, return_sequences=True)
>>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features
>>> result, state = rnn(x) # 5 timesteps, 20 features
>>> print(result.shape)
Expand Down

0 comments on commit 23fb3d9

Please sign in to comment.