Skip to content

Commit

Permalink
Update linear.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 29, 2023
1 parent cc8bc24 commit ed889bf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,11 @@ def batched_linear(key: jax.Array) -> Batched[Linear]:
return sk.tree_mask(layer)

self.in_linear = Linear(in_features, hidden_features, key=keys[0], **kwargs)
self.hidden_linear = sk.tree_unmask(batched_linear(keys[1:-1]))
self.mid_linear = sk.tree_unmask(batched_linear(keys[1:-1]))
self.out_linear = Linear(hidden_features, out_features, key=keys[-1], **kwargs)

def __call__(self, input: jax.Array) -> jax.Array:
input = self.act(self.in_linear(input))
weight_h, bias_h = self.hidden_linear.weight, self.hidden_linear.bias
weight_h, bias_h = self.mid_linear.weight, self.mid_linear.bias
input = scan_linear(input, weight_h, bias_h, self.act)
return self.out_linear(input)

0 comments on commit ed889bf

Please sign in to comment.