Skip to content

Commit

Permalink
edit einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 1, 2023
1 parent c6f5409 commit 0a70bc9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
56 changes: 35 additions & 21 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,35 @@ def general_linear(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
axes: tuple[int, ...],
in_axis: tuple[int, ...],
out_axis: int,
) -> jax.Array:
"""Apply a linear layer to input at axes defined by ``axes``"""

# ensure negative axes
def generate_einsum_string(*axes: tuple[int, ...]) -> str:
axes = sorted(axes)
total_axis = abs(min(axes)) # get the total number of axes
alpha = "".join(map(str, range(total_axis + 1)))
input = "..." + alpha[:total_axis]
weight = alpha[total_axis]
weight += "".join([input[axis] for axis in axes])
out = "".join([ai for ai in input if ai not in weight])
out += alpha[total_axis]
return f"{input},{weight}->{out}"

axes = map(lambda i: i if i < 0 else i - input.ndim, axes)
einsum_string = generate_einsum_string(*axes)
out = jnp.einsum(einsum_string, input, weight)
return out if bias is None else (out + bias)
in_axis = sorted([axis if axis >= 0 else axis + input.ndim for axis in in_axis])
lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3
rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3
out = "".join(str(axis) for axis in range(input.ndim) if axis not in in_axis)
out_axis = out_axis if out_axis >= 0 else out_axis + len(out) + 1
out = out[:out_axis] + "F" + out[out_axis:]

try:
einsum = f"{lhs},{rhs}->{out}"
result = jnp.einsum(einsum, input, weight)
except ValueError as error:
raise ValueError(

Check warning on line 69 in serket/_src/nn/linear.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/nn/linear.py#L68-L69

Added lines #L68 - L69 were not covered by tests
f"{input.shape=}\n"
f"{weight.shape=}\n"
f"{einsum=}\n"
f"{in_axis=}\n"
f"{out_axis=}\n"
f"{error=}"
)

if bias is None:
return result
broadcast_shape = list(range(result.ndim))
del broadcast_shape[out_axis]
bias = jnp.expand_dims(bias, axis=broadcast_shape)
return result + bias


def infer_in_features(instance, x, **__) -> tuple[int, ...]:
Expand Down Expand Up @@ -179,8 +188,13 @@ def __init__(

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, input: jax.Array) -> jax.Array:
out = general_linear(input, self.weight, self.bias, self.in_axis)
return jnp.moveaxis(out, -1, self.out_axis)
return general_linear(
input=input,
weight=self.weight,
bias=self.bias,
in_axis=self.in_axis,
out_axis=self.out_axis,
)


class Identity(sk.TreeClass):
Expand Down
4 changes: 2 additions & 2 deletions serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,8 +1717,8 @@ def __init__(
def __call__(
self,
input: jax.Array,
state: RNNState | None = None,
) -> jax.Array | tuple[jax.Array, RNNState]:
state: State | None = None,
) -> jax.Array | tuple[jax.Array, State]:
"""Scans the RNN cell over a sequence.
Args:
Expand Down

0 comments on commit 0a70bc9

Please sign in to comment.