Skip to content

Commit

Permalink
minor edits in linear
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 11, 2024
1 parent a7749db commit a0ef29e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def is_lazy_init(_1, in_features, *_2, **_3) -> bool:
return in_features is None


def infer_in_features(instance, x, *_1, **_3) -> int:
def infer_in_features(_1, x, *_2, **_3) -> int:
return x.shape[0]


Expand Down
25 changes: 11 additions & 14 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def generate_einsum_pattern(
rhs_ndim: int,
in_axis: Sequence[int],
out_axis: Sequence[int],
):
) -> tuple[str, str, str]:
# helper function to generate the einsum pattern for linear layer
# with flexible input and output axes
lhs_alpha = "abcdefghijklmnopqrstuvwxyz"
Expand All @@ -55,12 +55,13 @@ def generate_einsum_pattern(
in_axis = [axis if axis >= 0 else axis + lhs_ndim for axis in in_axis]
lhs = "".join(lhs_alpha[axis] for axis in range(lhs_ndim))
rhs = rhs_alpha[: len(out_axis)] + "".join(lhs_alpha[axis] for axis in in_axis)
out = [lhs_alpha[axis] for axis in range(lhs_ndim) if axis not in in_axis]
out_axis = [o if o >= 0 else o + len(out) + 1 for o in out_axis]

rest_out = [lhs_alpha[axis] for axis in range(lhs_ndim) if axis not in in_axis]
out = [None] * (len(out_axis) + len(rest_out))
out_axis = [o if o >= 0 else o + len(out) for o in out_axis]
for i, axis in enumerate(out_axis):
out.insert(axis, rhs_alpha[i])
return f"{lhs},{rhs}->{''.join(out)}"
out[axis] = rhs_alpha[i]
out = "".join(rest_out.pop(0) if o is None else o for o in out)
return lhs, rhs, out


def linear(
Expand All @@ -82,18 +83,14 @@ def linear(
in_axis: axes to apply the linear layer to.
out_axis: result axis.
"""
pattern = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis)
result = jnp.einsum(pattern, input, weight)
lhs, rhs, out = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis)
result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight)

if bias is None:
return result

with jax.ensure_compile_time_eval():
broadcast_shape = list(range(result.ndim))
for axis in out_axis:
broadcast_shape[axis] = None
broadcast_shape = [i for i in broadcast_shape if i is not None]
bias = jnp.expand_dims(bias, axis=broadcast_shape)
bias = bias.reshape(*bias.shape, *[1] * (result.ndim - bias.ndim))
bias = jnp.einsum(f"{''.join(sorted(out))}->{out}", bias)
return result + bias


Expand Down
2 changes: 1 addition & 1 deletion serket/_src/utils/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from serket._src.utils.inspect import get_params


def single_dispatch(argnum: int = 0):
def single_dispatch(argnum: int):
"""Single dispatch with argnum"""

def decorator(func):
Expand Down

0 comments on commit a0ef29e

Please sign in to comment.