Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Coreset for Kmeans and GaussianMixture clustering #799

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_ml/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unsupervised Clustering Algorithms"""

from .coreset import Coreset # noqa
from .k_means import KMeans # noqa
from .spectral import SpectralClustering # noqa
188 changes: 188 additions & 0 deletions dask_ml/cluster/coreset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import inspect
import logging

import dask.array as da
import dask.dataframe as dd
import numpy as np
from sklearn.base import TransformerMixin, clone

from .._utils import copy_learned_attributes
from ..utils import _timer, check_array
from ..wrappers import ParallelPostFit

logger = logging.getLogger(__name__)


def lightweight_coresets(X, m, *, gen=da.random.RandomState()):
"""
Parameters
----------
X : dask.array, shape = [n_samples, n_features]
input dask array to be sampled
m : int
number of samples to pick from `X`

gen: da.random.RandomState
random state to use for sampling
"""
dists = da.power(X - X.mean(), 2).sum(axis=1)
q = 0.5 / X.shape[0] + 0.5 * dists / dists.sum()
indices = gen.choice(X.shape[0], size=m, p=q, replace=True)
w_lwcs = 1.0 / (m * q[indices])
X_lwcs = X[indices, :]
return X_lwcs, w_lwcs


def get_m(X, k, eps, mode="hard"):
"""
Returns the coreset size, i.e the number of data points to be sampled
from the original set of points
See Theorem 2 from `Scalable k-Means Clustering via Lightweight Coresets`

The resulting coreset is a (eps, k)-lightweight coreset of X

Parameters
----------
X: dask.array
input data. We aim at finding minimum coreset size to be sampled from `X`
mu: float
average value for the input data
k: int
number of cluster k
eps: float
between 0 and 1


Returns
-------
m: int
Number of points to sample
For this number, the set `C` is a (`eps`, `k`)-lightweight coreset of `X`

Notes
-----
The `delta` parameter from the original paper is not used.
In practice most of the time it vanishes,
as a log is applied to its inverse, before being summed to big values
(values which depend on the data size | number of cluster)
"""
X_m, d = X.shape
if hasattr(X_m, "compute"):
X_m = X_m.compute()

if mode == "hard": # hard clustering
numerator = d * k * np.log(k)
elif mode == "soft": # soft clustering
numerator = (d ** 2) * (k ** 2)
else:
raise ValueError("`mode` should be in (hard|soft)")
m = np.ceil(numerator / eps)
if m >= X_m:
_m = np.ceil((d ** 2) * (k ** 2))
logger.warning(
f"""
Number of points to sample ({m}) higher
than input dimension ({d}),
forcing reduction to {_m}
"""
)
m = _m
return m


class Coreset(ParallelPostFit, TransformerMixin):
"""Coreset sampling implementation

Parameters
----------
estimator : Estimator
The underlying estimator to be fitted.

eps: float, default=0.05
For k cluster, the coreset is guaranteed to be a (`eps`, `k`)
coreset of the original data

`eps` must be greater or equal to 0.05
(<= 5% difference in the discretization error).

m : int, default None
Number of points to select to form a coreset

If it is `None` and the estimator has a `n_clusters` or `n_components`
attributes, `m` will atomatically be set depending on
`n_clusters|n_components`, `eps`and the input data when calling `.fit`

random_state : int, RandomState instance or None, optional, default: None
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.

References
----------
- Scalable k-Means Clustering via Lightweight Coresets, 2018
Olivier Bachem, Mario Lucic, Andreas Krause
https://arxiv.org/pdf/1702.08248.pdf

Notes
-----
``A Coreset is a small set of points that approximates
the shape of a larger set of points``.

A clustering algorithm can be applied on the selected subset of points.

Formally, a weighted set `C` is an (`eps`, `k`)-coreset for
some input data X if for any set of cluster centers `Q` (with `|Q| <= k`)
the quantization error computed via `Q` on `X` and
the quantization error computer via `Q` on `C` have at most an
`eps` relative difference.

"""

def __init__(self, estimator, *, eps=0.05, delta=0.01, m=None, random_state=None):
if not (0 < delta < 1):
raise ValueError("`delta` both should be a float between 0 and 1")
if not (0.05 <= eps < 1):
raise ValueError("`eps` should be a float between 0.05 and 1")
if m is None:
k = getattr(estimator, "n_clusters", None) or getattr(
estimator, "n_components", None
)
if not k or not isinstance(k, int):
raise ValueError(
"""`m` is None, `estimator` must have
an attribute in (n_clusters, n_components)"""
)
self.k = k

self.m = m
self.eps = eps
self.estimator = clone(estimator)
self.random_state = da.random.RandomState(random_state)

def fit(self, X, y=None, **kwargs):
if isinstance(X, dd.DataFrame):
X = X.to_dask_array(lengths=True) # if Dask.Dataframe
X = check_array(X, accept_dask_dataframe=False)
if self.m is None:
self.m = get_m(X, self.k, self.eps)

logger.info(f"sampling {self.m} points out of {X.shape[0]}")

logger.info("Starting sampling")
with _timer("sampling", _logger=logger):
Xcs, weights = lightweight_coresets(X, self.m)
Xcs = Xcs.compute()

logger.info("Starting fit")
with _timer("fit", _logger=logger):
if "sample_weights" in inspect.signature(self.estimator.fit).parameters:
kwargs["sample_weights"] = weights
updated_est = self.estimator.fit(Xcs, y, **kwargs)

# Copy over learned attributes
copy_learned_attributes(updated_est, self)
ParallelPostFit.__init__(self, estimator=updated_est)
return self

# TODO : partial fit ?
1 change: 1 addition & 0 deletions docs/source/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Clustering
.. autosummary::
KMeans
SpectralClustering
Coreset

The :mod:`dask_ml.cluster` module implements several algorithms for clustering unlabeled data.

Expand Down
1 change: 1 addition & 0 deletions docs/source/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ with Dask Arrays or DataFrames.

cluster.KMeans
cluster.SpectralClustering
cluster.Coreset


:mod:`dask_ml.decomposition`: Matrix Decomposition
Expand Down
139 changes: 139 additions & 0 deletions tests/test_coreset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pytest
from sklearn.base import BaseEstimator
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

from dask_ml.cluster.coreset import Coreset, get_m, lightweight_coresets
from dask_ml.datasets import make_classification
from dask_ml.utils import assert_estimator_equal


class DummyEstimator(BaseEstimator):
pass


@pytest.mark.parametrize(
"kwargs, m",
[
(dict(X=da.ones((10000, 5)), k=5, eps=0.1), 403),
(dict(X=da.ones((10000, 10)), k=5, eps=0.01), 8048),
(dict(X=da.ones((10000, 10)), k=5, eps=0.05), 1610),
(dict(X=da.ones((10000, 5)), k=3, eps=0.2), 83),
(dict(X=da.ones((10000, 2)), k=20, eps=0.2), 600),
(dict(X=da.ones((10000, 10)), k=20, eps=0.05), 40000), # m > 10k -> fallback
(dict(X=da.ones((10000, 3)), k=3, eps=0.1, mode="soft"), 810),
(dict(X=da.ones((10000, 3)), k=3, eps=0.1, mode="hard"), 99),
],
)
def test_get_m(kwargs, m):
# See Theorem 2 and Section 6. from
# https://dl.acm.org/doi/pdf/10.1145/3219819.3219973
# `d` in the theorem is simply X.shape[1]
computed_m = get_m(**kwargs)
assert int(computed_m) == m


@pytest.mark.parametrize(
"estimator, kwargs, error",
[
(DummyEstimator(), dict(eps=0.05, delta=0.2, m=None), ValueError),
(
DummyEstimator(),
dict(m=200),
None,
), # m is here, no need for n_clusters|n_components
(KMeans(), dict(eps=2), ValueError), # eps between 0 and 1
(KMeans(), dict(delta=2), ValueError), # delta between 0 and 1
(KMeans(), dict(eps=0.1), None), # extracting `n_clusters` attr
(GaussianMixture(), dict(delta=0.2), None), # extracting `n_components` attr
(KMeans(), dict(eps=0.02, delta=0.2, m=None), ValueError),
],
)
def test_init(estimator, kwargs, error):
if error is None:
Coreset(estimator, **kwargs)
else:
with pytest.raises(error):
Coreset(estimator, **kwargs)


def test_lightweight_coresets():
X = da.array([[3, 5], [3, 10], [4, 4]])
gen = da.random.RandomState(3)
Xcs, wcs = lightweight_coresets(X, 2, gen=gen)
np.testing.assert_array_equal(Xcs.compute(), X[[2, 1]].compute())

np.testing.assert_array_almost_equal(
wcs, np.array([2.67948718, 0.836]), decimal=3,
)


class TestKMeans:
def test_basic(self, Xl_blobs_easy):
X, _ = Xl_blobs_easy
m = X.shape[0] / 2

# make it super easy to cluster
skkm = KMeans(n_clusters=3, random_state=0)
dkkm = Coreset(KMeans(n_clusters=3, random_state=0), m=m)
skkm.fit(X)
dkkm.fit(X)

assert dkkm.m == m
assert_estimator_equal(
skkm, dkkm, exclude=["n_iter_", "inertia_", "cluster_centers_", "labels_"]
)

# sampling should reduce absolute sum of squared distances
assert dkkm.inertia_ <= skkm.inertia_

assert dkkm.n_iter_

@pytest.mark.parametrize("eps", [0.05, 0.2])
@pytest.mark.parametrize("k", [3, 10])
def test_inertia(self, eps, k):
"""
Test we find a (eps, k)-lightweight coreset of X
for different values of `eps` and `k`

See section 2 from
https://dl.acm.org/doi/pdf/10.1145/3219819.3219973
"""
X, _ = make_classification(
n_samples=10_000, n_features=k, chunks=100, random_state=1
)

def get_inertia(est, X):
"""
The `intertia_` attribute is relative to the fitted data,
We have to compute intertia regarding to the entire input data
for the coreset version, if we want to compare it with the non-coreset one
"""
return (est.transform(X).min(axis=1) ** 2).sum()

skkm = KMeans(n_clusters=k, random_state=0)
dkkm = Coreset(KMeans(n_clusters=k, random_state=0), eps=eps)
skkm.fit(X)
dkkm.fit(X)

assert_estimator_equal(
skkm, dkkm, exclude=["n_iter_", "inertia_", "cluster_centers_", "labels_"]
)

dkkm_X_inertia = get_inertia(dkkm, X).compute()

# See section 2. formulae 2.
assert dkkm_X_inertia <= (1 + 2 * eps) * skkm.inertia_


def test_dataframes():
df = dd.from_pandas(
pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [6, 7, 8, 9, 10]}), npartitions=2
)

kmeans = Coreset(KMeans(n_clusters=2))
kmeans.fit(df)