Skip to content

Commit

Permalink
promote kemeans to cluster.kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 8, 2023
1 parent fdbb33d commit 81b8bb1
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 18 deletions.
4 changes: 2 additions & 2 deletions docs/API/clustering.rst → docs/API/cluster.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Clustering
🌈 Clustering API
---------------------------------
.. currentmodule:: serket.nn
.. currentmodule:: serket.cluster


.. autoclass:: KMeans
Expand Down
1 change: 0 additions & 1 deletion docs/API/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

activations
attention
clustering
containers
convolution
dropout
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Install from github::
:maxdepth: 1

API/common
API/cluster
API/nn
API/image
API/pytreeclass
Expand Down
9 changes: 5 additions & 4 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
unfreeze,
)

from . import image, nn
from . import cluster, image, nn
from .custom_transform import tree_eval, tree_state
from .nn.activation import def_act_entry
from .nn.initialization import def_init_entry

__all__ = (
__all__ = [
# general utils
"TreeClass",
"is_tree_equal",
Expand Down Expand Up @@ -80,13 +80,14 @@
"Partial",
"leafwise",
# serket
"cluster",
"nn",
"image",
"tree_eval",
"tree_state",
"def_init_entry",
"def_act_entry",
)
]


__version__ = "0.2.0rc2"
__version__ = "0.2.0rc3"
18 changes: 18 additions & 0 deletions serket/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2023 Serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .kmeans import KMeans

__all__ = ["KMeans"]
10 changes: 5 additions & 5 deletions serket/nn/clustering.py → serket/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ class KMeans(sk.TreeClass):
Example:
Example usage plot of :class:`.nn.KMeans`
Example usage plot of :class:`.cluster.KMeans`
>>> import jax
>>> import jax.random as jr
>>> import matplotlib.pyplot as plt
>>> import serket as sk
>>> x = jr.uniform(jr.PRNGKey(0), shape=(500, 2))
>>> layer = sk.nn.KMeans(clusters=5, tol=1e-6)
>>> layer = sk.cluster.KMeans(clusters=5, tol=1e-6)
>>> labels, state = layer(x)
>>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r") # doctest: +SKIP
>>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4) # doctest: +SKIP
Expand All @@ -154,14 +154,14 @@ class KMeans(sk.TreeClass):
>>> features = 3
>>> clusters = 4
>>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
>>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
>>> layer = sk.cluster.KMeans(clusters=clusters, tol=1e-6)
>>> labels, state = layer(x)
>>> centers = state.centers
>>> assert labels.shape == (100, 1)
>>> assert centers.shape == (clusters, features)
Note:
To use the :class:`.nn.KMeans` layer in evaluation mode, use :func:`.tree_eval` to
To use the :class:`.cluster.KMeans` layer in evaluation mode, use :func:`.tree_eval` to
disallow centers update and only predict the labels based on the current
centers.
Expand All @@ -170,7 +170,7 @@ class KMeans(sk.TreeClass):
>>> features = 3
>>> clusters = 4
>>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
>>> layer = sk.nn.KMeans(clusters=clusters, tol=1e-6)
>>> layer = sk.cluster.KMeans(clusters=clusters, tol=1e-6)
>>> x, state = layer(x)
>>> eval_layer = sk.tree_eval(layer)
>>> y = jr.uniform(jr.PRNGKey(0), shape=(1, features))
Expand Down
4 changes: 2 additions & 2 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
VerticalTranslate2D,
)

__all__ = (
__all__ = [
# augment
"AdjustContrast2D",
"JigSaw2D",
Expand Down Expand Up @@ -75,4 +75,4 @@
"VerticalFlip2D",
"VerticalShear2D",
"VerticalTranslate2D",
)
]
3 changes: 0 additions & 3 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
ThresholdedReLU,
)
from .attention import MultiHeadAttention
from .clustering import KMeans
from .containers import RandomApply, RandomChoice, Sequential
from .convolution import (
Conv1D,
Expand Down Expand Up @@ -239,8 +238,6 @@
"GroupNorm",
"InstanceNorm",
"LayerNorm",
# kmeans
"KMeans",
# pooling
"AdaptiveAvgPool1D",
"AdaptiveAvgPool2D",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_kmeans():
k = 3
x = random.uniform(rng, (100, 2))
sc_ = KMeans(n_clusters=k, tol=1e-5).fit(x)
layer = sk.nn.KMeans(k, tol=1e-5)
layer = sk.cluster.KMeans(k, tol=1e-5)
_, state = layer(x)
npt.assert_allclose(
jnp.sort(sc_.cluster_centers_, axis=0),
Expand Down

0 comments on commit 81b8bb1

Please sign in to comment.