From 57035dd04896577b037667fbaccbafbfc0f9ac17 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 9 Sep 2023 04:32:48 +0900 Subject: [PATCH] move to _src --- docs/notebooks/layers_overview.ipynb | 2 +- serket/__init__.py | 9 +- serket/_src/__init__.py | 13 ++ serket/_src/cluster/__init__.py | 13 ++ serket/{ => _src}/cluster/kmeans.py | 6 +- serket/{ => _src}/custom_transform.py | 4 +- serket/_src/image/__init__.py | 13 ++ serket/{ => _src}/image/augment.py | 8 +- serket/{ => _src}/image/filter.py | 8 +- serket/{ => _src}/image/geometric.py | 8 +- serket/_src/nn/__init__.py | 13 ++ serket/{ => _src}/nn/activation.py | 4 +- serket/{ => _src}/nn/attention.py | 6 +- serket/{ => _src}/nn/containers.py | 6 +- serket/{ => _src}/nn/convolution.py | 6 +- serket/{ => _src}/nn/dropout.py | 11 +- serket/{ => _src}/nn/initialization.py | 4 +- serket/{ => _src}/nn/linear.py | 8 +- serket/{ => _src}/nn/normalization.py | 9 +- serket/{ => _src}/nn/pooling.py | 4 +- serket/{ => _src}/nn/recurrent.py | 10 +- serket/{ => _src}/nn/reshape.py | 8 +- serket/{ => _src}/utils.py | 2 +- serket/cluster/__init__.py | 4 +- serket/experimental/__init__.py | 2 +- serket/image/__init__.py | 8 +- serket/nn/__init__.py | 34 ++-- tests/__init__.py | 2 +- tests/test_activation.py | 4 +- tests/test_attention.py | 2 +- tests/test_clustering.py | 2 +- tests/test_containers.py | 2 +- tests/test_conv.py | 2 +- tests/test_convolution.py | 222 +++++++++++++------------ tests/test_dropout.py | 13 +- tests/test_image_filter.py | 116 +++++-------- tests/test_init.py | 4 +- tests/test_linear.py | 59 ++++--- tests/test_normalization.py | 23 ++- tests/test_pooling.py | 4 +- tests/test_reshape.py | 105 +++++------- tests/test_rnn.py | 6 +- tests/test_sequential.py | 18 +- tests/test_utils.py | 6 +- 44 files changed, 405 insertions(+), 408 deletions(-) create mode 100644 serket/_src/__init__.py create mode 100644 serket/_src/cluster/__init__.py rename serket/{ => _src}/cluster/kmeans.py (98%) rename serket/{ => _src}/custom_transform.py (99%) create mode 100644 serket/_src/image/__init__.py rename serket/{ => _src}/image/augment.py (98%) rename serket/{ => _src}/image/filter.py (98%) rename serket/{ => _src}/image/geometric.py (99%) create mode 100644 serket/_src/nn/__init__.py rename serket/{ => _src}/nn/activation.py (99%) rename serket/{ => _src}/nn/attention.py (98%) rename serket/{ => _src}/nn/containers.py (98%) rename serket/{ => _src}/nn/convolution.py (99%) rename serket/{ => _src}/nn/dropout.py (98%) rename serket/{ => _src}/nn/initialization.py (97%) rename serket/{ => _src}/nn/linear.py (99%) rename serket/{ => _src}/nn/normalization.py (99%) rename serket/{ => _src}/nn/pooling.py (99%) rename serket/{ => _src}/nn/recurrent.py (99%) rename serket/{ => _src}/nn/reshape.py (99%) rename serket/{ => _src}/utils.py (99%) diff --git a/docs/notebooks/layers_overview.ipynb b/docs/notebooks/layers_overview.ipynb index fae676c..e1840d6 100644 --- a/docs/notebooks/layers_overview.ipynb +++ b/docs/notebooks/layers_overview.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 📙 `serket.nn` layers overview" + "# 📙 `serket._src.nn` layers overview" ] }, { diff --git a/serket/__init__.py b/serket/__init__.py index 261c68a..4f96eb9 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,10 +42,11 @@ unfreeze, ) +from serket._src.custom_transform import tree_eval, tree_state +from serket._src.nn.activation import def_act_entry +from serket._src.nn.initialization import def_init_entry + from . import cluster, image, nn -from .custom_transform import tree_eval, tree_state -from .nn.activation import def_act_entry -from .nn.initialization import def_init_entry __all__ = [ # general utils diff --git a/serket/_src/__init__.py b/serket/_src/__init__.py new file mode 100644 index 0000000..dbf0b04 --- /dev/null +++ b/serket/_src/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serket/_src/cluster/__init__.py b/serket/_src/cluster/__init__.py new file mode 100644 index 0000000..dbf0b04 --- /dev/null +++ b/serket/_src/cluster/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serket/cluster/kmeans.py b/serket/_src/cluster/kmeans.py similarity index 98% rename from serket/cluster/kmeans.py rename to serket/_src/cluster/kmeans.py index 05b643f..65ac90e 100644 --- a/serket/cluster/kmeans.py +++ b/serket/_src/cluster/kmeans.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,8 +23,8 @@ from typing_extensions import Annotated import serket as sk -from serket.custom_transform import tree_eval, tree_state -from serket.utils import IsInstance, Range +from serket._src.custom_transform import tree_eval, tree_state +from serket._src.utils import IsInstance, Range """K-means utility functions.""" diff --git a/serket/custom_transform.py b/serket/_src/custom_transform.py similarity index 99% rename from serket/custom_transform.py rename to serket/_src/custom_transform.py index 9a9da1d..42e7ba2 100644 --- a/serket/custom_transform.py +++ b/serket/_src/custom_transform.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -133,7 +133,7 @@ def tree_eval(tree): Note: To define evaluation rule for a custom layer, use the decorator - :func:`.tree_eval.def_eval` on a function that accepts the layer. The + :func:`.tree_eval.def_eval` on a function that accepts the layer. The function should return the evaluation layer. >>> import serket as sk diff --git a/serket/_src/image/__init__.py b/serket/_src/image/__init__.py new file mode 100644 index 0000000..dbf0b04 --- /dev/null +++ b/serket/_src/image/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serket/image/augment.py b/serket/_src/image/augment.py similarity index 98% rename from serket/image/augment.py rename to serket/_src/image/augment.py index 47b30ca..a3c5250 100644 --- a/serket/image/augment.py +++ b/serket/_src/image/augment.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ from jax import lax import serket as sk -from serket.custom_transform import tree_eval -from serket.nn.linear import Identity -from serket.utils import IsInstance, Range, validate_spatial_ndim +from serket._src.custom_transform import tree_eval +from serket._src.nn.linear import Identity +from serket._src.utils import IsInstance, Range, validate_spatial_ndim def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax.Array: diff --git a/serket/image/filter.py b/serket/_src/image/filter.py similarity index 98% rename from serket/image/filter.py rename to serket/_src/image/filter.py index a336652..543d4dd 100644 --- a/serket/image/filter.py +++ b/serket/_src/image/filter.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +21,9 @@ from typing_extensions import Annotated import serket as sk -from serket.nn.convolution import fft_conv_general_dilated -from serket.nn.initialization import DType -from serket.utils import ( +from serket._src.nn.convolution import fft_conv_general_dilated +from serket._src.nn.initialization import DType +from serket._src.utils import ( generate_conv_dim_numbers, positive_int_cb, resolve_string_padding, diff --git a/serket/image/geometric.py b/serket/_src/image/geometric.py similarity index 99% rename from serket/image/geometric.py rename to serket/_src/image/geometric.py index cfae1bb..690393d 100644 --- a/serket/image/geometric.py +++ b/serket/_src/image/geometric.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ from jax.scipy.ndimage import map_coordinates import serket as sk -from serket.custom_transform import tree_eval -from serket.nn.linear import Identity -from serket.utils import IsInstance, validate_spatial_ndim +from serket._src.custom_transform import tree_eval +from serket._src.nn.linear import Identity +from serket._src.utils import IsInstance, validate_spatial_ndim def affine(image, matrix): diff --git a/serket/_src/nn/__init__.py b/serket/_src/nn/__init__.py new file mode 100644 index 0000000..dbf0b04 --- /dev/null +++ b/serket/_src/nn/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serket/nn/activation.py b/serket/_src/nn/activation.py similarity index 99% rename from serket/nn/activation.py rename to serket/_src/nn/activation.py index bfbef06..a88e443 100644 --- a/serket/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from jax import lax import serket as sk -from serket.utils import IsInstance, Range, ScalarLike +from serket._src.utils import IsInstance, Range, ScalarLike T = TypeVar("T") diff --git a/serket/nn/attention.py b/serket/_src/nn/attention.py similarity index 98% rename from serket/nn/attention.py rename to serket/_src/nn/attention.py index 32903e9..6736a81 100644 --- a/serket/nn/attention.py +++ b/serket/_src/nn/attention.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ from typing_extensions import Annotated import serket as sk -from serket.nn.initialization import InitType -from serket.utils import maybe_lazy_call, maybe_lazy_init +from serket._src.nn.initialization import InitType +from serket._src.utils import maybe_lazy_call, maybe_lazy_init """Defines attention layers.""" diff --git a/serket/nn/containers.py b/serket/_src/nn/containers.py similarity index 98% rename from serket/nn/containers.py rename to serket/_src/nn/containers.py index b6fb743..d51b671 100644 --- a/serket/nn/containers.py +++ b/serket/_src/nn/containers.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ import jax.random as jr import serket as sk -from serket.custom_transform import tree_eval -from serket.utils import Range +from serket._src.custom_transform import tree_eval +from serket._src.utils import Range def sequential( diff --git a/serket/nn/convolution.py b/serket/_src/nn/convolution.py similarity index 99% rename from serket/nn/convolution.py rename to serket/_src/nn/convolution.py index 9d0615f..eae3efa 100644 --- a/serket/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,8 +27,8 @@ from typing_extensions import Annotated import serket as sk -from serket.nn.initialization import DType, InitType, resolve_init -from serket.utils import ( +from serket._src.nn.initialization import DType, InitType, resolve_init +from serket._src.utils import ( DilationType, KernelSizeType, PaddingType, diff --git a/serket/nn/dropout.py b/serket/_src/nn/dropout.py similarity index 98% rename from serket/nn/dropout.py rename to serket/_src/nn/dropout.py index eea5192..0ec141b 100644 --- a/serket/nn/dropout.py +++ b/serket/_src/nn/dropout.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,9 +23,8 @@ import jax.random as jr import serket as sk -from serket.custom_transform import tree_eval -from serket.nn.linear import Identity -from serket.utils import ( +from serket._src.custom_transform import tree_eval +from serket._src.utils import ( IsInstance, Range, canonicalize, @@ -434,5 +433,5 @@ def spatial_ndim(self) -> int: @tree_eval.def_eval(RandomCutout2D) @tree_eval.def_eval(GeneralDropout) @tree_eval.def_eval(DropoutND) -def dropout_evaluation(_) -> Identity: - return Identity() +def dropout_evaluation(_) -> sk.nn.Identity: + return sk.nn.Identity() diff --git a/serket/nn/initialization.py b/serket/_src/nn/initialization.py similarity index 97% rename from serket/nn/initialization.py rename to serket/_src/nn/initialization.py index 61bed3a..e503fbb 100644 --- a/serket/nn/initialization.py +++ b/serket/_src/nn/initialization.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ import functools as ft from collections.abc import Callable as ABCCallable -from typing import Callable,Any, Literal, Tuple, Union, get_args +from typing import Any, Callable, Literal, Tuple, Union, get_args import jax import jax.nn.initializers as ji diff --git a/serket/nn/linear.py b/serket/_src/nn/linear.py similarity index 99% rename from serket/nn/linear.py rename to serket/_src/nn/linear.py index b90881a..b18d530 100644 --- a/serket/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,13 +22,13 @@ import jax.random as jr import serket as sk -from serket.nn.activation import ( +from serket._src.nn.activation import ( ActivationFunctionType, ActivationType, resolve_activation, ) -from serket.nn.initialization import DType, InitType, resolve_init -from serket.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb +from serket._src.nn.initialization import DType, InitType, resolve_init +from serket._src.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb T = TypeVar("T") diff --git a/serket/nn/normalization.py b/serket/_src/nn/normalization.py similarity index 99% rename from serket/nn/normalization.py rename to serket/_src/nn/normalization.py index 2f83cae..6654121 100644 --- a/serket/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ from jax.custom_batching import custom_vmap import serket as sk -from serket.custom_transform import tree_eval, tree_state -from serket.nn.initialization import DType, InitType, resolve_init -from serket.utils import ( +from serket._src.custom_transform import tree_eval, tree_state +from serket._src.nn.initialization import DType, InitType, resolve_init +from serket._src.utils import ( Range, ScalarLike, maybe_lazy_call, @@ -225,6 +225,7 @@ class GroupNorm(sk.TreeClass): Reference: https://nn.labml.ai/normalization/group_norm/index.html """ + eps: float = sk.field(on_setattr=[Range(0), ScalarLike()]) @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) diff --git a/serket/nn/pooling.py b/serket/_src/nn/pooling.py similarity index 99% rename from serket/nn/pooling.py rename to serket/_src/nn/pooling.py index e4b53d3..d9cdc87 100644 --- a/serket/nn/pooling.py +++ b/serket/_src/nn/pooling.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import kernex as kex import serket as sk -from serket.utils import ( +from serket._src.utils import ( KernelSizeType, PaddingType, StridesType, diff --git a/serket/nn/recurrent.py b/serket/_src/nn/recurrent.py similarity index 99% rename from serket/nn/recurrent.py rename to serket/_src/nn/recurrent.py index 5e6f66c..9d6312f 100644 --- a/serket/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ import jax.tree_util as jtu import serket as sk -from serket.custom_transform import tree_state -from serket.nn.activation import ActivationType, resolve_activation -from serket.nn.initialization import DType, InitType -from serket.utils import ( +from serket._src.custom_transform import tree_state +from serket._src.nn.activation import ActivationType, resolve_activation +from serket._src.nn.initialization import DType, InitType +from serket._src.utils import ( DilationType, KernelSizeType, PaddingType, diff --git a/serket/nn/reshape.py b/serket/_src/nn/reshape.py similarity index 99% rename from serket/nn/reshape.py rename to serket/_src/nn/reshape.py index 5dacfd1..d443998 100644 --- a/serket/nn/reshape.py +++ b/serket/_src/nn/reshape.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,9 +23,9 @@ import jax.random as jr import serket as sk -from serket.custom_transform import tree_eval -from serket.nn.linear import Identity -from serket.utils import ( +from serket._src.custom_transform import tree_eval +from serket._src.nn.linear import Identity +from serket._src.utils import ( IsInstance, canonicalize, delayed_canonicalize_padding, diff --git a/serket/utils.py b/serket/_src/utils.py similarity index 99% rename from serket/utils.py rename to serket/_src/utils.py index 615be5b..fceb1cf 100644 --- a/serket/utils.py +++ b/serket/_src/utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/cluster/__init__.py b/serket/cluster/__init__.py index 9046273..885ffe3 100644 --- a/serket/cluster/__init__.py +++ b/serket/cluster/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,6 @@ # limitations under the License. -from .kmeans import KMeans +from serket._src.cluster.kmeans import KMeans __all__ = ["KMeans"] diff --git a/serket/experimental/__init__.py b/serket/experimental/__init__.py index afa43b8..dbf0b04 100644 --- a/serket/experimental/__init__.py +++ b/serket/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/image/__init__.py b/serket/image/__init__.py index 6335842..efee2de 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .augment import ( +from serket._src.image.augment import ( AdjustContrast2D, JigSaw2D, PixelShuffle2D, Posterize2D, RandomContrast2D, ) -from .filter import ( +from serket._src.image.filter import ( AvgBlur2D, FFTAvgBlur2D, FFTFilter2D, @@ -27,7 +27,7 @@ Filter2D, GaussianBlur2D, ) -from .geometric import ( +from serket._src.image.geometric import ( HorizontalFlip2D, HorizontalShear2D, HorizontalTranslate2D, diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 2cdf47e..9ac2c6c 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .activation import ( +from serket._src.nn.activation import ( ELU, GELU, GLU, @@ -44,9 +44,9 @@ TanhShrink, ThresholdedReLU, ) -from .attention import MultiHeadAttention -from .containers import RandomApply, RandomChoice, Sequential -from .convolution import ( +from serket._src.nn.attention import MultiHeadAttention +from serket._src.nn.containers import RandomApply, RandomChoice, Sequential +from serket._src.nn.convolution import ( Conv1D, Conv1DLocal, Conv1DTranspose, @@ -75,7 +75,7 @@ SeparableFFTConv2D, SeparableFFTConv3D, ) -from .dropout import ( +from serket._src.nn.dropout import ( Dropout, Dropout1D, Dropout2D, @@ -84,9 +84,17 @@ RandomCutout1D, RandomCutout2D, ) -from .linear import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear -from .normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm -from .pooling import ( +from serket._src.nn.linear import ( + FNN, + MLP, + Embedding, + GeneralLinear, + Identity, + Linear, + Multilinear, +) +from serket._src.nn.normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm +from serket._src.nn.pooling import ( AdaptiveAvgPool1D, AdaptiveAvgPool2D, AdaptiveAvgPool3D, @@ -109,7 +117,7 @@ MaxPool2D, MaxPool3D, ) -from .recurrent import ( +from serket._src.nn.recurrent import ( ConvGRU1DCell, ConvGRU2DCell, ConvGRU3DCell, @@ -128,7 +136,7 @@ ScanRNN, SimpleRNNCell, ) -from .reshape import ( +from serket._src.nn.reshape import ( Crop1D, Crop2D, Crop3D, @@ -151,7 +159,7 @@ Upsample3D, ) -__all__ = ( +__all__ = [ # activation "ELU", "GELU", @@ -299,4 +307,4 @@ "Upsample1D", "Upsample2D", "Upsample3D", -) +] diff --git a/tests/__init__.py b/tests/__init__.py index afa43b8..dbf0b04 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_activation.py b/tests/test_activation.py index 5de3e1f..fa49109 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ import numpy.testing as npt import pytest -from serket.nn.activation import ( +from serket._src.nn.activation import ( ELU, GELU, GLU, diff --git a/tests/test_attention.py b/tests/test_attention.py index 8628017..4f43418 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_clustering.py b/tests/test_clustering.py index b1e269f..034007c 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_containers.py b/tests/test_containers.py index f809337..626773c 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_conv.py b/tests/test_conv.py index c969181..1fc89ba 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 9a27e7b..d87f20a 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,33 +16,7 @@ import numpy.testing as npt import pytest -from serket.nn import ( - Conv1D, - Conv1DTranspose, - Conv2D, - Conv2DTranspose, - Conv3D, - Conv3DTranspose, - DepthwiseConv1D, - DepthwiseConv2D, - DepthwiseConv3D, - DepthwiseFFTConv1D, - DepthwiseFFTConv2D, - DepthwiseFFTConv3D, - FFTConv1D, - FFTConv1DTranspose, - FFTConv2D, - FFTConv2DTranspose, - FFTConv3D, - FFTConv3DTranspose, - SeparableConv1D, - SeparableConv2D, - SeparableConv3D, - SeparableFFTConv1D, - SeparableFFTConv2D, - SeparableFFTConv3D, -) -from serket.nn.convolution import Conv1DLocal, Conv2DLocal +import serket as sk def test_fft_conv1d(): @@ -75,7 +49,7 @@ def test_fft_conv1d(): ] ) - layer = FFTConv1D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) + layer = sk.nn.FFTConv1D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) @@ -257,7 +231,7 @@ def test_fft_conv2d(): ] ) - ls = FFTConv2D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) + ls = sk.nn.FFTConv2D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) ls = ls.at["weight"].set(w) ls = ls.at["bias"].set(b) @@ -522,7 +496,7 @@ def test_fft_conv3d(): ] ) - ls = FFTConv3D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) + ls = sk.nn.FFTConv3D(2, 6, kernel_size=3, padding=0, strides=1, groups=2) ls = ls.at["weight"].set(w) ls = ls.at["bias"].set(b) @@ -531,44 +505,56 @@ def test_fft_conv3d(): def test_fft_conv(): x = jnp.ones([10, 1]) - npt.assert_allclose(FFTConv1D(10, 1, 3)(x), Conv1D(10, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv1D(10, 1, 3)(x), sk.nn.Conv1D(10, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([7, 8]) - npt.assert_allclose(FFTConv1D(7, 1, 3)(x), Conv1D(7, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv1D(7, 1, 3)(x), sk.nn.Conv1D(7, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([10, 1]) npt.assert_allclose( - FFTConv1D(10, 1, 3, dilation=2)(x), - Conv1D(10, 1, 3, dilation=2)(x), + sk.nn.FFTConv1D(10, 1, 3, dilation=2)(x), + sk.nn.Conv1D(10, 1, 3, dilation=2)(x), atol=1e-5, ) x = jnp.ones([10, 1, 1]) - npt.assert_allclose(FFTConv2D(10, 1, 3)(x), Conv2D(10, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv2D(10, 1, 3)(x), sk.nn.Conv2D(10, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([7, 8, 9]) - npt.assert_allclose(FFTConv2D(7, 1, 3)(x), Conv2D(7, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv2D(7, 1, 3)(x), sk.nn.Conv2D(7, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([10, 10, 10]) npt.assert_allclose( - FFTConv2D(10, 1, 3, dilation=3)(x), - Conv2D(10, 1, 3, dilation=3)(x), + sk.nn.FFTConv2D(10, 1, 3, dilation=3)(x), + sk.nn.Conv2D(10, 1, 3, dilation=3)(x), atol=1e-5, ) x = jnp.ones([7, 8, 9]) npt.assert_allclose( - FFTConv2D(7, 1, 3, dilation=2)(x), - Conv2D(7, 1, 3, dilation=2)(x), + sk.nn.FFTConv2D(7, 1, 3, dilation=2)(x), + sk.nn.Conv2D(7, 1, 3, dilation=2)(x), atol=1e-5, ) x = jnp.ones([10, 1, 1, 1]) - npt.assert_allclose(FFTConv3D(10, 1, 3)(x), Conv3D(10, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv3D(10, 1, 3)(x), sk.nn.Conv3D(10, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([7, 8, 9, 10]) - npt.assert_allclose(FFTConv3D(7, 1, 3)(x), Conv3D(7, 1, 3)(x), atol=1e-4) + npt.assert_allclose( + sk.nn.FFTConv3D(7, 1, 3)(x), sk.nn.Conv3D(7, 1, 3)(x), atol=1e-4 + ) x = jnp.ones([7, 8, 9, 10]) npt.assert_allclose( - FFTConv3D(7, 1, 3, dilation=(1, 2, 3))(x), - Conv3D(7, 1, 3, dilation=(1, 2, 3))(x), + sk.nn.FFTConv3D(7, 1, 3, dilation=(1, 2, 3))(x), + sk.nn.Conv3D(7, 1, 3, dilation=(1, 2, 3))(x), atol=1e-5, ) @@ -576,42 +562,46 @@ def test_fft_conv(): def test_depthwise_fft_conv(): x = jnp.ones([10, 1]) npt.assert_allclose( - DepthwiseFFTConv1D(10, 3)(x), DepthwiseConv1D(10, 3)(x), atol=1e-4 + sk.nn.DepthwiseFFTConv1D(10, 3)(x), sk.nn.DepthwiseConv1D(10, 3)(x), atol=1e-4 ) x = jnp.ones([10, 1, 1]) npt.assert_allclose( - DepthwiseFFTConv2D(10, 3)(x), DepthwiseConv2D(10, 3)(x), atol=1e-4 + sk.nn.DepthwiseFFTConv2D(10, 3)(x), sk.nn.DepthwiseConv2D(10, 3)(x), atol=1e-4 ) x = jnp.ones([10, 1, 1, 1]) npt.assert_allclose( - DepthwiseFFTConv3D(10, 3)(x), DepthwiseConv3D(10, 3)(x), atol=1e-4 + sk.nn.DepthwiseFFTConv3D(10, 3)(x), sk.nn.DepthwiseConv3D(10, 3)(x), atol=1e-4 ) def test_conv_transpose(): x = jnp.ones([10, 4]) npt.assert_allclose( - Conv1DTranspose(10, 4, 3)(x), FFTConv1DTranspose(10, 4, 3)(x), atol=1e-4 + sk.nn.Conv1DTranspose(10, 4, 3)(x), + sk.nn.FFTConv1DTranspose(10, 4, 3)(x), + atol=1e-4, ) x = jnp.ones([10, 4]) npt.assert_allclose( - Conv1DTranspose(10, 4, 3, dilation=2)(x), - FFTConv1DTranspose(10, 4, 3, dilation=2)(x), + sk.nn.Conv1DTranspose(10, 4, 3, dilation=2)(x), + sk.nn.FFTConv1DTranspose(10, 4, 3, dilation=2)(x), atol=1e-5, ) x = jnp.ones([10, 4, 4]) npt.assert_allclose( - Conv2DTranspose(10, 4, 3)(x), FFTConv2DTranspose(10, 4, 3)(x), atol=1e-4 + sk.nn.Conv2DTranspose(10, 4, 3)(x), + sk.nn.FFTConv2DTranspose(10, 4, 3)(x), + atol=1e-4, ) x = jnp.ones([10, 4, 4, 4]) npt.assert_allclose( - Conv3DTranspose(10, 4, 3, dilation=2)(x), - FFTConv3DTranspose(10, 4, 3, dilation=2)(x), + sk.nn.Conv3DTranspose(10, 4, 3, dilation=2)(x), + sk.nn.FFTConv3DTranspose(10, 4, 3, dilation=2)(x), atol=1e-5, ) @@ -619,22 +609,28 @@ def test_conv_transpose(): def test_separable_conv(): x = jnp.ones([10, 4]) npt.assert_allclose( - SeparableConv1D(10, 4, 3)(x), SeparableFFTConv1D(10, 4, 3)(x), atol=1e-4 + sk.nn.SeparableConv1D(10, 4, 3)(x), + sk.nn.SeparableFFTConv1D(10, 4, 3)(x), + atol=1e-4, ) x = jnp.ones([10, 4, 4]) npt.assert_allclose( - SeparableConv2D(10, 4, 3)(x), SeparableFFTConv2D(10, 4, 3)(x), atol=1e-4 + sk.nn.SeparableConv2D(10, 4, 3)(x), + sk.nn.SeparableFFTConv2D(10, 4, 3)(x), + atol=1e-4, ) x = jnp.ones([10, 4, 4, 4]) npt.assert_allclose( - SeparableConv3D(10, 4, 3)(x), SeparableFFTConv3D(10, 4, 3)(x), atol=1e-4 + sk.nn.SeparableConv3D(10, 4, 3)(x), + sk.nn.SeparableFFTConv3D(10, 4, 3)(x), + atol=1e-4, ) def test_conv1D(): - layer = Conv1D( + layer = sk.nn.Conv1D( in_features=1, out_features=1, kernel_size=2, @@ -646,7 +642,7 @@ def test_conv1D(): x = jnp.arange(1, 11).reshape([1, 10]).astype(jnp.float32) npt.assert_allclose(layer(x), jnp.array([[3, 5, 7, 9, 11, 13, 15, 17, 19, 10]])) - layer = Conv1D( + layer = sk.nn.Conv1D( in_features=1, out_features=1, kernel_size=2, padding="same", strides=2 ) layer = layer.at["weight"].set(jnp.ones([1, 1, 2], dtype=jnp.float32)) @@ -654,7 +650,7 @@ def test_conv1D(): npt.assert_allclose(layer(x), jnp.array([[3, 7, 11, 15, 19]])) - layer = Conv1D( + layer = sk.nn.Conv1D( in_features=1, out_features=1, kernel_size=2, padding="VALID", strides=1 ) layer = layer.at["weight"].set(jnp.ones([1, 1, 2], dtype=jnp.float32)) @@ -716,19 +712,19 @@ def test_conv1D(): ] ) - layer = Conv1D(1, 2, 3, padding=2, strides=1, dilation=2) + layer = sk.nn.Conv1D(1, 2, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) npt.assert_allclose(layer(x), y) - layer = Conv1D(1, 2, 3, padding=2, strides=1, dilation=2, bias_init=None) + layer = sk.nn.Conv1D(1, 2, 3, padding=2, strides=1, dilation=2, bias_init=None) layer = layer.at["weight"].set(w) npt.assert_allclose(layer(x), y) def test_conv2D(): - layer = Conv2D(in_features=1, out_features=1, kernel_size=2) + layer = sk.nn.Conv2D(in_features=1, out_features=1, kernel_size=2) layer = layer.at["weight"].set(jnp.ones([1, 1, 2, 2], dtype=jnp.float32)) # OIHW x = jnp.arange(1, 17).reshape([1, 4, 4]).astype(jnp.float32) @@ -739,7 +735,7 @@ def test_conv2D(): ), ) - layer = Conv2D(in_features=1, out_features=1, kernel_size=2, padding="VALID") + layer = sk.nn.Conv2D(in_features=1, out_features=1, kernel_size=2, padding="VALID") layer = layer.at["weight"].set(jnp.ones([1, 1, 2, 2], dtype=jnp.float32)) x = jnp.arange(1, 17).reshape([1, 4, 4]).astype(jnp.float32) @@ -754,7 +750,7 @@ def test_conv2D(): ), ) - layer = Conv2D(1, 2, 2, padding="same", strides=2) + layer = sk.nn.Conv2D(1, 2, 2, padding="same", strides=2) layer = layer.at["weight"].set(jnp.ones([2, 1, 2, 2], dtype=jnp.float32)) x = jnp.arange(1, 17).reshape([1, 4, 4]).astype(jnp.float32) @@ -768,7 +764,7 @@ def test_conv2D(): ), ) - layer = Conv2D(1, 2, 2, padding="same", strides=1) + layer = sk.nn.Conv2D(1, 2, 2, padding="same", strides=1) layer = layer.at["weight"].set(jnp.ones([2, 1, 2, 2], dtype=jnp.float32)) x = jnp.arange(1, 17).reshape([1, 4, 4]).astype(jnp.float32) @@ -786,7 +782,7 @@ def test_conv2D(): ), ) - layer = Conv2D(1, 2, 2, padding="same", strides=1, bias_init=None) + layer = sk.nn.Conv2D(1, 2, 2, padding="same", strides=1, bias_init=None) layer = layer.at["weight"].set(jnp.ones([2, 1, 2, 2], dtype=jnp.float32)) x = jnp.arange(1, 17).reshape([1, 4, 4]).astype(jnp.float32) @@ -806,7 +802,7 @@ def test_conv2D(): def test_conv3D(): - layer = Conv3D(1, 3, 3) + layer = sk.nn.Conv3D(1, 3, 3) layer = layer.at["weight"].set(jnp.ones([3, 1, 3, 3, 3])) layer = layer.at["bias"].set(jnp.zeros([3, 1, 1, 1])) npt.assert_allclose( @@ -838,13 +834,15 @@ def test_conv1dtranspose(): b = jnp.array([[[0.0]]]) - layer = Conv1DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) + layer = sk.nn.Conv1DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) y = jnp.array([[0.27022034, 0.24495776, -0.00368674]]) npt.assert_allclose(layer(x), y, atol=1e-5) - layer = Conv1DTranspose(4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) + layer = sk.nn.Conv1DTranspose( + 4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None + ) layer = layer.at["weight"].set(w) y = jnp.array([[0.27022034, 0.24495776, -0.00368674]]) npt.assert_allclose(layer(x), y, atol=1e-5) @@ -898,7 +896,7 @@ def test_conv2dtranspose(): b = jnp.array([[[0.0]]]) - layer = Conv2DTranspose(3, 1, 3, padding=2, strides=1, dilation=2) + layer = sk.nn.Conv2DTranspose(3, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) @@ -916,7 +914,9 @@ def test_conv2dtranspose(): npt.assert_allclose(layer(x), y, atol=1e-5) - layer = Conv2DTranspose(3, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) + layer = sk.nn.Conv2DTranspose( + 3, 1, 3, padding=2, strides=1, dilation=2, bias_init=None + ) layer = layer.at["weight"].set(w) @@ -1085,7 +1085,7 @@ def test_conv3dtranspose(): b = jnp.array([[[[0.0]]]]) - layer = Conv3DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) + layer = sk.nn.Conv3DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) @@ -1113,7 +1113,9 @@ def test_conv3dtranspose(): npt.assert_allclose(y, layer(x), atol=1e-5) - layer = Conv3DTranspose(4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) + layer = sk.nn.Conv3DTranspose( + 4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None + ) layer = layer.at["weight"].set(w) y = jnp.array( @@ -1182,7 +1184,7 @@ def test_depthwise_conv1d(): ] ) - layer = DepthwiseConv1D(in_features=5, kernel_size=3, depth_multiplier=2) + layer = sk.nn.DepthwiseConv1D(in_features=5, kernel_size=3, depth_multiplier=2) layer = layer.at["weight"].set(w) npt.assert_allclose(y, layer(x), atol=1e-5) @@ -1246,7 +1248,7 @@ def test_depthwise_conv2d(): ] ) - layer = DepthwiseConv2D(2, 3) + layer = sk.nn.DepthwiseConv2D(2, 3) layer = layer.at["weight"].set(w) npt.assert_allclose(y, layer(x), atol=1e-5) @@ -1273,7 +1275,7 @@ def test_seperable_conv1d(): y = jnp.array([[0.5005436, 0.44051802, 0.5662357, 0.13085097, -0.22720146]]) - layer = SeparableConv1D( + layer = sk.nn.SeparableConv1D( in_features=2, out_features=1, kernel_size=3, depth_multiplier=2 ) @@ -1350,7 +1352,7 @@ def test_seperable_conv2d(): ] ) - layer_jax = SeparableConv2D( + layer_jax = sk.nn.SeparableConv2D( in_features=2, out_features=1, kernel_size=3, depth_multiplier=2 ) @@ -1480,7 +1482,7 @@ def test_conv1d_local(): ] ) - layer = Conv1DLocal( + layer = sk.nn.Conv1DLocal( in_features=2, out_features=1, kernel_size=3, @@ -1518,7 +1520,7 @@ def test_conv2d_local(): x = jnp.ones((2, 4, 4)) - layer = Conv2DLocal(2, 1, (3, 2), in_size=(4, 4), padding="valid", strides=2) + layer = sk.nn.Conv2DLocal(2, 1, (3, 2), in_size=(4, 4), padding="valid", strides=2) layer = layer.at["weight"].set(w) npt.assert_allclose(y, layer(x), atol=1e-5) @@ -1526,91 +1528,91 @@ def test_conv2d_local(): def test_in_feature_error(): with pytest.raises(ValueError): - Conv1D(0, 1, 2) + sk.nn.Conv1D(0, 1, 2) with pytest.raises(ValueError): - Conv2D(0, 1, 2) + sk.nn.Conv2D(0, 1, 2) with pytest.raises(ValueError): - Conv3D(0, 1, 2) + sk.nn.Conv3D(0, 1, 2) with pytest.raises(ValueError): - Conv1DLocal(0, 1, 2, in_size=(2,)) + sk.nn.Conv1DLocal(0, 1, 2, in_size=(2,)) with pytest.raises(ValueError): - Conv2DLocal(0, 1, 2, in_size=(2, 2)) + sk.nn.Conv2DLocal(0, 1, 2, in_size=(2, 2)) with pytest.raises(ValueError): - Conv1DTranspose(0, 1, 3) + sk.nn.Conv1DTranspose(0, 1, 3) with pytest.raises(ValueError): - Conv2DTranspose(0, 1, 3) + sk.nn.Conv2DTranspose(0, 1, 3) with pytest.raises(ValueError): - Conv3DTranspose(0, 1, 3) + sk.nn.Conv3DTranspose(0, 1, 3) with pytest.raises(ValueError): - DepthwiseConv1D(0, 1) + sk.nn.DepthwiseConv1D(0, 1) with pytest.raises(ValueError): - DepthwiseConv2D(0, 1) + sk.nn.DepthwiseConv2D(0, 1) def test_out_feature_error(): with pytest.raises(ValueError): - Conv1D(1, 0, 2) + sk.nn.Conv1D(1, 0, 2) with pytest.raises(ValueError): - Conv2D(1, 0, 2) + sk.nn.Conv2D(1, 0, 2) with pytest.raises(ValueError): - Conv3D(1, 0, 2) + sk.nn.Conv3D(1, 0, 2) with pytest.raises(ValueError): - Conv1DLocal(1, 0, 2, in_size=(2,)) + sk.nn.Conv1DLocal(1, 0, 2, in_size=(2,)) with pytest.raises(ValueError): - Conv2DLocal(1, 0, 2, in_size=(2, 2)) + sk.nn.Conv2DLocal(1, 0, 2, in_size=(2, 2)) with pytest.raises(ValueError): - Conv1DTranspose(1, 0, 3) + sk.nn.Conv1DTranspose(1, 0, 3) with pytest.raises(ValueError): - Conv2DTranspose(1, 0, 3) + sk.nn.Conv2DTranspose(1, 0, 3) with pytest.raises(ValueError): - Conv3DTranspose(1, 0, 3) + sk.nn.Conv3DTranspose(1, 0, 3) def test_groups_error(): with pytest.raises(ValueError): - Conv1D(1, 1, 2, groups=0) + sk.nn.Conv1D(1, 1, 2, groups=0) with pytest.raises(ValueError): - Conv2D(1, 1, 2, groups=0) + sk.nn.Conv2D(1, 1, 2, groups=0) with pytest.raises(ValueError): - Conv3D(1, 1, 2, groups=0) + sk.nn.Conv3D(1, 1, 2, groups=0) with pytest.raises(ValueError): - Conv1DTranspose(1, 1, 3, groups=0) + sk.nn.Conv1DTranspose(1, 1, 3, groups=0) with pytest.raises(ValueError): - Conv2DTranspose(1, 1, 3, groups=0) + sk.nn.Conv2DTranspose(1, 1, 3, groups=0) with pytest.raises(ValueError): - Conv3DTranspose(1, 1, 3, groups=0) + sk.nn.Conv3DTranspose(1, 1, 3, groups=0) @pytest.mark.parametrize( "layer,array,expected_shape", [ - [Conv1D, jnp.ones([10, 3]), (1, 3)], - [Conv2D, jnp.ones([10, 3, 3]), (1, 3, 3)], - [Conv3D, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], - [Conv1DTranspose, jnp.ones([10, 3]), (1, 3)], - [Conv2DTranspose, jnp.ones([10, 3, 3]), (1, 3, 3)], - [Conv3DTranspose, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], + [sk.nn.Conv1D, jnp.ones([10, 3]), (1, 3)], + [sk.nn.Conv2D, jnp.ones([10, 3, 3]), (1, 3, 3)], + [sk.nn.Conv3D, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], + [sk.nn.Conv1DTranspose, jnp.ones([10, 3]), (1, 3)], + [sk.nn.Conv2DTranspose, jnp.ones([10, 3, 3]), (1, 3, 3)], + [sk.nn.Conv3DTranspose, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], ], ) def test_lazy_conv(layer, array, expected_shape): diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 37575d1..d970990 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,34 +17,33 @@ import pytest import serket as sk -from serket.nn import Dropout, RandomCutout1D, RandomCutout2D def test_dropout(): x = jnp.array([1, 2, 3, 4, 5]) - layer = Dropout(1.0) + layer = sk.nn.Dropout(1.0) npt.assert_allclose(layer(x), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0])) layer = layer.at["drop_rate"].set(0.0, is_leaf=sk.is_frozen) npt.assert_allclose(layer(x), x) with pytest.raises(ValueError): - Dropout(1.1) + sk.nn.Dropout(1.1) with pytest.raises(ValueError): - Dropout(-0.1) + sk.nn.Dropout(-0.1) def test_random_cutout_1d(): - layer = RandomCutout1D(3, 1) + layer = sk.nn.RandomCutout1D(3, 1) x = jnp.ones((1, 10)) y = layer(x) npt.assert_equal(y.shape, (1, 10)) def test_random_cutout_2d(): - layer = RandomCutout2D((3, 3), 1) + layer = sk.nn.RandomCutout2D((3, 3), 1) x = jnp.ones((1, 10, 10)) y = layer(x) npt.assert_equal(y.shape, (1, 10, 10)) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index ac806cb..4ea98a8 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,38 +20,10 @@ import pytest import serket as sk -from serket.image.augment import ( - AdjustContrast2D, - JigSaw2D, - PixelShuffle2D, - RandomContrast2D, -) -from serket.image.filter import ( - AvgBlur2D, - FFTAvgBlur2D, - FFTFilter2D, - FFTGaussianBlur2D, - Filter2D, - GaussianBlur2D, -) -from serket.image.geometric import ( - HorizontalFlip2D, - HorizontalShear2D, - HorizontalTranslate2D, - Pixelate2D, - RandomHorizontalShear2D, - RandomRotate2D, - RandomVerticalShear2D, - Rotate2D, - Solarize2D, - VerticalFlip2D, - VerticalShear2D, - VerticalTranslate2D, -) def test_AvgBlur2D(): - x = AvgBlur2D(3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) + x = sk.image.AvgBlur2D(3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) y = [ [ @@ -66,12 +38,14 @@ def test_AvgBlur2D(): npt.assert_allclose(x, y, atol=1e-5) # test with - z = FFTAvgBlur2D(3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) + z = sk.image.FFTAvgBlur2D(3)( + jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32) + ) npt.assert_allclose(y, z, atol=1e-5) def test_GaussBlur2D(): - layer = GaussianBlur2D(kernel_size=3, sigma=1.0) + layer = sk.image.GaussianBlur2D(kernel_size=3, sigma=1.0) x = jnp.ones([1, 5, 5]) npt.assert_allclose( @@ -91,47 +65,29 @@ def test_GaussBlur2D(): ) with pytest.raises(ValueError): - GaussianBlur2D(0, sigma=1.0) + sk.image.GaussianBlur2D(0, sigma=1.0) - z = FFTGaussianBlur2D(3, sigma=1.0)(jnp.ones([1, 5, 5])).astype(jnp.float32) + z = sk.image.FFTGaussianBlur2D(3, sigma=1.0)(jnp.ones([1, 5, 5])).astype( + jnp.float32 + ) npt.assert_allclose(layer(x), z, atol=1e-5) -# def test_lazy_blur(): -# layer = GaussianBlur2D(in_features=None, kernel_size=3, sigma=1.0) -# assert layer(jnp.ones([10, 5, 5])).shape == (10, 5, 5) - -# layer = AvgBlur2D(None, 3) -# assert layer(jnp.ones([10, 5, 5])).shape == (10, 5, 5) - -# layer = Filter2D(None, jnp.ones([3, 3])) -# assert layer(jnp.ones([10, 5, 5])).shape == (10, 5, 5) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(GaussianBlur2D(in_features=None, kernel_size=3, sigma=1.0))(jnp.ones([10, 5, 5])) # fmt: skip - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(AvgBlur2D(in_features=None, kernel_size=3))(jnp.ones([10, 5, 5])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Filter2D(in_features=None, kernel=jnp.ones([4, 4])))(jnp.ones([10, 5, 5])) # fmt: skip - - def test_filter2d(): - layer = Filter2D(kernel=jnp.ones([3, 3]) / 9.0) + layer = sk.image.Filter2D(kernel=jnp.ones([3, 3]) / 9.0) x = jnp.ones([1, 5, 5]) - npt.assert_allclose(AvgBlur2D(3)(x), layer(x), atol=1e-4) + npt.assert_allclose(sk.image.AvgBlur2D(3)(x), layer(x), atol=1e-4) - layer2 = FFTFilter2D(kernel=jnp.ones([3, 3]) / 9.0) + layer2 = sk.image.FFTFilter2D(kernel=jnp.ones([3, 3]) / 9.0) npt.assert_allclose(layer(x), layer2(x), atol=1e-4) def test_solarize2d(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = Solarize2D(threshold=10, max_val=25) + layer = sk.image.Solarize2D(threshold=10, max_val=25) npt.assert_allclose( layer(x), jnp.array( @@ -150,7 +106,7 @@ def test_solarize2d(): def test_horizontal_translate(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = HorizontalTranslate2D(2) + layer = sk.image.HorizontalTranslate2D(2) npt.assert_allclose( layer(x), jnp.array( @@ -166,7 +122,7 @@ def test_horizontal_translate(): ), ) - layer = HorizontalTranslate2D(-2) + layer = sk.image.HorizontalTranslate2D(-2) npt.assert_allclose( layer(x), jnp.array( @@ -182,14 +138,14 @@ def test_horizontal_translate(): ), ) - layer = HorizontalTranslate2D(0) + layer = sk.image.HorizontalTranslate2D(0) npt.assert_allclose(layer(x), x) def test_vertical_translate(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = VerticalTranslate2D(2) + layer = sk.image.VerticalTranslate2D(2) npt.assert_allclose( layer(x), jnp.array( @@ -205,7 +161,7 @@ def test_vertical_translate(): ), ) - layer = VerticalTranslate2D(-2) + layer = sk.image.VerticalTranslate2D(-2) npt.assert_allclose( layer(x), jnp.array( @@ -221,14 +177,14 @@ def test_vertical_translate(): ), ) - layer = VerticalTranslate2D(0) + layer = sk.image.VerticalTranslate2D(0) npt.assert_allclose(layer(x), x) def test_jigsaw(): x = jnp.arange(1, 17).reshape(1, 4, 4) - layer = JigSaw2D(2) + layer = sk.image.JigSaw2D(2) npt.assert_allclose( layer(x), jnp.array([[[9, 10, 3, 4], [13, 14, 7, 8], [11, 12, 1, 2], [15, 16, 5, 6]]]), @@ -236,7 +192,7 @@ def test_jigsaw(): def test_rotate(): - layer = Rotate2D(90) + layer = sk.image.Rotate2D(90) x = jnp.arange(1, 26).reshape(1, 5, 5) # ccw rotation @@ -257,7 +213,7 @@ def test_rotate(): # random roate - layer = RandomRotate2D((90, 90)) + layer = sk.image.RandomRotate2D((90, 90)) npt.assert_allclose(layer(x), rot) npt.assert_allclose(sk.tree_eval(layer)(x), x) @@ -265,7 +221,7 @@ def test_rotate(): def test_horizontal_shear(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = HorizontalShear2D(45) + layer = sk.image.HorizontalShear2D(45) shear = jnp.array( [ [ @@ -280,7 +236,7 @@ def test_horizontal_shear(): npt.assert_allclose(layer(x), shear) - layer = RandomHorizontalShear2D((45, 45)) + layer = sk.image.RandomHorizontalShear2D((45, 45)) npt.assert_allclose(layer(x), shear) npt.assert_allclose(sk.tree_eval(layer)(x), x) @@ -288,7 +244,7 @@ def test_horizontal_shear(): def test_vertical_shear(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = VerticalShear2D(45) + layer = sk.image.VerticalShear2D(45) shear = jnp.array( [ [ @@ -303,7 +259,7 @@ def test_vertical_shear(): npt.assert_allclose(layer(x), shear) - layer = RandomVerticalShear2D((45, 45)) + layer = sk.image.RandomVerticalShear2D((45, 45)) npt.assert_allclose(layer(x), shear) npt.assert_allclose(sk.tree_eval(layer)(x), x) @@ -329,7 +285,7 @@ def test_posterize(): def test_pixelate(): x = jnp.arange(1, 26).reshape(1, 5, 5) - layer = Pixelate2D(1) + layer = sk.image.Pixelate2D(1) npt.assert_allclose(layer(x), x) @@ -363,7 +319,7 @@ def test_adjust_contrast_2d(): ] ) - npt.assert_allclose(AdjustContrast2D(contrast_factor=0.5)(x), y, atol=1e-5) + npt.assert_allclose(sk.image.AdjustContrast2D(contrast_factor=0.5)(x), y, atol=1e-5) def test_random_contrast_2d(): @@ -398,21 +354,23 @@ def test_random_contrast_2d(): ) npt.assert_allclose( - RandomContrast2D(contrast_range=(0.5, 1))(x, key=jax.random.PRNGKey(0)), + sk.image.RandomContrast2D(contrast_range=(0.5, 1))( + x, key=jax.random.PRNGKey(0) + ), y, atol=1e-5, ) def test_flip_left_right_2d(): - flip = HorizontalFlip2D() + flip = sk.image.HorizontalFlip2D() x = jnp.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) y = flip(x) npt.assert_allclose(y, jnp.array([[[3, 2, 1], [6, 5, 4], [9, 8, 7]]])) def test_flip_up_down_2d(): - flip = VerticalFlip2D() + flip = sk.image.VerticalFlip2D() x = jnp.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) y = flip(x) npt.assert_allclose(y, jnp.array([[[7, 8, 9], [4, 5, 6], [1, 2, 3]]])) @@ -428,13 +386,13 @@ def test_pixel_shuffle(): ] ) - ps = PixelShuffle2D(2) + ps = sk.image.PixelShuffle2D(2) y = jnp.array([0.08482574, 0.33432344, 1.9097648, -0.82606775]) npt.assert_allclose(ps(x)[0, 0], y, atol=1e-5) with pytest.raises(ValueError): - PixelShuffle2D(3)(jnp.ones([6, 4, 4])) + sk.image.PixelShuffle2D(3)(jnp.ones([6, 4, 4])) with pytest.raises(ValueError): - PixelShuffle2D(-3)(jnp.ones([9, 6, 4])) + sk.image.PixelShuffle2D(-3)(jnp.ones([9, 6, 4])) diff --git a/tests/test_init.py b/tests/test_init.py index 41a3055..6f6f48f 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import pytest -from serket.nn.initialization import def_init_entry +from serket._src.nn.initialization import def_init_entry def test_def_init_entry(): diff --git a/tests/test_linear.py b/tests/test_linear.py index fbc699b..214e457 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,11 +19,10 @@ import pytest import serket as sk -from serket.nn import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear def test_embed(): - table = Embedding(10, 3) + table = sk.nn.Embedding(10, 3) x = jnp.array([9]) npt.assert_allclose(table(x), jnp.array([[0.43810904, 0.35078037, 0.13254273]])) @@ -45,7 +44,7 @@ def update(NN, x, y): value, grad = loss_func(NN, x, y) return value, jtu.tree_map(lambda x, g: x - 1e-3 * g, NN, grad) - nn = FNN( + nn = sk.nn.FNN( [1, 128, 128, 1], act="relu", weight_init="he_normal", @@ -59,13 +58,13 @@ def update(NN, x, y): npt.assert_allclose(jnp.array(4.933563e-05), value, atol=1e-3) - layer = Linear(1, 1, bias_init=None) + layer = sk.nn.Linear(1, 1, bias_init=None) w = jnp.array([[-0.31568417]]) layer = layer.at["weight"].set(w) y = jnp.array([[-0.31568417]]) npt.assert_allclose(layer(jnp.array([[1.0]])), y) - layer = Linear(None, 1, bias_init="zeros") + layer = sk.nn.Linear(None, 1, bias_init="zeros") _, layer = layer.at["__call__"](jnp.ones([100, 2])) assert layer.in_features == (2,) @@ -87,12 +86,12 @@ def test_bilinear(): x1 = jnp.array([[-0.7676, -0.7205, -0.0586]]) x2 = jnp.array([[0.4600, -0.2508, 0.0115, 0.6155]]) y = jnp.array([[-0.3001916, 0.28336674]]) - layer = Multilinear((3, 4), 2, bias_init=None) + layer = sk.nn.Multilinear((3, 4), 2, bias_init=None) layer = layer.at["weight"].set(W) npt.assert_allclose(y, layer(x1, x2), atol=1e-4) - layer = Multilinear((3, 4), 2, bias_init="zeros") + layer = sk.nn.Multilinear((3, 4), 2, bias_init="zeros") layer = layer.at["weight"].set(W) npt.assert_allclose(y, layer(x1, x2), atol=1e-4) @@ -100,63 +99,63 @@ def test_bilinear(): def test_identity(): x = jnp.array([[1, 2, 3], [4, 5, 6]]) - layer = Identity() + layer = sk.nn.Identity() npt.assert_allclose(x, layer(x)) def test_multi_linear(): x = jnp.linspace(0, 1, 100)[:, None] - lhs = Linear(1, 10) - rhs = Multilinear((1,), 10) + lhs = sk.nn.Linear(1, 10) + rhs = sk.nn.Multilinear((1,), 10) npt.assert_allclose(lhs(x), rhs(x), atol=1e-4) with pytest.raises(ValueError): - Multilinear([1, 2], 10) + sk.nn.Multilinear([1, 2], 10) def test_general_linear(): x = jnp.ones([1, 2, 3, 4]) - layer = GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) + layer = sk.nn.GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) + layer = sk.nn.GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = GeneralLinear(in_features=(1, 2), in_axes=(0, -3), out_features=5) + layer = sk.nn.GeneralLinear(in_features=(1, 2), in_axes=(0, -3), out_features=5) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = GeneralLinear(in_features=(2, 3), in_axes=(1, -2), out_features=5) + layer = sk.nn.GeneralLinear(in_features=(2, 3), in_axes=(1, -2), out_features=5) assert layer(x).shape == (1, 4, 5) with pytest.raises(TypeError): - GeneralLinear(in_features=2, in_axes=(1, -2), out_features=5) + sk.nn.GeneralLinear(in_features=2, in_axes=(1, -2), out_features=5) with pytest.raises(TypeError): - GeneralLinear(in_features=(2, 3), in_axes=2, out_features=5) + sk.nn.GeneralLinear(in_features=(2, 3), in_axes=2, out_features=5) with pytest.raises(ValueError): - GeneralLinear(in_features=(1,), in_axes=(0, -3), out_features=5) + sk.nn.GeneralLinear(in_features=(1,), in_axes=(0, -3), out_features=5) with pytest.raises(TypeError): - GeneralLinear(in_features=(1, "s"), in_axes=(0, -3), out_features=5) + sk.nn.GeneralLinear(in_features=(1, "s"), in_axes=(0, -3), out_features=5) with pytest.raises(TypeError): - GeneralLinear(in_features=(1, 2), in_axes=(0, "s"), out_features=3) + sk.nn.GeneralLinear(in_features=(1, 2), in_axes=(0, "s"), out_features=3) def test_mlp(): x = jnp.linspace(0, 1, 100)[:, None] - fnn = FNN([1, 2, 1]) - mlp = MLP(in_features=1, out_features=1, hidden_size=2, num_hidden_layers=1) + fnn = sk.nn.FNN([1, 2, 1]) + mlp = sk.nn.MLP(in_features=1, out_features=1, hidden_size=2, num_hidden_layers=1) npt.assert_allclose(fnn(x), mlp(x), atol=1e-4) - fnn = FNN([1, 2, 2, 1], act=("relu", "tanh")) - mlp = MLP( + fnn = sk.nn.FNN([1, 2, 2, 1], act=("relu", "tanh")) + mlp = sk.nn.MLP( in_features=1, out_features=1, hidden_size=2, @@ -166,7 +165,7 @@ def test_mlp(): npt.assert_allclose(fnn(x), mlp(x), atol=1e-4) - layer = MLP( + layer = sk.nn.MLP( 1, 4, hidden_size=10, @@ -197,7 +196,7 @@ def test_mlp(): def test_fnn(): - layer = FNN([1, 2, 3, 4], act=("relu", "tanh")) + layer = sk.nn.FNN([1, 2, 3, 4], act=("relu", "tanh")) assert layer.act[0] is not layer.act[1] assert layer.layers[0] is not layer.layers[1] @@ -212,7 +211,7 @@ def test_fnn(): y = jax.nn.relu(y) y = y @ w3 - l1 = FNN([1, 5, 3, 4], act=("tanh", "relu"), bias_init=None) + l1 = sk.nn.FNN([1, 5, 3, 4], act=("tanh", "relu"), bias_init=None) l1 = l1.at["layers"][0]["weight"].set(w1) l1 = l1.at["layers"][1]["weight"].set(w2) l1 = l1.at["layers"][2]["weight"].set(w3) @@ -221,7 +220,7 @@ def test_fnn(): def test_fnn_mlp(): - fnn = FNN(layers=[2, 4, 4, 2], act="relu") - mlp = MLP(2, 2, hidden_size=4, num_hidden_layers=2, act="relu") + fnn = sk.nn.FNN(layers=[2, 4, 4, 2], act="relu") + mlp = sk.nn.MLP(2, 2, hidden_size=4, num_hidden_layers=2, act="relu") x = jax.random.normal(jax.random.PRNGKey(0), (10, 2)) npt.assert_allclose(fnn(x), mlp(x)) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 56b8bcd..3eff3ad 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,13 +20,12 @@ import pytest import serket as sk -from serket.nn import BatchNorm, GroupNorm, InstanceNorm, LayerNorm os.environ["KERAS_BACKEND"] = "jax" -def test_LayerNorm(): - layer = LayerNorm((5, 2), bias_init=None, weight_init=None) +def test_layer_norm(): + layer = sk.nn.LayerNorm((5, 2), bias_init=None, weight_init=None) x = jnp.array( [ @@ -51,7 +50,7 @@ def test_LayerNorm(): npt.assert_allclose(layer(x), y, atol=1e-5) -def test_InstanceNorm(): +def test_instance_norm(): x = jnp.array( [ [ @@ -98,11 +97,11 @@ def test_InstanceNorm(): ] ) - layer = InstanceNorm(in_features=3) + layer = sk.nn.InstanceNorm(in_features=3) npt.assert_allclose(layer(x), y, atol=1e-5) - layer = InstanceNorm(in_features=3, weight_init=None, bias_init=None) + layer = sk.nn.InstanceNorm(in_features=3, weight_init=None, bias_init=None) npt.assert_allclose(layer(x), y, atol=1e-5) @@ -202,18 +201,18 @@ def test_group_norm(): ] ) - layer = GroupNorm(in_features=6, groups=2) + layer = sk.nn.GroupNorm(in_features=6, groups=2) npt.assert_allclose(layer(x), y, atol=1e-5) with pytest.raises(ValueError): - layer = GroupNorm(in_features=6, groups=4) + layer = sk.nn.GroupNorm(in_features=6, groups=4) with pytest.raises(ValueError): - layer = GroupNorm(in_features=0, groups=1) + layer = sk.nn.GroupNorm(in_features=0, groups=1) with pytest.raises(ValueError): - layer = GroupNorm(in_features=-1, groups=0) + layer = sk.nn.GroupNorm(in_features=-1, groups=0) @pytest.mark.parametrize("axis", [0, 1, 2, 3]) @@ -231,7 +230,7 @@ def test_batchnorm(axis): for i in range(5): x_keras = bn_keras(x_keras, training=True) - bn_sk = BatchNorm( + bn_sk = sk.nn.BatchNorm( x_keras.shape[axis], momentum=0.5, eps=bn_keras.epsilon, diff --git a/tests/test_pooling.py b/tests/test_pooling.py index e1417a6..3c9fe59 100644 --- a/tests/test_pooling.py +++ b/tests/test_pooling.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ import jax.numpy as jnp import numpy.testing as npt -from serket.nn.pooling import ( +from serket._src.nn.pooling import ( AdaptiveAvgPool1D, AdaptiveAvgPool2D, AdaptiveAvgPool3D, diff --git a/tests/test_reshape.py b/tests/test_reshape.py index 1c5b704..ebd3e93 100644 --- a/tests/test_reshape.py +++ b/tests/test_reshape.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,74 +17,51 @@ import jax.numpy as jnp import numpy.testing as npt -from serket.nn import ( - Pad1D, - Pad2D, - Pad3D, - RandomZoom1D, - RandomZoom2D, - RandomZoom3D, - Resize1D, - Resize2D, - Resize3D, - Upsample1D, - Upsample2D, - Upsample3D, -) -from serket.nn.reshape import ( - Crop1D, - Crop2D, - Crop3D, - Flatten, - RandomCrop1D, - RandomCrop2D, - RandomCrop3D, - Unflatten, -) +import serket as sk def test_flatten(): - assert Flatten(0, 1)(jnp.ones([1, 2, 3, 4, 5])).shape == (2, 3, 4, 5) - assert Flatten(0, 2)(jnp.ones([1, 2, 3, 4, 5])).shape == (6, 4, 5) - assert Flatten(1, 2)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 6, 4, 5) - assert Flatten(-1, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 3, 4, 5) - assert Flatten(-2, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 3, 20) - assert Flatten(-3, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 60) + assert sk.nn.Flatten(0, 1)(jnp.ones([1, 2, 3, 4, 5])).shape == (2, 3, 4, 5) + assert sk.nn.Flatten(0, 2)(jnp.ones([1, 2, 3, 4, 5])).shape == (6, 4, 5) + assert sk.nn.Flatten(1, 2)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 6, 4, 5) + assert sk.nn.Flatten(-1, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 3, 4, 5) + assert sk.nn.Flatten(-2, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 3, 20) + assert sk.nn.Flatten(-3, -1)(jnp.ones([1, 2, 3, 4, 5])).shape == (1, 2, 60) def test_unflatten(): - assert Unflatten(0, (1, 2, 3))(jnp.ones([6])).shape == (1, 2, 3) - assert Unflatten(0, (1, 2, 3))(jnp.ones([6, 4])).shape == (1, 2, 3, 4) + assert sk.nn.Unflatten(0, (1, 2, 3))(jnp.ones([6])).shape == (1, 2, 3) + assert sk.nn.Unflatten(0, (1, 2, 3))(jnp.ones([6, 4])).shape == (1, 2, 3, 4) def test_crop_1d(): x = jnp.arange(10)[None, :] - assert jnp.all(Crop1D(5, 0)(x)[0] == jnp.arange(5)) - assert jnp.all(Crop1D(5, 5)(x)[0] == jnp.arange(5, 10)) - assert jnp.all(Crop1D(5, 2)(x)[0] == jnp.arange(2, 7)) + assert jnp.all(sk.nn.Crop1D(5, 0)(x)[0] == jnp.arange(5)) + assert jnp.all(sk.nn.Crop1D(5, 5)(x)[0] == jnp.arange(5, 10)) + assert jnp.all(sk.nn.Crop1D(5, 2)(x)[0] == jnp.arange(2, 7)) # this is how jax.lax.dynamic_slice handles it - assert jnp.all(Crop1D(5, 7)(x)[0] == jnp.array([5, 6, 7, 8, 9])) + assert jnp.all(sk.nn.Crop1D(5, 7)(x)[0] == jnp.array([5, 6, 7, 8, 9])) def test_crop_2d(): x = jnp.arange(25).reshape(1, 5, 5) y = jnp.array([[0, 1, 2], [5, 6, 7], [10, 11, 12]]) - assert jnp.all(Crop2D((3, 3), (0, 0))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (0, 0))(x)[0] == y) y = jnp.array([[2, 3, 4], [7, 8, 9], [12, 13, 14]]) - assert jnp.all(Crop2D((3, 3), (0, 2))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (0, 2))(x)[0] == y) y = jnp.array([[10, 11, 12], [15, 16, 17], [20, 21, 22]]) - assert jnp.all(Crop2D((3, 3), (2, 0))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (2, 0))(x)[0] == y) y = jnp.array([[12, 13, 14], [17, 18, 19], [22, 23, 24]]) - assert jnp.all(Crop2D((3, 3), (2, 2))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (2, 2))(x)[0] == y) y = jnp.array([[12, 13, 14], [17, 18, 19], [22, 23, 24]]) - assert jnp.all(Crop2D((3, 3), (2, 2))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (2, 2))(x)[0] == y) y = jnp.array([[12, 13, 14], [17, 18, 19], [22, 23, 24]]) - assert jnp.all(Crop2D((3, 3), (2, 2))(x)[0] == y) + assert jnp.all(sk.nn.Crop2D((3, 3), (2, 2))(x)[0] == y) def test_crop_3d(): @@ -96,78 +73,80 @@ def test_crop_3d(): [[50, 51, 52], [55, 56, 57], [60, 61, 62]], ] ) - assert jnp.all(Crop3D((3, 3, 3), (0, 0, 0))(x)[0] == y) + assert jnp.all(sk.nn.Crop3D((3, 3, 3), (0, 0, 0))(x)[0] == y) def test_random_crop_1d(): x = jnp.arange(10)[None, :] - assert RandomCrop1D(size=5)(x).shape == (1, 5) + assert sk.nn.RandomCrop1D(size=5)(x).shape == (1, 5) def test_random_crop_2d(): x = jnp.arange(25).reshape(1, 5, 5) - assert RandomCrop2D(size=(3, 3))(x).shape == (1, 3, 3) + assert sk.nn.RandomCrop2D(size=(3, 3))(x).shape == (1, 3, 3) def test_random_crop_3d(): x = jnp.arange(125).reshape(1, 5, 5, 5) - assert RandomCrop3D(size=(3, 3, 3))(x).shape == (1, 3, 3, 3) + assert sk.nn.RandomCrop3D(size=(3, 3, 3))(x).shape == (1, 3, 3, 3) def test_resize1d(): - assert Resize1D(4)(jnp.ones([1, 2])).shape == (1, 4) + assert sk.nn.Resize1D(4)(jnp.ones([1, 2])).shape == (1, 4) def test_resize2d(): - assert Resize2D(4)(jnp.ones([1, 2, 2])).shape == (1, 4, 4) + assert sk.nn.Resize2D(4)(jnp.ones([1, 2, 2])).shape == (1, 4, 4) def test_resize3d(): - assert Resize3D(4)(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 4, 4) + assert sk.nn.Resize3D(4)(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 4, 4) def test_upsample1d(): - assert Upsample1D(2)(jnp.ones([1, 2])).shape == (1, 4) + assert sk.nn.Upsample1D(2)(jnp.ones([1, 2])).shape == (1, 4) def test_upsample2d(): - assert Upsample2D(2)(jnp.ones([1, 2, 2])).shape == (1, 4, 4) - assert Upsample2D((2, 3))(jnp.ones([1, 2, 2])).shape == (1, 4, 6) + assert sk.nn.Upsample2D(2)(jnp.ones([1, 2, 2])).shape == (1, 4, 4) + assert sk.nn.Upsample2D((2, 3))(jnp.ones([1, 2, 2])).shape == (1, 4, 6) def test_upsample3d(): - assert Upsample3D(2)(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 4, 4) - assert Upsample3D((2, 3, 4))(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 6, 8) + assert sk.nn.Upsample3D(2)(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 4, 4) + assert sk.nn.Upsample3D((2, 3, 4))(jnp.ones([1, 2, 2, 2])).shape == (1, 4, 6, 8) def test_padding1d(): - layer = Pad1D(padding=1) + layer = sk.nn.Pad1D(padding=1) assert layer(jnp.ones((1, 1))).shape == (1, 3) def test_padding2d(): - layer = Pad2D(padding=1) + layer = sk.nn.Pad2D(padding=1) assert layer(jnp.ones((1, 1, 1))).shape == (1, 3, 3) - layer = Pad2D(padding=((1, 2), (3, 4))) + layer = sk.nn.Pad2D(padding=((1, 2), (3, 4))) assert layer(jnp.ones((1, 1, 1))).shape == (1, 4, 8) def test_padding3d(): - layer = Pad3D(padding=1) + layer = sk.nn.Pad3D(padding=1) assert layer(jnp.ones((1, 1, 1, 1))).shape == (1, 3, 3, 3) - layer = Pad3D(padding=((1, 2), (3, 4), (5, 6))) + layer = sk.nn.Pad3D(padding=((1, 2), (3, 4), (5, 6))) assert layer(jnp.ones((1, 1, 1, 1))).shape == (1, 4, 8, 12) def test_random_zoom(): - npt.assert_allclose(RandomZoom1D((0, 0))(jnp.ones((10, 5))), jnp.ones((10, 5))) + npt.assert_allclose( + sk.nn.RandomZoom1D((0, 0))(jnp.ones((10, 5))), jnp.ones((10, 5)) + ) npt.assert_allclose( - RandomZoom2D((0.5, 0.5))(jnp.ones((10, 5, 5))).shape, (10, 5, 5) + sk.nn.RandomZoom2D((0.5, 0.5))(jnp.ones((10, 5, 5))).shape, (10, 5, 5) ) npt.assert_allclose( - RandomZoom3D((0.5, 0.5))(jnp.ones((10, 5, 5, 5))).shape, (10, 5, 5, 5) + sk.nn.RandomZoom3D((0.5, 0.5))(jnp.ones((10, 5, 5, 5))).shape, (10, 5, 5, 5) ) diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 3422d78..4656510 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ # import tensorflow.keras as tfk # import tensorflow as tf # import numpy as np -# from serket.nn.recurrent import LSTMCell, ScanRNN +# from serket._src.nn.recurrent import LSTMCell, ScanRNN # batch_size = 1 # time_steps = 2 @@ -68,7 +68,7 @@ import numpy.testing as npt import pytest -from serket.nn.recurrent import ( # ConvGRU1DCell,; ConvGRU2DCell,; ConvGRU3DCell,; ConvLSTM2DCell,; ConvLSTM3DCell, +from serket._src.nn.recurrent import ( ConvLSTM1DCell, DenseCell, FFTConvLSTM1DCell, diff --git a/tests/test_sequential.py b/tests/test_sequential.py index 74aef7c..c75e573 100644 --- a/tests/test_sequential.py +++ b/tests/test_sequential.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from serket.nn import Sequential +import serket as sk def test_sequential(): - model = Sequential(lambda x: x) + model = sk.nn.Sequential(lambda x: x) assert model(1.0) == 1.0 - model = Sequential(lambda x: x + 1, lambda x: x + 1) + model = sk.nn.Sequential(lambda x: x + 1, lambda x: x + 1) assert model(1.0) == 3.0 - model = Sequential(lambda x, key: x) + model = sk.nn.Sequential(lambda x, key: x) assert model(1.0) == 1.0 def test_sequential_getitem(): - model = Sequential(lambda x: x + 1, lambda x: x + 1) + model = sk.nn.Sequential(lambda x: x + 1, lambda x: x + 1) assert model[0](1.0) == 2.0 assert model[1](1.0) == 2.0 assert model[0:1](1.0) == 2.0 @@ -36,15 +36,15 @@ def test_sequential_getitem(): def test_sequential_len(): - model = Sequential(lambda x: x + 1, lambda x: x + 1) + model = sk.nn.Sequential(lambda x: x + 1, lambda x: x + 1) assert len(model) == 2 def test_sequential_iter(): - model = Sequential(lambda x: x + 1, lambda x: x + 1) + model = sk.nn.Sequential(lambda x: x + 1, lambda x: x + 1) assert list(model) == [model[0], model[1]] def test_sequential_reversed(): - model = Sequential(lambda x: x + 1, lambda x: x + 1) + model = sk.nn.Sequential(lambda x: x + 1, lambda x: x + 1) assert list(reversed(model)) == [model[1], model[0]] diff --git a/tests/test_utils.py b/tests/test_utils.py index d77224f..9f32ab1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 Serket authors +# Copyright 2023 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ import jax.tree_util as jtu import pytest -from serket.nn.initialization import resolve_init -from serket.utils import canonicalize +from serket._src.nn.initialization import resolve_init +from serket._src.utils import canonicalize def test_canonicalize_init_func():