Skip to content

Commit

Permalink
fix kwonly treestate
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 14, 2023
1 parent 6770dbd commit 41736b0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions serket/nn/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
class NoState(sk.TreeClass):
"""No state placeholder."""

def __init__(self, _: Any, __: Any):
del _, __
def __init__(self, _: Any, *, array: Any):
del _, array


def tree_state(tree: T, *, array: jax.Array | None = None) -> T:
Expand Down Expand Up @@ -72,7 +72,7 @@ def tree_state(tree: T, *, array: jax.Array | None = None) -> T:
... return "some state"
>>> sk.tree_state(LayerWithState())
'some state'
>>> sk.tree_state(LayerWithState(), jax.numpy.ones((1, 1)))
>>> sk.tree_state(LayerWithState(), array=jax.numpy.ones((1, 1)))
'some state'
"""

Expand Down
4 changes: 2 additions & 2 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def materialize_cell(instance, x: jax.Array, state=None, **__) -> RNNCell:
# in case of lazy initialization, we need to materialize the cell
# before it can be passed to the scan function
cell = instance.cell
state = state if state is not None else sk.tree_state(instance, x)
state = state if state is not None else sk.tree_state(instance, array=x)
state = split_state(state, 2) if instance.backward_cell is not None else [state]
_, cell = cell.at["__call__"](x[0], state[0])
return cell
Expand All @@ -1360,7 +1360,7 @@ def materialize_backward_cell(instance, x, state=None, **__) -> RNNCell | None:
if instance.backward_cell is None:
return None
cell = instance.cell
state = state if state is not None else sk.tree_state(instance, x)
state = state if state is not None else sk.tree_state(instance, array=x)
state = split_state(state, 2) if instance.backward_cell is not None else [state]
_, cell = cell.at["__call__"](x[0], state[-1])
return cell
Expand Down

0 comments on commit 41736b0

Please sign in to comment.