Skip to content

Commit

Permalink
merge padding and convert, move stranded types to typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 13, 2024
1 parent a0ef29e commit 4fca95e
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 181 deletions.
6 changes: 3 additions & 3 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
fft_conv_general_dilated,
generate_conv_dim_numbers,
)
from serket._src.utils.convert import canonicalize
from serket._src.utils.mapping import kernel_map
from serket._src.utils.padding import (
from serket._src.utils.convert import (
canonicalize,
delayed_canonicalize_padding,
resolve_string_padding,
)
from serket._src.utils.mapping import kernel_map
from serket._src.utils.typing import CHWArray, DType, HWArray
from serket._src.utils.validate import validate_spatial_ndim

Expand Down
4 changes: 1 addition & 3 deletions serket/_src/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import inspect
from collections.abc import Callable as ABCCallable
from typing import Callable, TypeVar, Union, get_args
from typing import Callable, Union, get_args

import jax
import jax.numpy as jnp
Expand All @@ -26,8 +26,6 @@
from serket._src.utils.typing import ActivationLiteral
from serket._src.utils.validate import IsInstance, Range, ScalarLike

T = TypeVar("T")


@autoinit
class CeLU(TreeClass):
Expand Down
6 changes: 3 additions & 3 deletions serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@

from serket import TreeClass
from serket._src.nn.initialization import resolve_init
from serket._src.utils.convert import canonicalize
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.padding import (
from serket._src.utils.convert import (
calculate_transpose_padding,
canonicalize,
delayed_canonicalize_padding,
)
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.typing import (
DilationType,
DType,
Expand Down
13 changes: 3 additions & 10 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import functools as ft
from typing import Any, Generic, Sequence, TypeVar
from typing import Sequence

import jax
import jax.numpy as jnp
Expand All @@ -30,16 +30,9 @@
from serket._src.nn.initialization import resolve_init
from serket._src.utils.convert import tuplify
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.typing import DType, InitType
from serket._src.utils.typing import Batched, DType, InitType
from serket._src.utils.validate import validate_pos_int

T = TypeVar("T")
PyTree = Any


class Batched(Generic[T]):
pass


def generate_einsum_pattern(
lhs_ndim: int,
Expand Down Expand Up @@ -308,7 +301,7 @@ def scan_func(input: jax.Array, weight_bias: Batched[jax.Array]):
return output


def infer_in_features(instance, x, **__) -> tuple[int, ...]:
def infer_in_features(_1, x, **_2) -> int:
return x.shape[-1]


Expand Down
7 changes: 2 additions & 5 deletions serket/_src/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import functools as ft
from typing import Sequence, TypeVar
from typing import Sequence

import jax
import jax.numpy as jnp
Expand All @@ -26,7 +26,7 @@
from serket._src.nn.initialization import resolve_init
from serket._src.utils.convert import tuplify
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.typing import DType, InitType
from serket._src.utils.typing import DType, InitType, T
from serket._src.utils.validate import (
Range,
ScalarLike,
Expand Down Expand Up @@ -844,9 +844,6 @@ def _(batch_norm: BatchNorm) -> BatchNormState:
return BatchNormState(running_mean, running_var)


T = TypeVar("T")


def weight_norm(leaf: T, axis: int | None = -1, eps: float = 1e-12) -> T:
"""Apply L2 weight normalization to an input.
Expand Down
3 changes: 1 addition & 2 deletions serket/_src/nn/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from typing_extensions import Annotated

from serket import TreeClass
from serket._src.utils.convert import canonicalize
from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding
from serket._src.utils.mapping import kernel_map
from serket._src.utils.padding import delayed_canonicalize_padding
from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType
from serket._src.utils.validate import validate_spatial_ndim

Expand Down
14 changes: 4 additions & 10 deletions serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines RNN related classes."""

from __future__ import annotations

import abc
import functools as ft
from typing import Any, Callable, TypeVar
from typing import Any, Callable

import jax
import jax.numpy as jnp
import jax.random as jr
from typing_extensions import ParamSpec

from serket import TreeClass, autoinit
from serket._src.custom_transform import tree_state
Expand All @@ -42,6 +43,7 @@
InitType,
KernelSizeType,
PaddingType,
S,
StridesType,
)
from serket._src.utils.validate import (
Expand All @@ -50,14 +52,6 @@
validate_spatial_ndim,
)

P = ParamSpec("P")
T = TypeVar("T")
S = TypeVar("S")

State = Any

"""Defines RNN related classes."""


def is_lazy_call(instance, *_1, **_2) -> bool:
return instance.in_features is None
Expand Down
3 changes: 1 addition & 2 deletions serket/_src/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from serket import TreeClass
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils.convert import canonicalize
from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding
from serket._src.utils.mapping import kernel_map
from serket._src.utils.padding import delayed_canonicalize_padding
from serket._src.utils.typing import (
KernelSizeType,
MethodKind,
Expand Down
122 changes: 120 additions & 2 deletions serket/_src/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from __future__ import annotations

from typing import Sequence, TypeVar
import functools as ft
from typing import Sequence

import jax
import jax.numpy as jnp

T = TypeVar("T")
from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType, T


def canonicalize(value, ndim, name: str | None = None):
Expand All @@ -36,3 +37,120 @@ def canonicalize(value, ndim, name: str | None = None):

def tuplify(value: T) -> T | tuple[T]:
return tuple(value) if isinstance(value, Sequence) else (value,)


def same_padding_along_dim(
in_dim: int,
kernel_size: int,
stride: int,
) -> tuple[int, int]:
# https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2
# di: input dimension
# ki: kernel size
# si: stride
if in_dim % stride == 0:
pad = max(kernel_size - stride, 0)
else:
pad = max(kernel_size - (in_dim % stride), 0)

return (pad // 2, pad - pad // 2)


def resolve_tuple_padding(
in_dim: tuple[int, ...],
padding: PaddingType,
kernel_size: KernelSizeType,
strides: StridesType,
) -> tuple[tuple[int, int], ...]:
del in_dim, strides
if len(padding) != len(kernel_size):
raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.")

resolved_padding = [[]] * len(kernel_size)

for i, item in enumerate(padding):
if isinstance(item, int):
resolved_padding[i] = (item, item) # ex: padding = (1, 2, 3)

elif isinstance(item, tuple):
if len(item) != 2:
raise ValueError(f"Expected tuple of length 2, got {len(item)=}")
resolved_padding[i] = item

return tuple(resolved_padding)


def resolve_int_padding(
in_dim: tuple[int, ...],
padding: PaddingType,
kernel_size: KernelSizeType,
strides: StridesType,
):
del in_dim, strides
return ((padding, padding),) * len(kernel_size)


def resolve_string_padding(in_dim, padding, kernel_size, strides):
if padding.lower() == "same":
return tuple(
same_padding_along_dim(di, ki, si)
for di, ki, si in zip(in_dim, kernel_size, strides)
)

if padding.lower() == "valid":
return ((0, 0),) * len(kernel_size)

raise ValueError(f'string argument must be in ["same","valid"].Found {padding}')


@ft.lru_cache(maxsize=128)
def delayed_canonicalize_padding(
in_dim: tuple[int, ...],
padding: PaddingType,
kernel_size: KernelSizeType,
strides: StridesType,
):
# in case of `str` padding, we need to know the input dimension
# to calculate the padding thus we need to delay the canonicalization
# until the call

if isinstance(padding, int):
return resolve_int_padding(in_dim, padding, kernel_size, strides)

if isinstance(padding, str):
return resolve_string_padding(in_dim, padding, kernel_size, strides)

if isinstance(padding, tuple):
return resolve_tuple_padding(in_dim, padding, kernel_size, strides)

raise ValueError(
"Expected padding to be of:\n"
"* int, for same padding along all dimensions\n"
"* str, for `same` or `valid` padding along all dimensions\n"
"* tuple of int, for individual padding along each dimension\n"
"* tuple of tuple of int, for padding before and after each dimension\n"
f"Got {padding=}."
)


@ft.lru_cache(maxsize=128)
def calculate_transpose_padding(
padding,
kernel_size,
input_dilation,
extra_padding,
):
"""Transpose padding to get the padding for the transpose convolution.
Args:
padding: padding to transpose
kernel_size: kernel size to use for transposing padding
input_dilation: input dilation to use for transposing padding
extra_padding: extra padding to use for transposing padding
"""
return tuple(
((ki - 1) * di - pl, (ki - 1) * di - pr + ep)
for (pl, pr), ki, ep, di in zip(
padding, kernel_size, extra_padding, input_dilation
)
)
3 changes: 1 addition & 2 deletions serket/_src/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

import functools as ft
import inspect
from types import MethodType


@ft.lru_cache(maxsize=128)
def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]:
def get_params(func) -> tuple[inspect.Parameter, ...]:
"""Get the arguments of func."""
return tuple(inspect.signature(func).parameters.values())
Loading

0 comments on commit 4fca95e

Please sign in to comment.