Skip to content

Commit

Permalink
move to _src
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 8, 2023
1 parent 7c7b4ec commit 57035dd
Show file tree
Hide file tree
Showing 44 changed files with 405 additions and 408 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# 📙 `serket.nn` layers overview"
"# 📙 `serket._src.nn` layers overview"
]
},
{
Expand Down
9 changes: 5 additions & 4 deletions serket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,10 +42,11 @@
unfreeze,
)

from serket._src.custom_transform import tree_eval, tree_state
from serket._src.nn.activation import def_act_entry
from serket._src.nn.initialization import def_init_entry

from . import cluster, image, nn
from .custom_transform import tree_eval, tree_state
from .nn.activation import def_act_entry
from .nn.initialization import def_init_entry

__all__ = [
# general utils
Expand Down
13 changes: 13 additions & 0 deletions serket/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions serket/_src/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
6 changes: 3 additions & 3 deletions serket/cluster/kmeans.py → serket/_src/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,8 +23,8 @@
from typing_extensions import Annotated

import serket as sk
from serket.custom_transform import tree_eval, tree_state
from serket.utils import IsInstance, Range
from serket._src.custom_transform import tree_eval, tree_state
from serket._src.utils import IsInstance, Range

"""K-means utility functions."""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -133,7 +133,7 @@ def tree_eval(tree):
Note:
To define evaluation rule for a custom layer, use the decorator
:func:`.tree_eval.def_eval` on a function that accepts the layer. The
:func:`.tree_eval.def_eval` on a function that accepts the layer. The
function should return the evaluation layer.
>>> import serket as sk
Expand Down
13 changes: 13 additions & 0 deletions serket/_src/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
8 changes: 4 additions & 4 deletions serket/image/augment.py → serket/_src/image/augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,9 +22,9 @@
from jax import lax

import serket as sk
from serket.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.utils import IsInstance, Range, validate_spatial_ndim
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import IsInstance, Range, validate_spatial_ndim


def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax.Array:
Expand Down
8 changes: 4 additions & 4 deletions serket/image/filter.py → serket/_src/image/filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,9 +21,9 @@
from typing_extensions import Annotated

import serket as sk
from serket.nn.convolution import fft_conv_general_dilated
from serket.nn.initialization import DType
from serket.utils import (
from serket._src.nn.convolution import fft_conv_general_dilated
from serket._src.nn.initialization import DType
from serket._src.utils import (
generate_conv_dim_numbers,
positive_int_cb,
resolve_string_padding,
Expand Down
8 changes: 4 additions & 4 deletions serket/image/geometric.py → serket/_src/image/geometric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,9 +22,9 @@
from jax.scipy.ndimage import map_coordinates

import serket as sk
from serket.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.utils import IsInstance, validate_spatial_ndim
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import IsInstance, validate_spatial_ndim


def affine(image, matrix):
Expand Down
13 changes: 13 additions & 0 deletions serket/_src/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
4 changes: 2 additions & 2 deletions serket/nn/activation.py → serket/_src/nn/activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,7 @@
from jax import lax

import serket as sk
from serket.utils import IsInstance, Range, ScalarLike
from serket._src.utils import IsInstance, Range, ScalarLike

T = TypeVar("T")

Expand Down
6 changes: 3 additions & 3 deletions serket/nn/attention.py → serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,8 +22,8 @@
from typing_extensions import Annotated

import serket as sk
from serket.nn.initialization import InitType
from serket.utils import maybe_lazy_call, maybe_lazy_init
from serket._src.nn.initialization import InitType
from serket._src.utils import maybe_lazy_call, maybe_lazy_init

"""Defines attention layers."""

Expand Down
6 changes: 3 additions & 3 deletions serket/nn/containers.py → serket/_src/nn/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,8 @@
import jax.random as jr

import serket as sk
from serket.custom_transform import tree_eval
from serket.utils import Range
from serket._src.custom_transform import tree_eval
from serket._src.utils import Range


def sequential(
Expand Down
6 changes: 3 additions & 3 deletions serket/nn/convolution.py → serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,8 +27,8 @@
from typing_extensions import Annotated

import serket as sk
from serket.nn.initialization import DType, InitType, resolve_init
from serket.utils import (
from serket._src.nn.initialization import DType, InitType, resolve_init
from serket._src.utils import (
DilationType,
KernelSizeType,
PaddingType,
Expand Down
11 changes: 5 additions & 6 deletions serket/nn/dropout.py → serket/_src/nn/dropout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,9 +23,8 @@
import jax.random as jr

import serket as sk
from serket.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.utils import (
from serket._src.custom_transform import tree_eval
from serket._src.utils import (
IsInstance,
Range,
canonicalize,
Expand Down Expand Up @@ -434,5 +433,5 @@ def spatial_ndim(self) -> int:
@tree_eval.def_eval(RandomCutout2D)
@tree_eval.def_eval(GeneralDropout)
@tree_eval.def_eval(DropoutND)
def dropout_evaluation(_) -> Identity:
return Identity()
def dropout_evaluation(_) -> sk.nn.Identity:
return sk.nn.Identity()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@

import functools as ft
from collections.abc import Callable as ABCCallable
from typing import Callable,Any, Literal, Tuple, Union, get_args
from typing import Any, Callable, Literal, Tuple, Union, get_args

import jax
import jax.nn.initializers as ji
Expand Down
8 changes: 4 additions & 4 deletions serket/nn/linear.py → serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,13 +22,13 @@
import jax.random as jr

import serket as sk
from serket.nn.activation import (
from serket._src.nn.activation import (
ActivationFunctionType,
ActivationType,
resolve_activation,
)
from serket.nn.initialization import DType, InitType, resolve_init
from serket.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb
from serket._src.nn.initialization import DType, InitType, resolve_init
from serket._src.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb

T = TypeVar("T")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,9 +22,9 @@
from jax.custom_batching import custom_vmap

import serket as sk
from serket.custom_transform import tree_eval, tree_state
from serket.nn.initialization import DType, InitType, resolve_init
from serket.utils import (
from serket._src.custom_transform import tree_eval, tree_state
from serket._src.nn.initialization import DType, InitType, resolve_init
from serket._src.utils import (
Range,
ScalarLike,
maybe_lazy_call,
Expand Down Expand Up @@ -225,6 +225,7 @@ class GroupNorm(sk.TreeClass):
Reference:
https://nn.labml.ai/normalization/group_norm/index.html
"""

eps: float = sk.field(on_setattr=[Range(0), ScalarLike()])

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
Expand Down
4 changes: 2 additions & 2 deletions serket/nn/pooling.py → serket/_src/nn/pooling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,7 @@
import kernex as kex

import serket as sk
from serket.utils import (
from serket._src.utils import (
KernelSizeType,
PaddingType,
StridesType,
Expand Down
10 changes: 5 additions & 5 deletions serket/nn/recurrent.py → serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,10 +24,10 @@
import jax.tree_util as jtu

import serket as sk
from serket.custom_transform import tree_state
from serket.nn.activation import ActivationType, resolve_activation
from serket.nn.initialization import DType, InitType
from serket.utils import (
from serket._src.custom_transform import tree_state
from serket._src.nn.activation import ActivationType, resolve_activation
from serket._src.nn.initialization import DType, InitType
from serket._src.utils import (
DilationType,
KernelSizeType,
PaddingType,
Expand Down
8 changes: 4 additions & 4 deletions serket/nn/reshape.py → serket/_src/nn/reshape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,9 +23,9 @@
import jax.random as jr

import serket as sk
from serket.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.utils import (
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import (
IsInstance,
canonicalize,
delayed_canonicalize_padding,
Expand Down
2 changes: 1 addition & 1 deletion serket/utils.py → serket/_src/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Serket authors
# Copyright 2023 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 57035dd

Please sign in to comment.