Skip to content

Commit

Permalink
Update linear.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 8, 2024
1 parent d9e92d3 commit e3840a1
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ class Batched(Generic[T]):
pass


def generate_einsum_pattern(
lhs_ndim: int,
rhs_ndim: int,
in_axis: Sequence[int],
out_axis: Sequence[int],
):
# helper function to generate the einsum pattern for linear layer
# with flexible input and output axes
alpha = "abcdefghijklmnopqrstuvwxyz"
assert (len(in_axis) + len(out_axis)) == rhs_ndim
in_axis = [axis if axis >= 0 else axis + lhs_ndim for axis in in_axis]
lhs = "".join(str(axis) for axis in range(lhs_ndim)) # 0, 1, 2, 3
rhs = alpha[: len(out_axis)] + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3
out = [str(axis) for axis in range(lhs_ndim) if axis not in in_axis]
out_axis = [o if o >= 0 else o + len(out) + 1 for o in out_axis]

for i, axis in enumerate(out_axis):
out.insert(axis, alpha[i])
return f"{lhs},{rhs}->{''.join(out)}"


def linear(
input: jax.Array,
weight: Any,
Expand All @@ -60,20 +81,9 @@ def linear(
in_axis: axes to apply the linear layer to.
out_axis: result axis.
"""
assert (len(in_axis) + len(out_axis)) == weight.ndim
alpha = "abcdefghijklmnopqrstuvwxyz"
in_axis = [axis if axis >= 0 else axis + input.ndim for axis in in_axis]
lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3
rhs = alpha[: len(out_axis)] + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3
out = [str(axis) for axis in range(input.ndim) if axis not in in_axis]
out_axis = [o if o >= 0 else o + len(out) + 1 for o in out_axis]

for i, axis in enumerate(out_axis):
out.insert(axis, alpha[i])
out = "".join(out)
einsum_pattern = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis)

try:
einsum_pattern = f"{lhs},{rhs}->{out}"
result = jnp.einsum(einsum_pattern, input, weight)
except ValueError as e:
# the pattern is invalid, raise a more informative error
Expand All @@ -99,7 +109,7 @@ def is_lazy_init(_1, in_features, *_2, **_3) -> bool:
return in_features is None


def infer_in_features(instance, x, **__) -> tuple[int, ...]:
def infer_in_features(instance, x, **_) -> tuple[int, ...]:
in_axis = getattr(instance, "in_axis", ())
return tuple(x.shape[i] for i in tuplify(in_axis))

Expand Down

0 comments on commit e3840a1

Please sign in to comment.