diff --git a/serket/nn/fully_connected.py b/serket/nn/fully_connected.py index 5fdf6a7..65a1c26 100644 --- a/serket/nn/fully_connected.py +++ b/serket/nn/fully_connected.py @@ -14,9 +14,10 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import Any, Generic, Sequence, TypeVar import jax +import jax.numpy as jnp import jax.random as jr import serket as sk @@ -28,6 +29,13 @@ from serket.nn.initialization import InitType from serket.nn.linear import Linear +T = TypeVar("T") + + +class Batched(Generic[T]): + pass + + PyTree = Any @@ -73,11 +81,8 @@ def __init__( num_hidden_layers = len(layers) - 2 if isinstance(act_func, tuple): - if len(act_func) != (num_hidden_layers + 1): - raise ValueError( - "tuple of activation functions must have " - f"length: {(num_hidden_layers+1)=}, " - ) + if len(act_func) != (num_hidden_layers): + raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}") self.act_func = tuple(resolve_activation(act) for act in act_func) else: @@ -224,11 +229,8 @@ def __init__( keys = jr.split(key, num_hidden_layers + 1) if isinstance(act_func, tuple): - if len(act_func) != (num_hidden_layers + 1): - raise ValueError( - "tuple of activation functions must have " - f"length: {(num_hidden_layers+1)=}, " - ) + if len(act_func) != (num_hidden_layers): + raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}") self.act_func = tuple(resolve_activation(act) for act in act_func) else: self.act_func = resolve_activation(act_func) @@ -244,58 +246,6 @@ def batched_linear(key) -> Batched[Linear]: + [Linear(hidden_size, out_features, key=keys[-1], **kwargs)] ) - def _single_call(self, x: jax.Array, **k) -> jax.Array: - x = self.act_func(self.in_layer(x)) - - if self.mid_layer.bias is None: - - def scan_mid_layer(x: jax.Array, weight_bias: jax.Array): - weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - return self.act_func(x @ weight + bias), None - - x, _ = jax.lax.scan(scan_mid_layer, x, self.mid_layer.bias) - - else: - bias = self.mid_layer.bias[:, :, None] - weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1) - - def scan_mid_layer(x: jax.Array, weight_bias: jax.Array): - weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - return self.act_func(x @ weight + bias), None - - x, _ = jax.lax.scan(scan_mid_layer, x, weight_bias) - - x = self.out_layer(x) - return x - - def _multi_call(self, x: jax.Array, **k) -> jax.Array: - a0, *ah = self.act_func - x = a0(self.in_layer(x)) - - if self.mid_layer.bias is None: - - def scan_mid_layer(x_index: jax.Array, weight: jax.Array): - x, index = x_index - x = jax.lax.switch(index, ah, x @ weight) - return [x, index + 1], None - - (x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], self.mid_layer.weight) - - else: - bias = self.mid_layer.bias[:, :, None] - weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1) - - def scan_mid_layer(x_index: jax.Array, weight_bias: jax.Array): - x, index = x_index - weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - x = jax.lax.switch(index, ah, x @ weight + bias) - return [x, index + 1], None - - (x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], weight_bias) - - x = self.out_layer(x) - return x - def __call__(self, x: jax.Array, **k) -> jax.Array: l0, lm, lh = self.layers