From 3ae833dacba58129cf6f4eef6c43edd789799c0f Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Fri, 21 Jul 2023 02:13:36 +0900 Subject: [PATCH] more docs --- README.md | 146 ++++++++++++++--- docs/API/activations.rst | 34 ++++ docs/API/api.rst | 19 +++ docs/API/base.rst | 223 -------------------------- docs/API/containers.rst | 6 + docs/API/convolution.rst | 41 +++++ docs/API/dropout.rst | 8 + docs/API/fully_connected.rst | 6 + docs/API/image_filtering.rst | 8 + docs/API/linear.rst | 10 ++ docs/API/misc.rst | 33 ++++ docs/API/normalization.rst | 8 + docs/API/pooling.rst | 26 +++ docs/API/pytc.rst | 53 ------ docs/API/pytreeclass.rst | 10 ++ docs/API/pytreeclass_advanced_api.rst | 20 +++ docs/API/pytreeclass_core.rst | 11 ++ docs/API/pytreeclass_pretty_print.rst | 19 +++ docs/API/random_transforms.rst | 11 ++ docs/API/recurrent.rst | 18 +++ 20 files changed, 408 insertions(+), 302 deletions(-) create mode 100644 docs/API/activations.rst create mode 100644 docs/API/api.rst delete mode 100644 docs/API/base.rst create mode 100644 docs/API/containers.rst create mode 100644 docs/API/convolution.rst create mode 100644 docs/API/dropout.rst create mode 100644 docs/API/fully_connected.rst create mode 100644 docs/API/image_filtering.rst create mode 100644 docs/API/linear.rst create mode 100644 docs/API/misc.rst create mode 100644 docs/API/normalization.rst create mode 100644 docs/API/pooling.rst delete mode 100644 docs/API/pytc.rst create mode 100644 docs/API/pytreeclass.rst create mode 100644 docs/API/pytreeclass_advanced_api.rst create mode 100644 docs/API/pytreeclass_core.rst create mode 100644 docs/API/pytreeclass_pretty_print.rst create mode 100644 docs/API/random_transforms.rst create mode 100644 docs/API/recurrent.rst diff --git a/README.md b/README.md index 5f3c80a..339eb9f 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,13 @@ ![Tests](https://github.com/ASEM000/serket/actions/workflows/tests.yml/badge.svg) ![pyver](https://img.shields.io/badge/python-3.7%203.8%203.9%203.10-red) -![codestyle](https://img.shields.io/badge/codestyle-black-lightgrey) +![codestyle](https://img.shields.io/badge/codestyle-black-black) [![Downloads](https://pepy.tech/badge/serket)](https://pepy.tech/project/serket) [![codecov](https://codecov.io/gh/ASEM000/serket/branch/main/graph/badge.svg?token=C6NXOK9EVS)](https://codecov.io/gh/ASEM000/serket) +[![Documentation Status](https://readthedocs.org/projects/serket/badge/?version=latest)](https://serket.readthedocs.io/en/latest/?badge=latest) [![DOI](https://zenodo.org/badge/526985786.svg)](https://zenodo.org/badge/latestdoi/526985786) ![PyPI](https://img.shields.io/pypi/v/serket) +[![CodeFactor](https://www.codefactor.io/repository/github/asem000/serket/badge)](https://www.codefactor.io/repository/github/asem000/serket) @@ -30,30 +32,122 @@ pip install git+https://github.com/ASEM000/serket ## 📖 Description and motivation -- `serket` aims to be the most intuitive and easy-to-use physics-based Neural network library in JAX. -- `serket` is fully transparent to `jax` transformation (e.g. `vmap`,`grad`,`jit`,...) -- `serket` current aim to facilitate the integration of numerical methods in a NN setting (see examples for more) +- `serket` aims to be the most intuitive and easy-to-use neural network library in `JAX`. +- `serket` is fully transparent to `jax` transformation (e.g. `vmap`,`grad`,`jit`,...). -
+Example: + +```python +import os +os.environ["KERAS_BACKEND"] = "jax" +from keras_core.datasets import mnist +import jax +import jax.numpy as jnp +import functools as ft +import optax # for gradient optimization +import serket as sk +import time +import matplotlib.pyplot as plt # for plotting the predictions + +EPOCHS = 1 +LR = 1e-3 +BATCH_SIZE = 128 + +(x_train, y_train), _ = mnist.load_data() + +x_train = x_train.reshape(-1, 1, 28, 28).astype("float32") / 255.0 +x_train = jnp.array_split(x_train, x_train.shape[0] // BATCH_SIZE) +y_train = jnp.array_split(y_train, y_train.shape[0] // BATCH_SIZE) + +k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) + +class ConvNet(sk.TreeClass): + conv1: sk.nn.Conv2D = sk.nn.Conv2D(1, 32, 3, key=k1, padding="valid") + pool1: sk.nn.MaxPool2D = sk.nn.MaxPool2D(2, 2) + conv2: sk.nn.Conv2D = sk.nn.Conv2D(32, 64, 3, key=k2, padding="valid") + pool2: sk.nn.MaxPool2D = sk.nn.MaxPool2D(2, 2) + linear: sk.nn.Linear = sk.nn.Linear(1600, 10, key=k3) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.pool1(jax.nn.relu(self.conv1(x))) + x = self.pool2(jax.nn.relu(self.conv2(x))) + x = self.linear(jnp.ravel(x)) + return x + +nn = ConvNet() + +# 1) mask the non-jaxtype parameters +nn = sk.tree_mask(nn) + +# 2) initialize the optimizer state +optim = optax.adam(LR) +optim_state = optim.init(nn) + +@jax.vmap +def softmax_cross_entropy(logits, onehot): + assert onehot.shape == logits.shape == (10,) + return -jnp.sum(jax.nn.log_softmax(logits) * onehot) + +@ft.partial(jax.grad, has_aux=True) +def loss_func(nn, x, y): + # pass non-jaxtype over jax transformation + # using `tree_mask`/`tree_unmask` scheme + # 3) unmask the non-jaxtype parameters to be used in the computation + nn = sk.tree_unmask(nn) + + # 4) vectorize the computation over the batch dimension + # and get the logits + logits = jax.vmap(nn)(x) + onehot = jax.nn.one_hot(y, 10) + + # 5) use the appropriate loss function + loss = jnp.mean(softmax_cross_entropy(logits, onehot)) + return loss, (loss, logits) + + +@jax.vmap +def accuracy_func(logits, y): + assert logits.shape == (10,) + return jnp.argmax(logits) == y + + +@jax.jit +def train_step(nn, optim_state, x, y): + grads, (loss, logits) = loss_func(nn, x, y) + updates, optim_state = optim.update(grads, optim_state) + nn = optax.apply_updates(nn, updates) + return nn, optim_state, (loss, logits) + +for i in range(1, EPOCHS + 1): + t0 = time.time() + for j, (xb, yb) in enumerate(zip(x_train, y_train)): + nn, optim_state, (loss, logits) = train_step(nn, optim_state, xb, yb) + accuracy = jnp.mean(accuracy_func(logits, yb)) + print( + f"Epoch: {i:003d}/{EPOCHS:003d}\t" + f"Batch: {j:003d}/{len(x_train):003d}\t" + f"Batch loss: {loss:3e}\t" + f"Batch accuracy: {accuracy:3f}\t" + f"Time: {time.time() - t0:.3f}", + end="\r", + ) + +# 6) un-mask the trained network +nn = sk.tree_unmask(nn) + +# create 2x5 grid of images +fig, axes = plt.subplots(2, 5, figsize=(10, 4)) +idxs = jax.random.randint(k1, shape=(10,), minval=0, maxval=x_train[0].shape[0]) + +for i, idx in zip(axes.flatten(), idxs): + # get the prediction + pred = nn(x_train[0][idx]) + # plot the image + i.imshow(x_train[0][idx].reshape(28, 28), cmap="gray") + # set the title to be the prediction + i.set_title(jnp.argmax(pred)) + i.set_xticks([]) + i.set_yticks([]) -### 🧠 Neural network package: `serket.nn` 🧠 - -| Group | Layers | -| ------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| Linear | - `Linear`, `Bilinear`, `Multilinear`, `GeneralLinear`, `Identity`, `Embedding` | -| Densely connected | - `FNN` (Fully connected network), | -| Convolution | - `{Conv,FFTConv}{1D,2D,3D}`
- `{Conv,FFTConv}{1D,2D,3D}Transpose`
- `{Depthwise,Separable}{Conv,FFTConv}{1D,2D,3D}`
- `Conv{1D,2D,3D}Local` | -| Containers | - `Sequential`, `Lambda` | -| Pooling
(`kernex` backend) | - `{Avg,Max,LP}Pool{1D,2D,3D}`
- `Global{Avg,Max}Pool{1D,2D,3D}`
- `Adaptive{Avg,Max}Pool{1D,2D,3D}` | -| Reshaping | - `Flatten`, `Unflatten`,
- `FlipLeftRight2D`, `FlipUpDown2D`
- `Resize{1D,2D,3D}`
- `Upsample{1D,2D,3D}`
- `Pad{1D,2D,3D}` | -| Crop | - `Crop{1D,2D}` | -| Normalization | - `{Layer,Instance,Group}Norm` | -| Blurring | - `{Avg,Gaussian}Blur2D` | -| Dropout | - `Dropout`
- `Dropout{1D,2D,3D}` | -| Random transforms | - `RandomCrop{1D,2D}`
- `RandomApply`,
- `RandomCutout{1D,2D}`
- `RandomZoom2D`,
- `RandomContrast2D` | -| Misc | - `HistogramEqualization2D`, `AdjustContrast2D`, `Filter2D`, `PixelShuffle2D` | -| Activations | - `Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh}`,
- `CeLU`,`ELU`,`GELU`,`GLU`
- `Hard{SILU,Shrink,Sigmoid,Swish,Tanh}`,
- `Soft{Plus,Sign,Shrink}`
- `LeakyReLU`,`LogSigmoid`,`LogSoftmax`,`Mish`,`PReLU`,
- `ReLU`,`ReLU6`,`SILU`,`SeLU`,`Sigmoid`
- `Swish`,`Tanh`,`TanhShrink`, `ThresholdedReLU`, `Snake`, `Stan`, `SquarePlus` | -| Recurrent cells | - `{SimpleRNN,LSTM,GRU}Cell`
- `Conv{LSTM,GRU}{1D,2D,3D}Cell` | -| Blocks | - `VGG{16,19}Block`, `UNetBlock` | - -
+# Epoch: 001/001 Batch: 467/468 Batch loss: 2.040178e-01 Batch accuracy: 0.984375 Time: 19.284 +``` \ No newline at end of file diff --git a/docs/API/activations.rst b/docs/API/activations.rst new file mode 100644 index 0000000..b490bff --- /dev/null +++ b/docs/API/activations.rst @@ -0,0 +1,34 @@ +Activations +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: AdaptiveLeakyReLU +.. autoclass:: AdaptiveReLU +.. autoclass:: AdaptiveSigmoid +.. autoclass:: AdaptiveTanh +.. autoclass:: CeLU +.. autoclass:: ELU +.. autoclass:: GELU +.. autoclass:: GLU +.. autoclass:: HardShrink +.. autoclass:: HardSigmoid +.. autoclass:: HardSwish +.. autoclass:: HardTanh +.. autoclass:: LeakyReLU +.. autoclass:: LogSigmoid +.. autoclass:: LogSoftmax +.. autoclass:: Mish +.. autoclass:: PReLU +.. autoclass:: ReLU +.. autoclass:: ReLU6 +.. autoclass:: SeLU +.. autoclass:: Sigmoid +.. autoclass:: SoftPlus +.. autoclass:: SoftShrink +.. autoclass:: SoftSign +.. autoclass:: SquarePlus +.. autoclass:: Swish +.. autoclass:: Snake +.. autoclass:: Tanh +.. autoclass:: TanhShrink +.. autoclass:: ThresholdedReLU \ No newline at end of file diff --git a/docs/API/api.rst b/docs/API/api.rst new file mode 100644 index 0000000..2a5f9f6 --- /dev/null +++ b/docs/API/api.rst @@ -0,0 +1,19 @@ +``Serket`` NN API +====================== + +.. toctree:: + :maxdepth: 2 + :caption: API Documentation + + fully_connected + linear + dropout + containers + pooling + convolution + normalization + image_filtering + misc + random_transforms + activations + recurrent \ No newline at end of file diff --git a/docs/API/base.rst b/docs/API/base.rst deleted file mode 100644 index 513e771..0000000 --- a/docs/API/base.rst +++ /dev/null @@ -1,223 +0,0 @@ - -``Serket`` NN API -====================== - - -Fully connected ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: FNN -.. autoclass:: MLP - -Linear ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: Linear -.. autoclass:: Bilinear -.. autoclass:: Identity -.. autoclass:: Multilinear -.. autoclass:: GeneralLinear -.. autoclass:: Embedding - -Dropout ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: Dropout -.. autoclass:: Dropout1D -.. autoclass:: Dropout2D -.. autoclass:: Dropout3D - - -Containers ---------------------------------- -.. currentmodule:: serket.nn - - -.. autoclass:: Sequential - -Pooling ---------------------------------- -.. currentmodule:: serket.nn - - -.. autoclass:: MaxPool1D -.. autoclass:: MaxPool2D -.. autoclass:: MaxPool3D -.. autoclass:: AvgPool1D -.. autoclass:: AvgPool2D -.. autoclass:: AvgPool3D -.. autoclass:: GlobalAvgPool1D -.. autoclass:: GlobalAvgPool2D -.. autoclass:: GlobalAvgPool3D -.. autoclass:: GlobalMaxPool1D -.. autoclass:: GlobalMaxPool2D -.. autoclass:: GlobalMaxPool3D -.. autoclass:: LPPool1D -.. autoclass:: LPPool2D -.. autoclass:: LPPool3D -.. autoclass:: AdaptiveAvgPool1D -.. autoclass:: AdaptiveAvgPool2D -.. autoclass:: AdaptiveAvgPool3D -.. autoclass:: AdaptiveMaxPool1D -.. autoclass:: AdaptiveMaxPool2D -.. autoclass:: AdaptiveMaxPool3D - - -Convolution ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: Conv1D -.. autoclass:: Conv2D -.. autoclass:: Conv3D - -.. autoclass:: Conv1DTranspose -.. autoclass:: Conv2DTranspose -.. autoclass:: Conv3DTranspose - -.. autoclass:: DepthwiseConv1D -.. autoclass:: DepthwiseConv2D -.. autoclass:: DepthwiseConv3D - -.. autoclass:: SeparableConv1D -.. autoclass:: SeparableConv2D -.. autoclass:: SeparableConv3D - -.. autoclass:: Conv1DLocal -.. autoclass:: Conv2DLocal -.. autoclass:: Conv3DLocal - -.. autoclass:: FFTConv1D -.. autoclass:: FFTConv2D -.. autoclass:: FFTConv3D - -.. autoclass:: DepthwiseFFTConv1D -.. autoclass:: DepthwiseFFTConv2D -.. autoclass:: DepthwiseFFTConv3D - -.. autoclass:: FFTConv1DTranspose -.. autoclass:: FFTConv2DTranspose -.. autoclass:: FFTConv3DTranspose - -.. autoclass:: SeparableFFTConv1D -.. autoclass:: SeparableFFTConv2D -.. autoclass:: SeparableFFTConv3D - -Normalization ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: LayerNorm -.. autoclass:: InstanceNorm -.. autoclass:: GroupNorm -.. autoclass:: BatchNorm - -Image filtering ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: AvgBlur2D -.. autoclass:: GaussianBlur2D -.. autoclass:: Filter2D -.. autoclass:: FFTFilter2D - -Misc ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: FlipLeftRight2D -.. autoclass:: FlipUpDown2D -.. autoclass:: Resize1D -.. autoclass:: Resize2D -.. autoclass:: Resize3D -.. autoclass:: Upsample1D -.. autoclass:: Upsample2D -.. autoclass:: Upsample3D -.. autoclass:: Pad1D -.. autoclass:: Pad2D -.. autoclass:: Pad3D - -.. autoclass:: VGG16Block -.. autoclass:: VGG19Block -.. autoclass:: UNetBlock - -.. autoclass:: Crop1D -.. autoclass:: Crop2D -.. autoclass:: Crop3D - -.. autoclass:: Flatten -.. autoclass:: Unflatten - -.. autoclass:: HistogramEqualization2D -.. autoclass:: PixelShuffle2D - -Random transforms ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: RandomCrop1D -.. autoclass:: RandomCrop2D -.. autoclass:: RandomCrop3D -.. autoclass:: RandomCutout1D -.. autoclass:: RandomCutout2D -.. autoclass:: RandomZoom2D -.. autoclass:: RandomApply - - -Activations ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: AdaptiveLeakyReLU -.. autoclass:: AdaptiveReLU -.. autoclass:: AdaptiveSigmoid -.. autoclass:: AdaptiveTanh -.. autoclass:: CeLU -.. autoclass:: ELU -.. autoclass:: GELU -.. autoclass:: GLU -.. autoclass:: HardShrink -.. autoclass:: HardSigmoid -.. autoclass:: HardSwish -.. autoclass:: HardTanh -.. autoclass:: LeakyReLU -.. autoclass:: LogSigmoid -.. autoclass:: LogSoftmax -.. autoclass:: Mish -.. autoclass:: PReLU -.. autoclass:: ReLU -.. autoclass:: ReLU6 -.. autoclass:: SeLU -.. autoclass:: Sigmoid -.. autoclass:: SoftPlus -.. autoclass:: SoftShrink -.. autoclass:: SoftSign -.. autoclass:: SquarePlus -.. autoclass:: Swish -.. autoclass:: Snake -.. autoclass:: Tanh -.. autoclass:: TanhShrink -.. autoclass:: ThresholdedReLU - -.. autoclass:: AdjustContrast2D -.. autoclass:: RandomContrast2D - -Recurrent ---------------------------------- - -.. currentmodule:: serket.nn - -.. autoclass:: LSTMCell -.. autoclass:: GRUCell -.. autoclass:: SimpleRNNCell -.. autoclass:: DenseCell -.. autoclass:: ConvLSTM1DCell -.. autoclass:: ConvLSTM2DCell -.. autoclass:: ConvLSTM3DCell -.. autoclass:: ConvGRU1DCell -.. autoclass:: ConvGRU2DCell -.. autoclass:: ConvGRU3DCell -.. autoclass:: ScanRNN \ No newline at end of file diff --git a/docs/API/containers.rst b/docs/API/containers.rst new file mode 100644 index 0000000..f1860ae --- /dev/null +++ b/docs/API/containers.rst @@ -0,0 +1,6 @@ +Containers +--------------------------------- +.. currentmodule:: serket.nn + + +.. autoclass:: Sequential \ No newline at end of file diff --git a/docs/API/convolution.rst b/docs/API/convolution.rst new file mode 100644 index 0000000..d3b47e4 --- /dev/null +++ b/docs/API/convolution.rst @@ -0,0 +1,41 @@ +Convolution +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: Conv1D +.. autoclass:: Conv2D +.. autoclass:: Conv3D + +.. autoclass:: Conv1DTranspose +.. autoclass:: Conv2DTranspose +.. autoclass:: Conv3DTranspose + +.. autoclass:: DepthwiseConv1D +.. autoclass:: DepthwiseConv2D +.. autoclass:: DepthwiseConv3D + +.. autoclass:: SeparableConv1D +.. autoclass:: SeparableConv2D +.. autoclass:: SeparableConv3D + +.. autoclass:: Conv1DLocal +.. autoclass:: Conv2DLocal +.. autoclass:: Conv3DLocal + +.. autoclass:: FFTConv1D +.. autoclass:: FFTConv2D +.. autoclass:: FFTConv3D + +.. autoclass:: DepthwiseFFTConv1D +.. autoclass:: DepthwiseFFTConv2D +.. autoclass:: DepthwiseFFTConv3D + +.. autoclass:: FFTConv1DTranspose +.. autoclass:: FFTConv2DTranspose +.. autoclass:: FFTConv3DTranspose + +.. autoclass:: SeparableFFTConv1D +.. autoclass:: SeparableFFTConv2D +.. autoclass:: SeparableFFTConv3D + + diff --git a/docs/API/dropout.rst b/docs/API/dropout.rst new file mode 100644 index 0000000..2db3437 --- /dev/null +++ b/docs/API/dropout.rst @@ -0,0 +1,8 @@ +Dropout +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: Dropout +.. autoclass:: Dropout1D +.. autoclass:: Dropout2D +.. autoclass:: Dropout3D \ No newline at end of file diff --git a/docs/API/fully_connected.rst b/docs/API/fully_connected.rst new file mode 100644 index 0000000..2a1afa0 --- /dev/null +++ b/docs/API/fully_connected.rst @@ -0,0 +1,6 @@ +Fully connected +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: FNN +.. autoclass:: MLP \ No newline at end of file diff --git a/docs/API/image_filtering.rst b/docs/API/image_filtering.rst new file mode 100644 index 0000000..8591e77 --- /dev/null +++ b/docs/API/image_filtering.rst @@ -0,0 +1,8 @@ +Image filtering +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: AvgBlur2D +.. autoclass:: GaussianBlur2D +.. autoclass:: Filter2D +.. autoclass:: FFTFilter2D \ No newline at end of file diff --git a/docs/API/linear.rst b/docs/API/linear.rst new file mode 100644 index 0000000..7947521 --- /dev/null +++ b/docs/API/linear.rst @@ -0,0 +1,10 @@ +Linear +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: Linear +.. autoclass:: Bilinear +.. autoclass:: Identity +.. autoclass:: Multilinear +.. autoclass:: GeneralLinear +.. autoclass:: Embedding \ No newline at end of file diff --git a/docs/API/misc.rst b/docs/API/misc.rst new file mode 100644 index 0000000..c0a8979 --- /dev/null +++ b/docs/API/misc.rst @@ -0,0 +1,33 @@ +Misc +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: FlipLeftRight2D +.. autoclass:: FlipUpDown2D +.. autoclass:: Resize1D +.. autoclass:: Resize2D +.. autoclass:: Resize3D +.. autoclass:: Upsample1D +.. autoclass:: Upsample2D +.. autoclass:: Upsample3D +.. autoclass:: Pad1D +.. autoclass:: Pad2D +.. autoclass:: Pad3D + +.. autoclass:: VGG16Block +.. autoclass:: VGG19Block +.. autoclass:: UNetBlock + +.. autoclass:: Crop1D +.. autoclass:: Crop2D +.. autoclass:: Crop3D + +.. autoclass:: Flatten +.. autoclass:: Unflatten + +.. autoclass:: HistogramEqualization2D +.. autoclass:: PixelShuffle2D + +.. autoclass:: AdjustContrast2D +.. autoclass:: RandomContrast2D + diff --git a/docs/API/normalization.rst b/docs/API/normalization.rst new file mode 100644 index 0000000..c671e5e --- /dev/null +++ b/docs/API/normalization.rst @@ -0,0 +1,8 @@ +Normalization +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: LayerNorm +.. autoclass:: InstanceNorm +.. autoclass:: GroupNorm +.. autoclass:: BatchNorm \ No newline at end of file diff --git a/docs/API/pooling.rst b/docs/API/pooling.rst new file mode 100644 index 0000000..55a8882 --- /dev/null +++ b/docs/API/pooling.rst @@ -0,0 +1,26 @@ +Pooling +--------------------------------- +.. currentmodule:: serket.nn + + +.. autoclass:: MaxPool1D +.. autoclass:: MaxPool2D +.. autoclass:: MaxPool3D +.. autoclass:: AvgPool1D +.. autoclass:: AvgPool2D +.. autoclass:: AvgPool3D +.. autoclass:: GlobalAvgPool1D +.. autoclass:: GlobalAvgPool2D +.. autoclass:: GlobalAvgPool3D +.. autoclass:: GlobalMaxPool1D +.. autoclass:: GlobalMaxPool2D +.. autoclass:: GlobalMaxPool3D +.. autoclass:: LPPool1D +.. autoclass:: LPPool2D +.. autoclass:: LPPool3D +.. autoclass:: AdaptiveAvgPool1D +.. autoclass:: AdaptiveAvgPool2D +.. autoclass:: AdaptiveAvgPool3D +.. autoclass:: AdaptiveMaxPool1D +.. autoclass:: AdaptiveMaxPool2D +.. autoclass:: AdaptiveMaxPool3D \ No newline at end of file diff --git a/docs/API/pytc.rst b/docs/API/pytc.rst deleted file mode 100644 index dd867d4..0000000 --- a/docs/API/pytc.rst +++ /dev/null @@ -1,53 +0,0 @@ -``PyTreeClass`` exported API -============================= - - -.. currentmodule:: serket - -.. autoclass:: TreeClass -.. autoclass:::members: at -.. autofunction:: is_tree_equal -.. autofunction:: field -.. autofunction:: fields - -``PyTreeClass`` exported pretty printing API ----------------------------------------------- - -.. currentmodule:: serket - -.. autofunction:: tree_diagram -.. autofunction:: tree_graph -.. autofunction:: tree_mermaid -.. autofunction:: tree_repr -.. autofunction:: tree_str -.. autofunction:: tree_summary -.. autofunction:: tree_repr_with_trace -.. currentmodule:: serket -.. autofunction:: is_nondiff -.. autofunction:: freeze -.. autofunction:: unfreeze -.. autofunction:: is_frozen -.. autofunction:: tree_mask -.. autofunction:: tree_unmask - - -``PyTreeClass`` exported advanced API ---------------------------------------- -.. currentmodule:: serket - -.. autofunction:: bcmap -.. autoclass:: Partial -.. autoclass:: AtIndexer - :members: - get, - set, - apply, - scan, - reduce -.. autoclass:: BaseKey - :members: - __eq__ -.. autofunction:: tree_map_with_trace -.. autofunction:: tree_leaves_with_trace -.. autofunction:: tree_flatten_with_trace - diff --git a/docs/API/pytreeclass.rst b/docs/API/pytreeclass.rst new file mode 100644 index 0000000..edda822 --- /dev/null +++ b/docs/API/pytreeclass.rst @@ -0,0 +1,10 @@ +``PyTreeClass`` exported API +============================= + +.. toctree:: + :maxdepth: 2 + :caption: API Documentation + + pytreeclass_core + pytreeclass_advanced_api + pytreeclass_pretty_print diff --git a/docs/API/pytreeclass_advanced_api.rst b/docs/API/pytreeclass_advanced_api.rst new file mode 100644 index 0000000..253f8cb --- /dev/null +++ b/docs/API/pytreeclass_advanced_api.rst @@ -0,0 +1,20 @@ +Advanced API +--------------------------------------- +.. currentmodule:: serket + +.. autofunction:: bcmap +.. autoclass:: Partial +.. autoclass:: AtIndexer + :members: + get, + set, + apply, + scan, + reduce +.. autoclass:: BaseKey + :members: + __eq__ +.. autofunction:: tree_map_with_trace +.. autofunction:: tree_leaves_with_trace +.. autofunction:: tree_flatten_with_trace + diff --git a/docs/API/pytreeclass_core.rst b/docs/API/pytreeclass_core.rst new file mode 100644 index 0000000..b3aa901 --- /dev/null +++ b/docs/API/pytreeclass_core.rst @@ -0,0 +1,11 @@ +Core API +============================= + + +.. currentmodule:: serket + +.. autoclass:: TreeClass +.. autoclass:::members: at +.. autofunction:: is_tree_equal +.. autofunction:: field +.. autofunction:: fields \ No newline at end of file diff --git a/docs/API/pytreeclass_pretty_print.rst b/docs/API/pytreeclass_pretty_print.rst new file mode 100644 index 0000000..ad19769 --- /dev/null +++ b/docs/API/pytreeclass_pretty_print.rst @@ -0,0 +1,19 @@ +Pretty printing API +---------------------------------------------- + +.. currentmodule:: serket + +.. autofunction:: tree_diagram +.. autofunction:: tree_graph +.. autofunction:: tree_mermaid +.. autofunction:: tree_repr +.. autofunction:: tree_str +.. autofunction:: tree_summary +.. autofunction:: tree_repr_with_trace +.. currentmodule:: serket +.. autofunction:: is_nondiff +.. autofunction:: freeze +.. autofunction:: unfreeze +.. autofunction:: is_frozen +.. autofunction:: tree_mask +.. autofunction:: tree_unmask diff --git a/docs/API/random_transforms.rst b/docs/API/random_transforms.rst new file mode 100644 index 0000000..c0e9b2b --- /dev/null +++ b/docs/API/random_transforms.rst @@ -0,0 +1,11 @@ +Random transforms +--------------------------------- +.. currentmodule:: serket.nn + +.. autoclass:: RandomCrop1D +.. autoclass:: RandomCrop2D +.. autoclass:: RandomCrop3D +.. autoclass:: RandomCutout1D +.. autoclass:: RandomCutout2D +.. autoclass:: RandomZoom2D +.. autoclass:: RandomApply \ No newline at end of file diff --git a/docs/API/recurrent.rst b/docs/API/recurrent.rst new file mode 100644 index 0000000..7069e5a --- /dev/null +++ b/docs/API/recurrent.rst @@ -0,0 +1,18 @@ +Recurrent +--------------------------------- + +.. currentmodule:: serket.nn + +.. autoclass:: LSTMCell +.. autoclass:: GRUCell +.. autoclass:: SimpleRNNCell +.. autoclass:: DenseCell +.. autoclass:: ConvLSTM1DCell +.. autoclass:: ConvLSTM2DCell +.. autoclass:: ConvLSTM3DCell +.. autoclass:: ConvGRU1DCell +.. autoclass:: ConvGRU2DCell +.. autoclass:: ConvGRU3DCell +.. autoclass:: ScanRNN + :members: + __call__ \ No newline at end of file