Skip to content

Commit

Permalink
expose linear
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 31, 2024
1 parent ff3b193 commit 797e6c8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/API/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ Linear
.. autoclass:: Linear
.. autoclass:: Identity
.. autoclass:: Embedding
.. autoclass:: MLP
.. autoclass:: MLP

.. autofunction:: linear
16 changes: 13 additions & 3 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,23 @@ def is_lazy_init(_, in_features, *__, **___) -> bool:
return in_features is None


def general_linear(
def linear(
input: jax.Array,
weight: jax.Array,
bias: jax.Array | None,
in_axis: tuple[int, ...],
out_axis: int,
) -> jax.Array:
"""A@B + C.
Args:
input: input array.
weight: weight array. In the shape of (out_features, in_feature_1, in_feature_2, ...)
bias: bias array. In the shape of (out_features,) or ``None`` for no bias.
in_axis: which axes in the input to apply the linear layer to. ``tuple`` of ``ints``
corresponding to the (in_feature_1, in_feature_2, ...)
out_axis: the axis to put the result. accepts ``in`` values.
"""
in_axis = [axis if axis >= 0 else axis + input.ndim for axis in in_axis]
lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3
rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3
Expand Down Expand Up @@ -194,7 +204,7 @@ def __init__(
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, input: jax.Array) -> jax.Array:
"""Apply a linear transformation to the input."""
return general_linear(
return linear(
input=input,
weight=self.weight,
bias=self.bias,
Expand Down Expand Up @@ -369,7 +379,7 @@ def batched_linear(key: jax.Array) -> Batched[Linear]:
return sk.tree_mask(layer)

self.in_linear = Linear(in_features, hidden_features, key=keys[0], **kwargs)
self.mid_linear = sk.tree_unmask(batched_linear(keys[1:-1]))
self.mid_linear: Batched[Linear] = sk.tree_unmask(batched_linear(keys[1:-1]))
self.out_linear = Linear(hidden_features, out_features, key=keys[-1], **kwargs)

def __call__(self, input: jax.Array) -> jax.Array:
Expand Down
5 changes: 2 additions & 3 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
dropout_nd,
random_cutout_nd,
)
from serket._src.nn.linear import MLP, Embedding, Identity, Linear
from serket._src.nn.linear import MLP, Embedding, Identity, Linear, linear
from serket._src.nn.normalization import (
BatchNorm,
EvalBatchNorm,
Expand Down Expand Up @@ -262,10 +262,9 @@
"dropout_nd",
"random_cutout_nd",
# linear
"FNN",
"MLP",
"linear",
"Embedding",
"GeneralLinear",
"Identity",
"Linear",
# norms
Expand Down

0 comments on commit 797e6c8

Please sign in to comment.