diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index 0dd674d..49b5a8b 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -29,14 +29,12 @@ from serket import TreeClass from serket._src.nn.initialization import resolve_init -from serket._src.utils.convert import canonicalize -from serket._src.utils.dispatch import single_dispatch -from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init -from serket._src.utils.padding import ( +from serket._src.utils.convert import ( calculate_transpose_padding, canonicalize, delayed_canonicalize_padding, ) +from serket._src.utils.dispatch import single_dispatch from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init from serket._src.utils.typing import ( DilationType, diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index c352a88..815ff91 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -15,7 +15,7 @@ from __future__ import annotations import functools as ft -from typing import Sequence +from typing import Any, Sequence import jax import jax.numpy as jnp @@ -35,11 +35,12 @@ from serket._src.utils.validate import validate_pos_int +@ft.lru_cache(maxsize=None) def generate_einsum_pattern( lhs_ndim: int, rhs_ndim: int, - in_axis: Sequence[int], - out_axis: Sequence[int], + in_axis: tuple[int, ...], + out_axis: tuple[int, ...], ) -> tuple[str, str, str]: # helper function to generate the einsum pattern for linear layer # with flexible input and output axes @@ -90,8 +91,9 @@ def _( in_axis: Sequence[int] = (-1,), out_axis: Sequence[int] = (-1,), ) -> jax.Array: - pattern = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis) - result = jnp.einsum(pattern, input, weight) + in_axis, out_axis = tuple(in_axis), tuple(out_axis) + lhs, rhs, out = generate_einsum_pattern(input.ndim, weight.ndim, in_axis, out_axis) + result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight) if bias is None: return result diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 4817fb9..b888008 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os