Skip to content

Commit

Permalink
revert dispatchers
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 2, 2024
1 parent 14a5d6b commit 829b431
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 115 deletions.
96 changes: 6 additions & 90 deletions serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ def grouped_matmul(x, y, groups) -> jax.Array:
)


@single_dispatch(argnum=1)
def fft_conv_nd(
input: jax.Array,
weight: Any,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
Expand All @@ -150,20 +149,6 @@ def fft_conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
raise TypeError(f"{type(weight)=}")


@fft_conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
dilation: tuple[int, ...],
groups: int,
mask: Weight | None = None,
) -> jax.Array:
x = fft_conv_general_dilated(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand All @@ -176,10 +161,9 @@ def _(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0)


@single_dispatch(argnum=1)
def fft_conv_nd_transpose(
input: jax.Array,
weight: Any,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
Expand All @@ -203,20 +187,6 @@ def fft_conv_nd_transpose(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
raise TypeError(f"{type(weight)=}")


@fft_conv_nd_transpose.def_type(jax.Array)
def _(
input: jax.Array,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
dilation: tuple[int, ...],
out_padding: int,
mask: Weight | None = None,
) -> jax.Array:
transposed_padding = calculate_transpose_padding(
padding=padding,
extra_padding=out_padding,
Expand Down Expand Up @@ -321,10 +291,9 @@ def separable_fft_conv_nd(
)


@single_dispatch(argnum=1)
def conv_nd(
input: jax.Array,
weight: Any,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
Expand All @@ -348,20 +317,6 @@ def conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set
"""
raise TypeError(f"{type(weight)=}")


@conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: Any,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
dilation: tuple[int, ...],
groups: int,
mask: Weight | None = None,
) -> jax.Array:
x = jax.lax.conv_general_dilated(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand All @@ -375,10 +330,9 @@ def _(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0)


@single_dispatch(argnum=1)
def conv_nd_transpose(
input: jax.Array,
weight: Any,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
Expand All @@ -402,20 +356,6 @@ def conv_nd_transpose(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
raise TypeError(f"{type(weight)=}")


@conv_nd_transpose.def_type(jax.Array)
def _(
input: jax.Array,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
dilation: tuple[int, ...],
out_padding: int,
mask: Weight | None = None,
) -> jax.Array:
transposed_padding = calculate_transpose_padding(
padding=padding,
extra_padding=out_padding,
Expand Down Expand Up @@ -519,10 +459,9 @@ def depthwise_conv_nd(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze(x + bias, 0)


@single_dispatch(argnum=1)
def spectral_conv_nd(
input: Annotated[jax.Array, "I..."],
weight: Any,
weight: Weight,
modes: tuple[int, ...],
) -> Annotated[jax.Array, "O..."]:
"""fourier neural operator convolution function.
Expand All @@ -533,15 +472,7 @@ def spectral_conv_nd(
where dim is the number of spatial dimensions on the
modes: number of modes included in the fft representation of the input.
"""
raise TypeError(f"{type(weight)=}")


@spectral_conv_nd.def_type(jax.Array)
def _(
input: Annotated[jax.Array, "I..."],
weight: Annotated[jax.Array, "2NOI..."],
modes: tuple[int, ...],
) -> Annotated[jax.Array, "O..."]:
def generate_modes_slices(modes: tuple[int, ...]):
*ms, ml = modes
slices_ = [[slice(None, ml)]]
Expand All @@ -560,10 +491,9 @@ def generate_modes_slices(modes: tuple[int, ...]):
return jnp.fft.irfftn(out, s=(*si, sl))


@single_dispatch(argnum=1)
def local_conv_nd(
input: jax.Array,
weight: Weight,
weight: jax.Array,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
Expand All @@ -588,20 +518,6 @@ def local_conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
raise TypeError(f"{type(weight)=}")


@local_conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: Weight,
bias: jax.Array | None,
strides: tuple[int, ...],
padding: tuple[tuple[int, int], ...],
dilation: tuple[int, ...],
kernel_size: tuple[int, ...],
mask: Weight | None = None,
) -> jax.Array:
x = jax.lax.conv_general_dilated_local(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand Down
34 changes: 9 additions & 25 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,13 @@
)

T = TypeVar("T")
PyTree = Any


class Batched(Generic[T]):
pass


PyTree = Any


def is_lazy_call(instance, *_1, **_2) -> bool:
return getattr(instance, "in_features", False) is None


def is_lazy_init(_1, in_features, *_2, **_3) -> bool:
return in_features is None


@single_dispatch(argnum=1)
def linear(
input: jax.Array,
weight: Any,
Expand All @@ -72,19 +61,6 @@ def linear(
corresponding to the (in_feature_1, in_feature_2, ...)
out_axis: the axis to put the result. accepts ``in`` values.
"""
del input, bias, in_axis, out_axis
raise TypeError(f"{type(weight)=}")


@linear.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
in_axis: tuple[int, ...],
out_axis: int,
) -> jax.Array:
# weight array handler
in_axis = [axis if axis >= 0 else axis + input.ndim for axis in in_axis]
lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3
rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3
Expand All @@ -108,6 +84,14 @@ def _(
return result + bias


def is_lazy_call(instance, *_1, **_2) -> bool:
return getattr(instance, "in_features", False) is None


def is_lazy_init(_1, in_features, *_2, **_3) -> bool:
return in_features is None


def infer_in_features(instance, x, **__) -> tuple[int, ...]:
in_axis = getattr(instance, "in_axis", ())
return tuple(x.shape[i] for i in tuplify(in_axis))
Expand Down

0 comments on commit 829b431

Please sign in to comment.