diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 85f104b..dd4b3ab 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -24,6 +24,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytreeclass>=0.4.0
python -m pip install keras_core>=0.1.1
+ python -m pip install scikit-learn
python -m pip install pytest wheel optax jaxlib coverage kernex
- name: Pytest Check
run: |
diff --git a/docs/API/api.rst b/docs/API/api.rst
index b40d7c9..ec98a6f 100644
--- a/docs/API/api.rst
+++ b/docs/API/api.rst
@@ -14,6 +14,7 @@
activations
attention
+ clustering
containers
convolution
dropout
diff --git a/docs/API/clustering.rst b/docs/API/clustering.rst
new file mode 100644
index 0000000..6f24605
--- /dev/null
+++ b/docs/API/clustering.rst
@@ -0,0 +1,27 @@
+Clustering
+---------------------------------
+.. currentmodule:: serket.nn
+
+
+.. autoclass:: KMeans
+ :members:
+ __call__
+
+.. note::
+
+ Example usage plot of :class:`.nn.KMeans`
+
+ .. code-block::
+
+ >>> import jax
+ >>> import jax.random as jr
+ >>> import matplotlib.pyplot as plt
+ >>> import serket as sk
+ >>> x = jr.uniform(jr.PRNGKey(0), shape=(500, 2))
+ >>> layer = sk.nn.KMeans(clusters=5, tol=1e-6)
+ >>> labels, state = layer(x)
+ >>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r")
+ >>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4)
+ .. image:: kmeans.svg
+ :width: 600
+ :align: center
\ No newline at end of file
diff --git a/docs/API/kmeans.svg b/docs/API/kmeans.svg
new file mode 100644
index 0000000..85db423
--- /dev/null
+++ b/docs/API/kmeans.svg
@@ -0,0 +1,2405 @@
+
+
+
diff --git a/kmeans.svg b/kmeans.svg
new file mode 100644
index 0000000..20633e8
--- /dev/null
+++ b/kmeans.svg
@@ -0,0 +1,2405 @@
+
+
+
diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py
index 7f706bc..ed7493f 100644
--- a/serket/nn/__init__.py
+++ b/serket/nn/__init__.py
@@ -47,6 +47,7 @@
)
from .attention import MultiHeadAttention
from .blocks import UNetBlock, VGG16Block, VGG19Block
+from .clustering import KMeans
from .containers import RandomApply, Sequential
from .convolution import (
Conv1D,
@@ -284,6 +285,8 @@
"Rotate2D",
"Solarize2D",
"VerticalShear2D",
+ # kmeans
+ "KMeans",
# pooling
"AdaptiveAvgPool1D",
"AdaptiveAvgPool2D",
diff --git a/serket/nn/clustering.py b/serket/nn/clustering.py
new file mode 100644
index 0000000..e628244
--- /dev/null
+++ b/serket/nn/clustering.py
@@ -0,0 +1,230 @@
+# Copyright 2023 Serket authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import functools as ft
+from typing import NamedTuple
+
+import jax
+import jax.numpy as jnp
+import jax.random as jr
+from typing_extensions import Annotated
+
+import serket as sk
+from serket.nn.custom_transform import tree_eval, tree_state
+from serket.nn.utils import IsInstance, Range
+
+"""K-means utility functions."""
+
+
+class KMeansState(NamedTuple):
+ centers: Annotated[jax.Array, "f32[k,d]"]
+ error: Annotated[jax.Array, "f32[k,d]"]
+ iters: int = 0
+
+
+def distances_from_centers(
+ data: Annotated[jax.Array, "f32[n,d]"],
+ centers: Annotated[jax.Array, "f32[k,d]"],
+) -> Annotated[jax.Array, "f32[n,k]"]:
+ # for each point find the distance to each center
+ return jax.vmap(lambda xi: jax.vmap(jnp.linalg.norm)(xi - centers))(data)
+
+
+def labels_from_distances(
+ distances: Annotated[jax.Array, "f32[n,k]"]
+) -> Annotated[jax.Array, "f32[n,1]"]:
+ # for each point find the index of the closest center
+ return jnp.argmin(distances, axis=1, keepdims=True)
+
+
+def centers_from_labels(
+ data: Annotated[jax.Array, "f32[n,d]"],
+ labels: Annotated[jax.Array, "i32[n,1]"],
+ k: int,
+) -> Annotated[jax.Array, "f32[k,d]"]:
+ # for each center find the mean of the points assigned to it
+ return jax.vmap(
+ lambda k: jnp.divide(
+ jnp.sum(jnp.where(labels == k, data, 0), axis=0),
+ jnp.sum(jnp.where(labels == k, 1, 0)).clip(min=1),
+ )
+ )(jnp.arange(k))
+
+
+@ft.partial(jax.jit, static_argnames="clusters")
+def kmeans(
+ data: Annotated[jax.Array, "f32[n,d]"],
+ state: KMeansState,
+ *,
+ clusters: int,
+ tol: float = 1e-4,
+) -> KMeansState:
+ """K-means clustering algorithm.
+
+ Steps:
+ 1. Initialize the centers randomly. f32[k,d]
+ 2. Calculate point-wise distances from data and centers. f32[n,d],f32[k,d] -> f32[n,k]
+ 3. Assign each point to the closest center. f32[n,k] -> f32[n,1]
+ 4. Calculate the new centers from data and labels. f32[n,d],f32[n,1] -> f32[k,d]
+ 5. Repeat steps 2-4 until the centers converge.
+
+ Args:
+ data: The data to cluster in the shape of n points with d dimensions.
+ state: initial ``KMeansState`` containing:
+
+ - centers: The initial centers of the clusters.
+ - error: The initial error of the centers at each iteration.
+ - iters: The inital number of iterations (i.e. 0)
+
+ clusters: The number of clusters.
+ tol: The tolerance for convergence. default: 1e-4
+
+ Returns:
+ A ``KMeansState`` named tuple containing:
+
+ - centers: The final centers of the clusters.
+ - error: The error of the centers at each iteration.
+ - iters: The number of iterations until convergence.
+ """
+
+ if not isinstance(state, KMeansState):
+ raise TypeError(f"{state=} not an instance of `KMeansState`")
+
+ def step(state: KMeansState) -> KMeansState:
+ # f32[n,d] -> f32[n,k]
+ distances = distances_from_centers(data, state.centers)
+
+ # f32[n,k] -> f32[n,1]
+ labels = labels_from_distances(distances)
+
+ # f32[n,d] -> f32[k,d]
+ centers = centers_from_labels(data, labels, clusters)
+
+ error = jnp.abs(centers - state.centers)
+
+ return KMeansState(centers, error, state.iters + 1)
+
+ def condition(state: KMeansState) -> bool:
+ return jnp.all(state.error > tol)
+
+ return jax.lax.while_loop(condition, step, state)
+
+
+@sk.autoinit
+class KMeans(sk.TreeClass):
+ """Vanilla K-means clustering algorithm.
+
+ Args:
+ clusters: The number of clusters.
+ tol: The tolerance for convergence. default: 1e-4
+
+ Example:
+ >>> import serket as sk
+ >>> import jax.random as jr
+ >>> features = 3
+ >>> clusters = 4
+ >>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
+ >>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
+ >>> labels, state = layer(x)
+ >>> centers = state.centers
+ >>> assert labels.shape == (100, 1)
+ >>> assert centers.shape == (clusters, features)
+
+ Note:
+ To use the :class:`.nn.KMeans` layer in evaluation mode, use :func:`.tree_eval` to
+ disallow centers update and only predict the labels based on the current
+ centers.
+
+ >>> import serket as sk
+ >>> import jax.random as jr
+ >>> features = 3
+ >>> clusters = 4
+ >>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
+ >>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
+ >>> x, state = layer(x)
+ >>> eval_layer = sk.tree_eval(layer)
+ >>> y = jr.uniform(jr.PRNGKey(0), shape=(1, features))
+ >>> y, eval_state = eval_layer(y, state)
+ >>> # centers are not updated
+ >>> assert jnp.all(eval_state.centers == state.centers)
+ """
+
+ clusters: int = sk.field(callbacks=[IsInstance(int), Range(1)])
+ tol: float = sk.field(callbacks=[IsInstance(float), Range(0, min_inclusive=False)])
+
+ def __call__(
+ self,
+ x: jax.Array,
+ state: KMeansState | None = None,
+ ) -> tuple[jax.Array, KMeansState]:
+ """K-means clustering algorithm.
+
+ Args:
+ x: The data to cluster in the shape of n points with d dimensions.
+ state: initial ``KMeansState`` containing:
+
+ - centers: The initial centers of the clusters.
+ - error: The initial error of the centers at each iteration.
+ - iters: The inital number of iterations (i.e. 0)
+
+ if ``None`` then the initial state is initialized using the rule
+ defined in :func:`.tree_state`
+
+ Returns:
+ A tuple containing the labels and a ``KMeansState``.
+ """
+
+ state = sk.tree_state(self, x) if state is None else state
+ clusters, tol, state = jax.lax.stop_gradient((self.clusters, self.tol, state))
+ state = kmeans(x, state, clusters=clusters, tol=tol)
+ distances = distances_from_centers(x, state.centers)
+ labels = labels_from_distances(distances)
+ return labels, state
+
+
+class EvalKMeans(sk.TreeClass):
+ """K-means clustering algorithm evaluation.
+
+ Evaluates the K-means clustering algorithm on the input data and returns the
+ input data and the final ``KMeansState`` with no further updates.
+ """
+
+ def __call__(
+ self,
+ x: jax.Array,
+ state: KMeansState,
+ ) -> tuple[jax.Array, KMeansState]:
+ distances = distances_from_centers(x, state.centers)
+ labels = labels_from_distances(distances)
+ state = state._replace(iters=None, error=None)
+ return labels, state
+
+
+@tree_state.def_state(KMeans)
+def init_kmeans(layer: KMeans, data: jax.Array) -> KMeansState:
+ centers = jr.uniform(
+ key=jr.PRNGKey(0),
+ minval=data.min(),
+ maxval=data.max(),
+ shape=(layer.clusters, data.shape[1]),
+ )
+
+ return KMeansState(centers=centers, error=centers + jnp.inf, iters=0)
+
+
+@tree_eval.def_eval(KMeans)
+def eval_kmeans(_: KMeans) -> EvalKMeans:
+ return EvalKMeans()
diff --git a/tests/test_clustering.py b/tests/test_clustering.py
new file mode 100644
index 0000000..1177661
--- /dev/null
+++ b/tests/test_clustering.py
@@ -0,0 +1,49 @@
+# Copyright 2023 Serket authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import warnings
+
+import jax.numpy as jnp
+import numpy.testing as npt
+from jax import random
+
+import serket as sk
+
+# Suppress FutureWarning
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def test_kmeans():
+ from sklearn.cluster import KMeans
+
+ rng = random.PRNGKey(42)
+ k = 3
+ x = random.uniform(rng, (100, 2))
+ sc_ = KMeans(n_clusters=k, tol=1e-5).fit(x)
+ layer = sk.nn.KMeans(k, tol=1e-5)
+ _, state = layer(x)
+ npt.assert_allclose(
+ jnp.sort(sc_.cluster_centers_, axis=0),
+ jnp.sort(state.centers, axis=0),
+ atol=1e-6,
+ )
+ # pick a point near one of the centers
+ xx = jnp.array([[0.5, 0.2]])
+ labels, eval_state = sk.tree_eval(layer)(xx, state)
+ assert eval_state.iters is None
+ assert eval_state.error is None
+ # centers should not change
+ npt.assert_allclose(state.centers, eval_state.centers, atol=1e-6)
+ assert labels[0] == 0