Skip to content

Commit

Permalink
linear/conv op dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 11, 2024
1 parent e5f0066 commit 3cbd820
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 4 deletions.
109 changes: 108 additions & 1 deletion serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from serket import TreeClass
from serket._src.nn.initialization import resolve_init
from serket._src.utils.convert import canonicalize
from serket._src.utils.dispatch import single_dispatch
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.padding import (
calculate_transpose_padding,
Expand Down Expand Up @@ -139,6 +140,7 @@ def grouped_matmul(x, y, groups) -> jax.Array:
)


@single_dispatch(argnum=1)
def fft_conv_nd(
input: jax.Array,
weight: Weight,
Expand Down Expand Up @@ -168,6 +170,21 @@ def fft_conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
del input, bias, strides, padding, dilation, groups, mask
raise NotImplementedError(f"{type(weight)=}")


@fft_conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
dilation: Sequence[int],
groups: int,
mask: Weight | None = None,
) -> jax.Array:
x = fft_conv_general_dilated(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand All @@ -180,6 +197,7 @@ def fft_conv_nd(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0)


@single_dispatch(argnum=1)
def fft_conv_nd_transpose(
input: jax.Array,
weight: Weight,
Expand All @@ -206,6 +224,21 @@ def fft_conv_nd_transpose(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
del input, bias, strides, padding, dilation, out_padding, mask
raise NotImplementedError(f"{type(weight)=}")


@fft_conv_nd_transpose.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
dilation: Sequence[int],
out_padding: int,
mask: Weight | None = None,
) -> jax.Array:
transposed_padding = calculate_transpose_padding(
padding=padding,
extra_padding=out_padding,
Expand Down Expand Up @@ -310,6 +343,7 @@ def separable_fft_conv_nd(
)


@single_dispatch(argnum=1)
def conv_nd(
input: jax.Array,
weight: Weight,
Expand All @@ -336,6 +370,21 @@ def conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set
"""
del input, bias, strides, padding, dilation, groups, mask
raise NotImplementedError(f"{type(weight)=}")


@conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
dilation: Sequence[int],
groups: int,
mask: Weight | None = None,
) -> jax.Array:
x = jax.lax.conv_general_dilated(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand All @@ -349,6 +398,7 @@ def conv_nd(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze((x + bias), 0)


@single_dispatch(argnum=1)
def conv_nd_transpose(
input: jax.Array,
weight: Weight,
Expand All @@ -361,6 +411,37 @@ def conv_nd_transpose(
) -> jax.Array:
"""Transposed convolution function wrapping ``jax.lax.conv_general_dilated``.
Args:
input: input array. shape is ``(in_features, spatial)``.
weight: convolutional kernel. shape is ``(out_features, in_features, kernel)``.
bias: bias. shape is ``(out_features, spatial)``. set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
for different padding in each dimension.
dilation: dilation of the convolutional kernel accepts tuple of integers
for different dilation in each dimension.
out_padding: padding of the output after convolution.
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
del input, bias, strides, padding, dilation, out_padding, mask
raise NotImplementedError(f"{type(weight)=}")


@conv_nd_transpose.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
dilation: Sequence[int],
out_padding: int,
mask: Weight | None = None,
) -> jax.Array:
"""Transposed convolution function wrapping ``jax.lax.conv_general_dilated``.
Args:
input: input array. shape is ``(in_features, spatial)``.
weight: convolutional kernel. shape is ``(out_features, in_features, kernel)``.
Expand Down Expand Up @@ -478,6 +559,7 @@ def depthwise_conv_nd(
return jnp.squeeze(x, 0) if bias is None else jnp.squeeze(x + bias, 0)


@single_dispatch(argnum=1)
def spectral_conv_nd(
input: Annotated[jax.Array, "I..."],
weight: Weight,
Expand All @@ -491,7 +573,16 @@ def spectral_conv_nd(
where dim is the number of spatial dimensions on the
modes: number of modes included in the fft representation of the input.
"""
del input, modes
raise NotImplementedError(f"{type(weight)=}")


@spectral_conv_nd.def_type(jax.Array)
def _(
input: Annotated[jax.Array, "I..."],
weight: jax.Array,
modes: Sequence[int],
) -> Annotated[jax.Array, "O..."]:
def generate_modes_slices(modes: Sequence[int]):
*ms, ml = modes
slices_ = [[slice(None, ml)]]
Expand All @@ -510,9 +601,10 @@ def generate_modes_slices(modes: Sequence[int]):
return jnp.fft.irfftn(out, s=(*si, sl))


@single_dispatch(argnum=1)
def local_conv_nd(
input: jax.Array,
weight: jax.Array,
weight: Weight,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
Expand All @@ -537,6 +629,21 @@ def local_conv_nd(
mask: a binary mask multiplied with the convolutional kernel. shape is
``(out_features, in_features, kernel)``. set to ``None`` to not use a mask.
"""
del input, bias, strides, padding, dilation, kernel_size, mask
raise NotImplementedError(f"{type(weight)=}")


@local_conv_nd.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
dilation: Sequence[int],
kernel_size: Sequence[int],
mask: Weight | None = None,
) -> jax.Array:
x = jax.lax.conv_general_dilated_local(
lhs=jnp.expand_dims(input, 0),
rhs=weight if mask is None else weight * mask,
Expand Down
16 changes: 15 additions & 1 deletion serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from serket._src.nn.initialization import resolve_init
from serket._src.utils.convert import tuplify
from serket._src.utils.dispatch import single_dispatch
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.typing import DType, InitType
from serket._src.utils.validate import validate_pos_int
Expand Down Expand Up @@ -63,9 +64,10 @@ def generate_einsum_pattern(
return f"{lhs},{rhs}->{''.join(out)}"


@single_dispatch(argnum=1)
def linear(
input: jax.Array,
weight: jax.Array,
weight: Any,
bias: jax.Array | None,
in_axis: Sequence[int] = (-1,),
out_axis: Sequence[int] = (-1,),
Expand All @@ -82,6 +84,18 @@ def linear(
in_axis: axes to apply the linear layer to.
out_axis: result axis.
"""
del input, bias, in_axis, out_axis
raise NotImplementedError(f"{type(weight)=}")


@linear.def_type(jax.Array)
def _(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
in_axis: Sequence[int] = (-1,),
out_axis: Sequence[int] = (-1,),
) -> jax.Array:
pattern = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis)
result = jnp.einsum(pattern, input, weight)

Expand Down
5 changes: 4 additions & 1 deletion serket/_src/utils/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def wrapper(*args, **kwargs):
klass = type(args[argnum])
except IndexError:
argname = get_params(func)[argnum].name
klass = type(kwargs[argname])
try:
klass = type(kwargs[argname])
except KeyError:
raise TypeError(f"{func.__name__} missing required {argname=}")
return dispatcher.dispatch(klass)(*args, **kwargs)

wrapper.def_type = dispatcher.register
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array]
InitType = Union[InitLiteral, InitFuncType]
MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"]
Weight = Annotated[jax.Array, "OI..."]
Weight = Union[jax.Array, Any]


ActivationLiteral = Literal[
Expand Down
44 changes: 44 additions & 0 deletions tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, Sequence

import jax
import jax.numpy as jnp
import numpy.testing as npt
Expand Down Expand Up @@ -207,3 +209,45 @@ def test_mlp_bias():
layer = layer.at["out_bias"].set(b3)

npt.assert_allclose(layer(x), y)


def test_linear_dispatch():
# test per-weight dispatch on function-level
class PlusOne(NamedTuple):
array: jax.Array

# dispatch is on the function level
@sk.nn.linear.def_type(PlusOne)
def _(
input: jax.Array,
weight: PlusOne,
bias: jax.Array | None,
in_axis: Sequence[int],
out_axis: Sequence[int],
):
return sk.nn.linear(
input=input,
weight=weight.array + 1,
bias=bias,
in_axis=in_axis,
out_axis=out_axis,
)

# layer version that depends on the linear function
# will dispatch on the weight
linear = sk.nn.Linear(
in_features=2,
out_features=3,
key=jax.random.PRNGKey(0),
weight_init=lambda key, shape, dtype: PlusOne(
jax.random.normal(key, shape, dtype)
),
bias_init="zeros",
)

input = jax.random.normal(jax.random.PRNGKey(1), (10, 2))
lhs = linear(input)

rhs = input @ (linear.weight.array.T + 1) + linear.bias

npt.assert_allclose(lhs, rhs, atol=1e-6)

0 comments on commit 3cbd820

Please sign in to comment.