From 829b4312edc5fdec5ddacff75e15014a29e8a02c Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Wed, 3 Apr 2024 00:31:38 +0900 Subject: [PATCH] revert dispatchers --- serket/_src/nn/convolution.py | 96 +++-------------------------------- serket/_src/nn/linear.py | 34 ++++--------- 2 files changed, 15 insertions(+), 115 deletions(-) diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index f47bb0d..41355c1 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -120,10 +120,9 @@ def grouped_matmul(x, y, groups) -> jax.Array: ) -@single_dispatch(argnum=1) def fft_conv_nd( input: jax.Array, - weight: Any, + weight: Weight, bias: jax.Array | None, strides: tuple[int, ...], padding: tuple[tuple[int, int], ...], @@ -150,20 +149,6 @@ def fft_conv_nd( mask: a binary mask multiplied with the convolutional kernel. shape is ``(out_features, in_features, kernel)``. set to ``None`` to not use a mask. """ - raise TypeError(f"{type(weight)=}") - - -@fft_conv_nd.def_type(jax.Array) -def _( - input: jax.Array, - weight: Weight, - bias: jax.Array | None, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - dilation: tuple[int, ...], - groups: int, - mask: Weight | None = None, -) -> jax.Array: x = fft_conv_general_dilated( lhs=jnp.expand_dims(input, 0), rhs=weight if mask is None else weight * mask, @@ -176,10 +161,9 @@ def _( return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0) -@single_dispatch(argnum=1) def fft_conv_nd_transpose( input: jax.Array, - weight: Any, + weight: Weight, bias: jax.Array | None, strides: tuple[int, ...], padding: tuple[tuple[int, int], ...], @@ -203,20 +187,6 @@ def fft_conv_nd_transpose( mask: a binary mask multiplied with the convolutional kernel. shape is ``(out_features, in_features, kernel)``. set to ``None`` to not use a mask. """ - raise TypeError(f"{type(weight)=}") - - -@fft_conv_nd_transpose.def_type(jax.Array) -def _( - input: jax.Array, - weight: Weight, - bias: jax.Array | None, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - dilation: tuple[int, ...], - out_padding: int, - mask: Weight | None = None, -) -> jax.Array: transposed_padding = calculate_transpose_padding( padding=padding, extra_padding=out_padding, @@ -321,10 +291,9 @@ def separable_fft_conv_nd( ) -@single_dispatch(argnum=1) def conv_nd( input: jax.Array, - weight: Any, + weight: Weight, bias: jax.Array | None, strides: tuple[int, ...], padding: tuple[tuple[int, int], ...], @@ -348,20 +317,6 @@ def conv_nd( mask: a binary mask multiplied with the convolutional kernel. shape is ``(out_features, in_features, kernel)``. set """ - raise TypeError(f"{type(weight)=}") - - -@conv_nd.def_type(jax.Array) -def _( - input: jax.Array, - weight: Any, - bias: jax.Array | None, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - dilation: tuple[int, ...], - groups: int, - mask: Weight | None = None, -) -> jax.Array: x = jax.lax.conv_general_dilated( lhs=jnp.expand_dims(input, 0), rhs=weight if mask is None else weight * mask, @@ -375,10 +330,9 @@ def _( return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0) -@single_dispatch(argnum=1) def conv_nd_transpose( input: jax.Array, - weight: Any, + weight: Weight, bias: jax.Array | None, strides: tuple[int, ...], padding: tuple[tuple[int, int], ...], @@ -402,20 +356,6 @@ def conv_nd_transpose( mask: a binary mask multiplied with the convolutional kernel. shape is ``(out_features, in_features, kernel)``. set to ``None`` to not use a mask. """ - raise TypeError(f"{type(weight)=}") - - -@conv_nd_transpose.def_type(jax.Array) -def _( - input: jax.Array, - weight: Weight, - bias: jax.Array | None, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - dilation: tuple[int, ...], - out_padding: int, - mask: Weight | None = None, -) -> jax.Array: transposed_padding = calculate_transpose_padding( padding=padding, extra_padding=out_padding, @@ -519,10 +459,9 @@ def depthwise_conv_nd( return jnp.squeeze(x, 0) if bias is None else jnp.squeeze(x + bias, 0) -@single_dispatch(argnum=1) def spectral_conv_nd( input: Annotated[jax.Array, "I..."], - weight: Any, + weight: Weight, modes: tuple[int, ...], ) -> Annotated[jax.Array, "O..."]: """fourier neural operator convolution function. @@ -533,15 +472,7 @@ def spectral_conv_nd( where dim is the number of spatial dimensions on the modes: number of modes included in the fft representation of the input. """ - raise TypeError(f"{type(weight)=}") - -@spectral_conv_nd.def_type(jax.Array) -def _( - input: Annotated[jax.Array, "I..."], - weight: Annotated[jax.Array, "2NOI..."], - modes: tuple[int, ...], -) -> Annotated[jax.Array, "O..."]: def generate_modes_slices(modes: tuple[int, ...]): *ms, ml = modes slices_ = [[slice(None, ml)]] @@ -560,10 +491,9 @@ def generate_modes_slices(modes: tuple[int, ...]): return jnp.fft.irfftn(out, s=(*si, sl)) -@single_dispatch(argnum=1) def local_conv_nd( input: jax.Array, - weight: Weight, + weight: jax.Array, bias: jax.Array | None, strides: tuple[int, ...], padding: tuple[tuple[int, int], ...], @@ -588,20 +518,6 @@ def local_conv_nd( mask: a binary mask multiplied with the convolutional kernel. shape is ``(out_features, in_features, kernel)``. set to ``None`` to not use a mask. """ - raise TypeError(f"{type(weight)=}") - - -@local_conv_nd.def_type(jax.Array) -def _( - input: jax.Array, - weight: Weight, - bias: jax.Array | None, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - dilation: tuple[int, ...], - kernel_size: tuple[int, ...], - mask: Weight | None = None, -) -> jax.Array: x = jax.lax.conv_general_dilated_local( lhs=jnp.expand_dims(input, 0), rhs=weight if mask is None else weight * mask, diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 83f21ac..8926ae6 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -37,24 +37,13 @@ ) T = TypeVar("T") +PyTree = Any class Batched(Generic[T]): pass -PyTree = Any - - -def is_lazy_call(instance, *_1, **_2) -> bool: - return getattr(instance, "in_features", False) is None - - -def is_lazy_init(_1, in_features, *_2, **_3) -> bool: - return in_features is None - - -@single_dispatch(argnum=1) def linear( input: jax.Array, weight: Any, @@ -72,19 +61,6 @@ def linear( corresponding to the (in_feature_1, in_feature_2, ...) out_axis: the axis to put the result. accepts ``in`` values. """ - del input, bias, in_axis, out_axis - raise TypeError(f"{type(weight)=}") - - -@linear.def_type(jax.Array) -def _( - input: jax.Array, - weight: jax.Array, - bias: jax.Array | None, - in_axis: tuple[int, ...], - out_axis: int, -) -> jax.Array: - # weight array handler in_axis = [axis if axis >= 0 else axis + input.ndim for axis in in_axis] lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3 rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3 @@ -108,6 +84,14 @@ def _( return result + bias +def is_lazy_call(instance, *_1, **_2) -> bool: + return getattr(instance, "in_features", False) is None + + +def is_lazy_init(_1, in_features, *_2, **_3) -> bool: + return in_features is None + + def infer_in_features(instance, x, **__) -> tuple[int, ...]: in_axis = getattr(instance, "in_axis", ()) return tuple(x.shape[i] for i in tuplify(in_axis))