diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index 8616562..e432a0c 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -558,7 +558,7 @@ def is_lazy_init(_1, in_features, *_2, **_3) -> bool: return in_features is None -def infer_in_features(instance, x, *_1, **_3) -> int: +def infer_in_features(_1, x, *_2, **_3) -> int: return x.shape[0] diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 4ec3937..26f80d7 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -46,7 +46,7 @@ def generate_einsum_pattern( rhs_ndim: int, in_axis: Sequence[int], out_axis: Sequence[int], -): +) -> tuple[str, str, str]: # helper function to generate the einsum pattern for linear layer # with flexible input and output axes lhs_alpha = "abcdefghijklmnopqrstuvwxyz" @@ -55,12 +55,13 @@ def generate_einsum_pattern( in_axis = [axis if axis >= 0 else axis + lhs_ndim for axis in in_axis] lhs = "".join(lhs_alpha[axis] for axis in range(lhs_ndim)) rhs = rhs_alpha[: len(out_axis)] + "".join(lhs_alpha[axis] for axis in in_axis) - out = [lhs_alpha[axis] for axis in range(lhs_ndim) if axis not in in_axis] - out_axis = [o if o >= 0 else o + len(out) + 1 for o in out_axis] - + rest_out = [lhs_alpha[axis] for axis in range(lhs_ndim) if axis not in in_axis] + out = [None] * (len(out_axis) + len(rest_out)) + out_axis = [o if o >= 0 else o + len(out) for o in out_axis] for i, axis in enumerate(out_axis): - out.insert(axis, rhs_alpha[i]) - return f"{lhs},{rhs}->{''.join(out)}" + out[axis] = rhs_alpha[i] + out = "".join(rest_out.pop(0) if o is None else o for o in out) + return lhs, rhs, out def linear( @@ -82,18 +83,14 @@ def linear( in_axis: axes to apply the linear layer to. out_axis: result axis. """ - pattern = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis) - result = jnp.einsum(pattern, input, weight) + lhs, rhs, out = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis) + result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight) if bias is None: return result - with jax.ensure_compile_time_eval(): - broadcast_shape = list(range(result.ndim)) - for axis in out_axis: - broadcast_shape[axis] = None - broadcast_shape = [i for i in broadcast_shape if i is not None] - bias = jnp.expand_dims(bias, axis=broadcast_shape) + bias = bias.reshape(*bias.shape, *[1] * (result.ndim - bias.ndim)) + bias = jnp.einsum(f"{''.join(sorted(out))}->{out}", bias) return result + bias diff --git a/serket/_src/utils/dispatch.py b/serket/_src/utils/dispatch.py index 4149a61..181c41a 100644 --- a/serket/_src/utils/dispatch.py +++ b/serket/_src/utils/dispatch.py @@ -19,7 +19,7 @@ from serket._src.utils.inspect import get_params -def single_dispatch(argnum: int = 0): +def single_dispatch(argnum: int): """Single dispatch with argnum""" def decorator(func):