From 4fca95e9048f96d07516d87288f6cdc27e779ce4 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 13 Apr 2024 20:46:50 +0900 Subject: [PATCH] merge padding and convert, move stranded types to typing --- serket/_src/image/filter.py | 6 +- serket/_src/nn/activation.py | 4 +- serket/_src/nn/convolution.py | 6 +- serket/_src/nn/linear.py | 13 +-- serket/_src/nn/normalization.py | 7 +- serket/_src/nn/pooling.py | 3 +- serket/_src/nn/recurrent.py | 14 +--- serket/_src/nn/reshape.py | 3 +- serket/_src/utils/convert.py | 122 +++++++++++++++++++++++++++- serket/_src/utils/inspect.py | 3 +- serket/_src/utils/padding.py | 136 -------------------------------- serket/_src/utils/typing.py | 12 ++- tests/test_utils.py | 4 +- 13 files changed, 152 insertions(+), 181 deletions(-) delete mode 100644 serket/_src/utils/padding.py diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index 0d1327a..4c3f0fa 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -28,12 +28,12 @@ fft_conv_general_dilated, generate_conv_dim_numbers, ) -from serket._src.utils.convert import canonicalize -from serket._src.utils.mapping import kernel_map -from serket._src.utils.padding import ( +from serket._src.utils.convert import ( + canonicalize, delayed_canonicalize_padding, resolve_string_padding, ) +from serket._src.utils.mapping import kernel_map from serket._src.utils.typing import CHWArray, DType, HWArray from serket._src.utils.validate import validate_spatial_ndim diff --git a/serket/_src/nn/activation.py b/serket/_src/nn/activation.py index 7a0280b..a0274fc 100644 --- a/serket/_src/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -16,7 +16,7 @@ import inspect from collections.abc import Callable as ABCCallable -from typing import Callable, TypeVar, Union, get_args +from typing import Callable, Union, get_args import jax import jax.numpy as jnp @@ -26,8 +26,6 @@ from serket._src.utils.typing import ActivationLiteral from serket._src.utils.validate import IsInstance, Range, ScalarLike -T = TypeVar("T") - @autoinit class CeLU(TreeClass): diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index e432a0c..cdd7b4a 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -29,12 +29,12 @@ from serket import TreeClass from serket._src.nn.initialization import resolve_init -from serket._src.utils.convert import canonicalize -from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init -from serket._src.utils.padding import ( +from serket._src.utils.convert import ( calculate_transpose_padding, + canonicalize, delayed_canonicalize_padding, ) +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init from serket._src.utils.typing import ( DilationType, DType, diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 26f80d7..f766c93 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -15,7 +15,7 @@ from __future__ import annotations import functools as ft -from typing import Any, Generic, Sequence, TypeVar +from typing import Sequence import jax import jax.numpy as jnp @@ -30,16 +30,9 @@ from serket._src.nn.initialization import resolve_init from serket._src.utils.convert import tuplify from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init -from serket._src.utils.typing import DType, InitType +from serket._src.utils.typing import Batched, DType, InitType from serket._src.utils.validate import validate_pos_int -T = TypeVar("T") -PyTree = Any - - -class Batched(Generic[T]): - pass - def generate_einsum_pattern( lhs_ndim: int, @@ -308,7 +301,7 @@ def scan_func(input: jax.Array, weight_bias: Batched[jax.Array]): return output -def infer_in_features(instance, x, **__) -> tuple[int, ...]: +def infer_in_features(_1, x, **_2) -> int: return x.shape[-1] diff --git a/serket/_src/nn/normalization.py b/serket/_src/nn/normalization.py index 5b4e621..64ed268 100644 --- a/serket/_src/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -15,7 +15,7 @@ from __future__ import annotations import functools as ft -from typing import Sequence, TypeVar +from typing import Sequence import jax import jax.numpy as jnp @@ -26,7 +26,7 @@ from serket._src.nn.initialization import resolve_init from serket._src.utils.convert import tuplify from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init -from serket._src.utils.typing import DType, InitType +from serket._src.utils.typing import DType, InitType, T from serket._src.utils.validate import ( Range, ScalarLike, @@ -844,9 +844,6 @@ def _(batch_norm: BatchNorm) -> BatchNormState: return BatchNormState(running_mean, running_var) -T = TypeVar("T") - - def weight_norm(leaf: T, axis: int | None = -1, eps: float = 1e-12) -> T: """Apply L2 weight normalization to an input. diff --git a/serket/_src/nn/pooling.py b/serket/_src/nn/pooling.py index 624a152..9a66337 100644 --- a/serket/_src/nn/pooling.py +++ b/serket/_src/nn/pooling.py @@ -23,9 +23,8 @@ from typing_extensions import Annotated from serket import TreeClass -from serket._src.utils.convert import canonicalize +from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding from serket._src.utils.mapping import kernel_map -from serket._src.utils.padding import delayed_canonicalize_padding from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType from serket._src.utils.validate import validate_spatial_ndim diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index fbeef1d..0d156ca 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Defines RNN related classes.""" + from __future__ import annotations import abc import functools as ft -from typing import Any, Callable, TypeVar +from typing import Any, Callable import jax import jax.numpy as jnp import jax.random as jr -from typing_extensions import ParamSpec from serket import TreeClass, autoinit from serket._src.custom_transform import tree_state @@ -42,6 +43,7 @@ InitType, KernelSizeType, PaddingType, + S, StridesType, ) from serket._src.utils.validate import ( @@ -50,14 +52,6 @@ validate_spatial_ndim, ) -P = ParamSpec("P") -T = TypeVar("T") -S = TypeVar("S") - -State = Any - -"""Defines RNN related classes.""" - def is_lazy_call(instance, *_1, **_2) -> bool: return instance.in_features is None diff --git a/serket/_src/nn/reshape.py b/serket/_src/nn/reshape.py index 8e3596a..f499d74 100644 --- a/serket/_src/nn/reshape.py +++ b/serket/_src/nn/reshape.py @@ -24,9 +24,8 @@ from serket import TreeClass from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils.convert import canonicalize +from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding from serket._src.utils.mapping import kernel_map -from serket._src.utils.padding import delayed_canonicalize_padding from serket._src.utils.typing import ( KernelSizeType, MethodKind, diff --git a/serket/_src/utils/convert.py b/serket/_src/utils/convert.py index 59847f2..0dfa08f 100644 --- a/serket/_src/utils/convert.py +++ b/serket/_src/utils/convert.py @@ -14,12 +14,13 @@ from __future__ import annotations -from typing import Sequence, TypeVar +import functools as ft +from typing import Sequence import jax import jax.numpy as jnp -T = TypeVar("T") +from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType, T def canonicalize(value, ndim, name: str | None = None): @@ -36,3 +37,120 @@ def canonicalize(value, ndim, name: str | None = None): def tuplify(value: T) -> T | tuple[T]: return tuple(value) if isinstance(value, Sequence) else (value,) + + +def same_padding_along_dim( + in_dim: int, + kernel_size: int, + stride: int, +) -> tuple[int, int]: + # https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 + # di: input dimension + # ki: kernel size + # si: stride + if in_dim % stride == 0: + pad = max(kernel_size - stride, 0) + else: + pad = max(kernel_size - (in_dim % stride), 0) + + return (pad // 2, pad - pad // 2) + + +def resolve_tuple_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +) -> tuple[tuple[int, int], ...]: + del in_dim, strides + if len(padding) != len(kernel_size): + raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.") + + resolved_padding = [[]] * len(kernel_size) + + for i, item in enumerate(padding): + if isinstance(item, int): + resolved_padding[i] = (item, item) # ex: padding = (1, 2, 3) + + elif isinstance(item, tuple): + if len(item) != 2: + raise ValueError(f"Expected tuple of length 2, got {len(item)=}") + resolved_padding[i] = item + + return tuple(resolved_padding) + + +def resolve_int_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +): + del in_dim, strides + return ((padding, padding),) * len(kernel_size) + + +def resolve_string_padding(in_dim, padding, kernel_size, strides): + if padding.lower() == "same": + return tuple( + same_padding_along_dim(di, ki, si) + for di, ki, si in zip(in_dim, kernel_size, strides) + ) + + if padding.lower() == "valid": + return ((0, 0),) * len(kernel_size) + + raise ValueError(f'string argument must be in ["same","valid"].Found {padding}') + + +@ft.lru_cache(maxsize=128) +def delayed_canonicalize_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +): + # in case of `str` padding, we need to know the input dimension + # to calculate the padding thus we need to delay the canonicalization + # until the call + + if isinstance(padding, int): + return resolve_int_padding(in_dim, padding, kernel_size, strides) + + if isinstance(padding, str): + return resolve_string_padding(in_dim, padding, kernel_size, strides) + + if isinstance(padding, tuple): + return resolve_tuple_padding(in_dim, padding, kernel_size, strides) + + raise ValueError( + "Expected padding to be of:\n" + "* int, for same padding along all dimensions\n" + "* str, for `same` or `valid` padding along all dimensions\n" + "* tuple of int, for individual padding along each dimension\n" + "* tuple of tuple of int, for padding before and after each dimension\n" + f"Got {padding=}." + ) + + +@ft.lru_cache(maxsize=128) +def calculate_transpose_padding( + padding, + kernel_size, + input_dilation, + extra_padding, +): + """Transpose padding to get the padding for the transpose convolution. + + Args: + padding: padding to transpose + kernel_size: kernel size to use for transposing padding + input_dilation: input dilation to use for transposing padding + extra_padding: extra padding to use for transposing padding + """ + return tuple( + ((ki - 1) * di - pl, (ki - 1) * di - pr + ep) + for (pl, pr), ki, ep, di in zip( + padding, kernel_size, extra_padding, input_dilation + ) + ) diff --git a/serket/_src/utils/inspect.py b/serket/_src/utils/inspect.py index 65568f9..5386d85 100644 --- a/serket/_src/utils/inspect.py +++ b/serket/_src/utils/inspect.py @@ -16,10 +16,9 @@ import functools as ft import inspect -from types import MethodType @ft.lru_cache(maxsize=128) -def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]: +def get_params(func) -> tuple[inspect.Parameter, ...]: """Get the arguments of func.""" return tuple(inspect.signature(func).parameters.values()) diff --git a/serket/_src/utils/padding.py b/serket/_src/utils/padding.py deleted file mode 100644 index c542614..0000000 --- a/serket/_src/utils/padding.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import functools as ft - -from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType - - -def same_padding_along_dim( - in_dim: int, - kernel_size: int, - stride: int, -) -> tuple[int, int]: - # https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 - # di: input dimension - # ki: kernel size - # si: stride - if in_dim % stride == 0: - pad = max(kernel_size - stride, 0) - else: - pad = max(kernel_size - (in_dim % stride), 0) - - return (pad // 2, pad - pad // 2) - - -def resolve_tuple_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -) -> tuple[tuple[int, int], ...]: - del in_dim, strides - if len(padding) != len(kernel_size): - raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.") - - resolved_padding = [[]] * len(kernel_size) - - for i, item in enumerate(padding): - if isinstance(item, int): - resolved_padding[i] = (item, item) # ex: padding = (1, 2, 3) - - elif isinstance(item, tuple): - if len(item) != 2: - raise ValueError(f"Expected tuple of length 2, got {len(item)=}") - resolved_padding[i] = item - - return tuple(resolved_padding) - - -def resolve_int_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -): - del in_dim, strides - return ((padding, padding),) * len(kernel_size) - - -def resolve_string_padding(in_dim, padding, kernel_size, strides): - if padding.lower() == "same": - return tuple( - same_padding_along_dim(di, ki, si) - for di, ki, si in zip(in_dim, kernel_size, strides) - ) - - if padding.lower() == "valid": - return ((0, 0),) * len(kernel_size) - - raise ValueError(f'string argument must be in ["same","valid"].Found {padding}') - - -@ft.lru_cache(maxsize=128) -def delayed_canonicalize_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -): - # in case of `str` padding, we need to know the input dimension - # to calculate the padding thus we need to delay the canonicalization - # until the call - - if isinstance(padding, int): - return resolve_int_padding(in_dim, padding, kernel_size, strides) - - if isinstance(padding, str): - return resolve_string_padding(in_dim, padding, kernel_size, strides) - - if isinstance(padding, tuple): - return resolve_tuple_padding(in_dim, padding, kernel_size, strides) - - raise ValueError( - "Expected padding to be of:\n" - "* int, for same padding along all dimensions\n" - "* str, for `same` or `valid` padding along all dimensions\n" - "* tuple of int, for individual padding along each dimension\n" - "* tuple of tuple of int, for padding before and after each dimension\n" - f"Got {padding=}." - ) - - -@ft.lru_cache(maxsize=128) -def calculate_transpose_padding( - padding, - kernel_size, - input_dilation, - extra_padding, -): - """Transpose padding to get the padding for the transpose convolution. - - Args: - padding: padding to transpose - kernel_size: kernel size to use for transposing padding - input_dilation: input dilation to use for transposing padding - extra_padding: extra padding to use for transposing padding - """ - return tuple( - ((ki - 1) * di - pl, (ki - 1) * di - pr + ep) - for (pl, pr), ki, ep, di in zip( - padding, kernel_size, extra_padding, input_dilation - ) - ) diff --git a/serket/_src/utils/typing.py b/serket/_src/utils/typing.py index 2e62f3f..995866c 100644 --- a/serket/_src/utils/typing.py +++ b/serket/_src/utils/typing.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Callable, Literal, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, Generic, Literal, Sequence, Tuple, TypeVar, Union import jax import numpy as np @@ -94,3 +94,13 @@ "tanh_shrink", "thresholded_relu", ] + + +P = ParamSpec("P") +T = TypeVar("T") +S = TypeVar("S") +PyTree = TypeVar("PyTree", bound=Any) + + +class Batched(Generic[T]): + pass diff --git a/tests/test_utils.py b/tests/test_utils.py index 9504330..92f81b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,8 +21,8 @@ import serket as sk from serket._src.nn.initialization import resolve_init -from serket._src.utils.convert import canonicalize -from serket._src.utils.padding import ( +from serket._src.utils.convert import ( + canonicalize, delayed_canonicalize_padding, resolve_string_padding, resolve_tuple_padding,