Skip to content

Commit

Permalink
Kmeans (#30)
Browse files Browse the repository at this point in the history
* kmeans

* Update tests.yml

* edit kmeans return
  • Loading branch information
ASEM000 authored Aug 13, 2023
1 parent d89d01c commit 81b20ce
Show file tree
Hide file tree
Showing 8 changed files with 5,121 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytreeclass>=0.4.0
python -m pip install keras_core>=0.1.1
python -m pip install scikit-learn
python -m pip install pytest wheel optax jaxlib coverage kernex
- name: Pytest Check
run: |
Expand Down
1 change: 1 addition & 0 deletions docs/API/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

activations
attention
clustering
containers
convolution
dropout
Expand Down
27 changes: 27 additions & 0 deletions docs/API/clustering.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Clustering
---------------------------------
.. currentmodule:: serket.nn


.. autoclass:: KMeans
:members:
__call__

.. note::

Example usage plot of :class:`.nn.KMeans`

.. code-block::
>>> import jax
>>> import jax.random as jr
>>> import matplotlib.pyplot as plt
>>> import serket as sk
>>> x = jr.uniform(jr.PRNGKey(0), shape=(500, 2))
>>> layer = sk.nn.KMeans(clusters=5, tol=1e-6)
>>> labels, state = layer(x)
>>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r")
>>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4)
.. image:: kmeans.svg
:width: 600
:align: center
2,405 changes: 2,405 additions & 0 deletions docs/API/kmeans.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,405 changes: 2,405 additions & 0 deletions kmeans.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from .attention import MultiHeadAttention
from .blocks import UNetBlock, VGG16Block, VGG19Block
from .clustering import KMeans
from .containers import RandomApply, Sequential
from .convolution import (
Conv1D,
Expand Down Expand Up @@ -284,6 +285,8 @@
"Rotate2D",
"Solarize2D",
"VerticalShear2D",
# kmeans
"KMeans",
# pooling
"AdaptiveAvgPool1D",
"AdaptiveAvgPool2D",
Expand Down
230 changes: 230 additions & 0 deletions serket/nn/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright 2023 Serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools as ft
from typing import NamedTuple

import jax
import jax.numpy as jnp
import jax.random as jr
from typing_extensions import Annotated

import serket as sk
from serket.nn.custom_transform import tree_eval, tree_state
from serket.nn.utils import IsInstance, Range

"""K-means utility functions."""


class KMeansState(NamedTuple):
centers: Annotated[jax.Array, "f32[k,d]"]
error: Annotated[jax.Array, "f32[k,d]"]
iters: int = 0


def distances_from_centers(
data: Annotated[jax.Array, "f32[n,d]"],
centers: Annotated[jax.Array, "f32[k,d]"],
) -> Annotated[jax.Array, "f32[n,k]"]:
# for each point find the distance to each center
return jax.vmap(lambda xi: jax.vmap(jnp.linalg.norm)(xi - centers))(data)


def labels_from_distances(
distances: Annotated[jax.Array, "f32[n,k]"]
) -> Annotated[jax.Array, "f32[n,1]"]:
# for each point find the index of the closest center
return jnp.argmin(distances, axis=1, keepdims=True)


def centers_from_labels(
data: Annotated[jax.Array, "f32[n,d]"],
labels: Annotated[jax.Array, "i32[n,1]"],
k: int,
) -> Annotated[jax.Array, "f32[k,d]"]:
# for each center find the mean of the points assigned to it
return jax.vmap(
lambda k: jnp.divide(
jnp.sum(jnp.where(labels == k, data, 0), axis=0),
jnp.sum(jnp.where(labels == k, 1, 0)).clip(min=1),
)
)(jnp.arange(k))


@ft.partial(jax.jit, static_argnames="clusters")
def kmeans(
data: Annotated[jax.Array, "f32[n,d]"],
state: KMeansState,
*,
clusters: int,
tol: float = 1e-4,
) -> KMeansState:
"""K-means clustering algorithm.
Steps:
1. Initialize the centers randomly. f32[k,d]
2. Calculate point-wise distances from data and centers. f32[n,d],f32[k,d] -> f32[n,k]
3. Assign each point to the closest center. f32[n,k] -> f32[n,1]
4. Calculate the new centers from data and labels. f32[n,d],f32[n,1] -> f32[k,d]
5. Repeat steps 2-4 until the centers converge.
Args:
data: The data to cluster in the shape of n points with d dimensions.
state: initial ``KMeansState`` containing:
- centers: The initial centers of the clusters.
- error: The initial error of the centers at each iteration.
- iters: The inital number of iterations (i.e. 0)
clusters: The number of clusters.
tol: The tolerance for convergence. default: 1e-4
Returns:
A ``KMeansState`` named tuple containing:
- centers: The final centers of the clusters.
- error: The error of the centers at each iteration.
- iters: The number of iterations until convergence.
"""

if not isinstance(state, KMeansState):
raise TypeError(f"{state=} not an instance of `KMeansState`")

def step(state: KMeansState) -> KMeansState:
# f32[n,d] -> f32[n,k]
distances = distances_from_centers(data, state.centers)

# f32[n,k] -> f32[n,1]
labels = labels_from_distances(distances)

# f32[n,d] -> f32[k,d]
centers = centers_from_labels(data, labels, clusters)

error = jnp.abs(centers - state.centers)

return KMeansState(centers, error, state.iters + 1)

def condition(state: KMeansState) -> bool:
return jnp.all(state.error > tol)

return jax.lax.while_loop(condition, step, state)


@sk.autoinit
class KMeans(sk.TreeClass):
"""Vanilla K-means clustering algorithm.
Args:
clusters: The number of clusters.
tol: The tolerance for convergence. default: 1e-4
Example:
>>> import serket as sk
>>> import jax.random as jr
>>> features = 3
>>> clusters = 4
>>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
>>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
>>> labels, state = layer(x)
>>> centers = state.centers
>>> assert labels.shape == (100, 1)
>>> assert centers.shape == (clusters, features)
Note:
To use the :class:`.nn.KMeans` layer in evaluation mode, use :func:`.tree_eval` to
disallow centers update and only predict the labels based on the current
centers.
>>> import serket as sk
>>> import jax.random as jr
>>> features = 3
>>> clusters = 4
>>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
>>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
>>> x, state = layer(x)
>>> eval_layer = sk.tree_eval(layer)
>>> y = jr.uniform(jr.PRNGKey(0), shape=(1, features))
>>> y, eval_state = eval_layer(y, state)
>>> # centers are not updated
>>> assert jnp.all(eval_state.centers == state.centers)
"""

clusters: int = sk.field(callbacks=[IsInstance(int), Range(1)])
tol: float = sk.field(callbacks=[IsInstance(float), Range(0, min_inclusive=False)])

def __call__(
self,
x: jax.Array,
state: KMeansState | None = None,
) -> tuple[jax.Array, KMeansState]:
"""K-means clustering algorithm.
Args:
x: The data to cluster in the shape of n points with d dimensions.
state: initial ``KMeansState`` containing:
- centers: The initial centers of the clusters.
- error: The initial error of the centers at each iteration.
- iters: The inital number of iterations (i.e. 0)
if ``None`` then the initial state is initialized using the rule
defined in :func:`.tree_state`
Returns:
A tuple containing the labels and a ``KMeansState``.
"""

state = sk.tree_state(self, x) if state is None else state
clusters, tol, state = jax.lax.stop_gradient((self.clusters, self.tol, state))
state = kmeans(x, state, clusters=clusters, tol=tol)
distances = distances_from_centers(x, state.centers)
labels = labels_from_distances(distances)
return labels, state


class EvalKMeans(sk.TreeClass):
"""K-means clustering algorithm evaluation.
Evaluates the K-means clustering algorithm on the input data and returns the
input data and the final ``KMeansState`` with no further updates.
"""

def __call__(
self,
x: jax.Array,
state: KMeansState,
) -> tuple[jax.Array, KMeansState]:
distances = distances_from_centers(x, state.centers)
labels = labels_from_distances(distances)
state = state._replace(iters=None, error=None)
return labels, state


@tree_state.def_state(KMeans)
def init_kmeans(layer: KMeans, data: jax.Array) -> KMeansState:
centers = jr.uniform(
key=jr.PRNGKey(0),
minval=data.min(),
maxval=data.max(),
shape=(layer.clusters, data.shape[1]),
)

return KMeansState(centers=centers, error=centers + jnp.inf, iters=0)


@tree_eval.def_eval(KMeans)
def eval_kmeans(_: KMeans) -> EvalKMeans:
return EvalKMeans()
49 changes: 49 additions & 0 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2023 Serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings

import jax.numpy as jnp
import numpy.testing as npt
from jax import random

import serket as sk

# Suppress FutureWarning
warnings.simplefilter(action="ignore", category=FutureWarning)


def test_kmeans():
from sklearn.cluster import KMeans

rng = random.PRNGKey(42)
k = 3
x = random.uniform(rng, (100, 2))
sc_ = KMeans(n_clusters=k, tol=1e-5).fit(x)
layer = sk.nn.KMeans(k, tol=1e-5)
_, state = layer(x)
npt.assert_allclose(
jnp.sort(sc_.cluster_centers_, axis=0),
jnp.sort(state.centers, axis=0),
atol=1e-6,
)
# pick a point near one of the centers
xx = jnp.array([[0.5, 0.2]])
labels, eval_state = sk.tree_eval(layer)(xx, state)
assert eval_state.iters is None
assert eval_state.error is None
# centers should not change
npt.assert_allclose(state.centers, eval_state.centers, atol=1e-6)
assert labels[0] == 0

0 comments on commit 81b20ce

Please sign in to comment.