Skip to content

Commit

Permalink
Update fully_connected.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 26, 2023
1 parent 89d0aae commit 6d01cfb
Showing 1 changed file with 13 additions and 63 deletions.
76 changes: 13 additions & 63 deletions serket/nn/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from __future__ import annotations

from typing import Any, Sequence
from typing import Any, Generic, Sequence, TypeVar

import jax
import jax.numpy as jnp
import jax.random as jr

import serket as sk
Expand All @@ -28,6 +29,13 @@
from serket.nn.initialization import InitType
from serket.nn.linear import Linear

T = TypeVar("T")


class Batched(Generic[T]):
pass


PyTree = Any


Expand Down Expand Up @@ -73,11 +81,8 @@ def __init__(
num_hidden_layers = len(layers) - 2

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers + 1):
raise ValueError(
"tuple of activation functions must have "
f"length: {(num_hidden_layers+1)=}, "
)
if len(act_func) != (num_hidden_layers):
raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}")

self.act_func = tuple(resolve_activation(act) for act in act_func)
else:
Expand Down Expand Up @@ -224,11 +229,8 @@ def __init__(
keys = jr.split(key, num_hidden_layers + 1)

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers + 1):
raise ValueError(
"tuple of activation functions must have "
f"length: {(num_hidden_layers+1)=}, "
)
if len(act_func) != (num_hidden_layers):
raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}")
self.act_func = tuple(resolve_activation(act) for act in act_func)
else:
self.act_func = resolve_activation(act_func)
Expand All @@ -244,58 +246,6 @@ def batched_linear(key) -> Batched[Linear]:
+ [Linear(hidden_size, out_features, key=keys[-1], **kwargs)]
)

def _single_call(self, x: jax.Array, **k) -> jax.Array:
x = self.act_func(self.in_layer(x))

if self.mid_layer.bias is None:

def scan_mid_layer(x: jax.Array, weight_bias: jax.Array):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return self.act_func(x @ weight + bias), None

x, _ = jax.lax.scan(scan_mid_layer, x, self.mid_layer.bias)

else:
bias = self.mid_layer.bias[:, :, None]
weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1)

def scan_mid_layer(x: jax.Array, weight_bias: jax.Array):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return self.act_func(x @ weight + bias), None

x, _ = jax.lax.scan(scan_mid_layer, x, weight_bias)

x = self.out_layer(x)
return x

def _multi_call(self, x: jax.Array, **k) -> jax.Array:
a0, *ah = self.act_func
x = a0(self.in_layer(x))

if self.mid_layer.bias is None:

def scan_mid_layer(x_index: jax.Array, weight: jax.Array):
x, index = x_index
x = jax.lax.switch(index, ah, x @ weight)
return [x, index + 1], None

(x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], self.mid_layer.weight)

else:
bias = self.mid_layer.bias[:, :, None]
weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1)

def scan_mid_layer(x_index: jax.Array, weight_bias: jax.Array):
x, index = x_index
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
x = jax.lax.switch(index, ah, x @ weight + bias)
return [x, index + 1], None

(x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], weight_bias)

x = self.out_layer(x)
return x

def __call__(self, x: jax.Array, **k) -> jax.Array:
l0, lm, lh = self.layers

Expand Down

0 comments on commit 6d01cfb

Please sign in to comment.