Skip to content

Commit

Permalink
tree state array is kwonly arg
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 14, 2023
1 parent 81b20ce commit 1492ee5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
10 changes: 5 additions & 5 deletions serket/nn/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __call__(
A tuple containing the labels and a ``KMeansState``.
"""

state = sk.tree_state(self, x) if state is None else state
state = sk.tree_state(self, array=x) if state is None else state
clusters, tol, state = jax.lax.stop_gradient((self.clusters, self.tol, state))
state = kmeans(x, state, clusters=clusters, tol=tol)
distances = distances_from_centers(x, state.centers)
Expand All @@ -214,12 +214,12 @@ def __call__(


@tree_state.def_state(KMeans)
def init_kmeans(layer: KMeans, data: jax.Array) -> KMeansState:
def init_kmeans(layer: KMeans, *, array: jax.Array) -> KMeansState:
centers = jr.uniform(
key=jr.PRNGKey(0),
minval=data.min(),
maxval=data.max(),
shape=(layer.clusters, data.shape[1]),
minval=array.min(),
maxval=array.max(),
shape=(layer.clusters, array.shape[1]),
)

return KMeansState(centers=centers, error=centers + jnp.inf, iters=0)
Expand Down
9 changes: 5 additions & 4 deletions serket/nn/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, _: Any, __: Any):
del _, __


def tree_state(tree: T, array: jax.Array | None = None) -> T:
def tree_state(tree: T, *, array: jax.Array | None = None) -> T:
"""Build state for a tree of layers.
Some layers require state to be initialized before training. For example,
Expand All @@ -43,8 +43,9 @@ def tree_state(tree: T, array: jax.Array | None = None) -> T:
Args:
tree: A tree of layers.
array: (Optional) array to use for initializing state required by some layers
(e.g. :class:`nn.ConvGRU1DCell`). default: ``None``.
array: (Optional keyword argument) array to use for initializing state
required by some layers (e.g. :class:`nn.ConvGRU1DCell`). default:
``None``.
Returns:
A tree of state leaves if it has state, otherwise ``NoState`` leaf.
Expand Down Expand Up @@ -86,7 +87,7 @@ def dispatch_func(leaf):
return tree_state.state_dispatcher(leaf)
except TypeError:
# with optional array argument
return tree_state.state_dispatcher(leaf, array)
return tree_state.state_dispatcher(leaf, array=array)

return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf)

Expand Down
24 changes: 12 additions & 12 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,32 +1585,32 @@ def _check_rnn_cell_tree_state_input(cell: RNNCell, array):


@tree_state.def_state(ConvLSTMNDCell)
def conv_lstm_init_state(cell: ConvLSTMNDCell, x: Any) -> ConvLSTMNDState:
x = _check_rnn_cell_tree_state_input(cell, x)
shape = (cell.hidden_features, *x.shape[1:])
zeros = jnp.zeros(shape).astype(x.dtype)
def conv_lstm_init_state(cell: ConvLSTMNDCell, *, array: Any) -> ConvLSTMNDState:
array = _check_rnn_cell_tree_state_input(cell, array)
shape = (cell.hidden_features, *array.shape[1:])
zeros = jnp.zeros(shape).astype(array.dtype)
return ConvLSTMNDState(zeros, zeros)


@tree_state.def_state(ConvGRUNDCell)
def conv_gru_init_state(cell: ConvGRUNDCell, x: Any) -> ConvGRUNDState:
x = _check_rnn_cell_tree_state_input(cell, x)
shape = (cell.hidden_features, *x.shape[1:])
return ConvGRUNDState(jnp.zeros(shape).astype(x.dtype))
def conv_gru_init_state(cell: ConvGRUNDCell, *, array: Any) -> ConvGRUNDState:
array = _check_rnn_cell_tree_state_input(cell, array)
shape = (cell.hidden_features, *array.shape[1:])
return ConvGRUNDState(jnp.zeros(shape).astype(array.dtype))


@tree_state.def_state(ScanRNN)
def scan_rnn_init_state(rnn: ScanRNN, x: jax.Array | None = None) -> RNNState:
def scan_rnn_init_state(rnn: ScanRNN, *, array: jax.Array | None = None) -> RNNState:
# the idea here is to combine the state of the forward and backward cells
# if backward cell exists. to have single state input for `ScanRNN` and
# single state output not to complicate the ``__call__`` signature on the
# user side.
x = [None] if x is None else x
array = [None] if array is None else array
# non-spatial cells don't need an input instead
# pass `None` to `tree_state`
# otherwise pass the a single time step input to the cells
return (
tree_state(rnn.cell, array=x[0])
tree_state(rnn.cell, array=array[0])
if rnn.backward_cell is None
else concat_state(tree_state((rnn.cell, rnn.backward_cell), array=x[0]))
else concat_state(tree_state((rnn.cell, rnn.backward_cell), array=array[0]))
)

0 comments on commit 1492ee5

Please sign in to comment.