Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Faster backend compatible ot.dist #701

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
from inspect import signature
from .backend import get_backend, Backend, NumpyBackend, JaxBackend

__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def tic():
r"""Python implementation of Matlab tic() function"""
global __time_tic_toc
__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def toc(message="Elapsed time : {} s"):
r"""Python implementation of Matlab toc() function"""
t = time.time()
t = time.perf_counter()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc


def toq():
r"""Python implementation of Julia toc() function"""
t = time.time()
t = time.perf_counter()
return t - __time_tic_toc


Expand Down Expand Up @@ -251,7 +251,7 @@
return a2, b2, M2


def euclidean_distances(X, Y, squared=False):
def euclidean_distances(X, Y, squared=False, nx=None):
r"""
Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
Expand All @@ -270,13 +270,13 @@
-------
distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""

nx = get_backend(X, Y)
if nx is None:
nx = get_backend(X, Y)

a2 = nx.einsum("ij,ij->i", X, X)
b2 = nx.einsum("ij,ij->i", Y, Y)

c = -2 * nx.dot(X, Y.T)
c = -2 * nx.dot(X, nx.transpose(Y))
c += a2[:, None]
c += b2[None, :]

Expand All @@ -291,11 +291,21 @@
return c


def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
def dist(
x1,
x2=None,
metric="sqeuclidean",
p=2,
w=None,
backend="auto",
nx=None,
use_tensor=False,
):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
from all compatible backends for the following metrics:
'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.

Parameters
----------
Expand All @@ -315,7 +325,16 @@
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
w : array-like, rank 1
Weights for the weighted metrics.

backend : str, optional
Backend to use for the computation. If 'auto', the backend is
automatically selected based on the input data. if 'scipy',
the ``scipy.spatial.distance.cdist`` function is used (and gradients are
detached).
use_tensor : bool, optional
If true use tensorized computation for the distance matrix which can
cause memory issues for large datasets.
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `x1`.

Returns
-------
Expand All @@ -324,12 +343,69 @@
distance matrix computed with given metric

"""
if nx is None:
nx = get_backend(x1, x2)
if x2 is None:
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
if backend == "scipy": # force scipy backend with cdist function
x1 = nx.to_numpy(x1)
x2 = nx.to_numpy(x2)
if isinstance(metric, str) and metric.endswith("minkowski"):
return nx.from_numpy(cdist(x1, x2, metric=metric, p=p, w=w))
if w is not None:
return nx.from_numpy(cdist(x1, x2, metric=metric, w=w))
return nx.from_numpy(cdist(x1, x2, metric=metric))

Check warning on line 357 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L351-L357

Added lines #L351 - L357 were not covered by tests
elif metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True, nx=nx)
elif metric == "euclidean":
return euclidean_distances(x1, x2, squared=False)
return euclidean_distances(x1, x2, squared=False, nx=nx)
elif metric == "cityblock":
if use_tensor:
return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2)

Check warning on line 364 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L364

Added line #L364 was not covered by tests
else:
M = 0.0
for i in range(x1.shape[1]):
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :])
return M
elif metric == "minkowski":
if w is None:
if use_tensor:
return nx.power(

Check warning on line 373 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L373

Added line #L373 was not covered by tests
nx.sum(
nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2
),
1 / p,
)
else:
M = 0.0
for i in range(x1.shape[1]):
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
return M ** (1 / p)
else:
if use_tensor:
return nx.power(

Check warning on line 386 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L386

Added line #L386 was not covered by tests
nx.sum(
w[None, None, :]
* nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p),
axis=2,
),
1 / p,
)
else:
M = 0.0
for i in range(x1.shape[1]):
M += w[i] * nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
return M ** (1 / p)
elif metric == "cosine":
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
elif metric == "correlation":
x1 = x1 - nx.mean(x1, axis=1)[:, None]
x2 = x2 - nx.mean(x2, axis=1)[:, None]
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
else:
if not get_backend(x1, x2).__name__ == "numpy":
raise NotImplementedError()
Expand Down
53 changes: 45 additions & 8 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@
import numpy as np
import sys
import pytest
import scipy

lst_metrics = [
"euclidean",
"sqeuclidean",
"cityblock",
"cosine",
"minkowski",
"correlation",
]

lst_all_metrics = lst_metrics + [
"braycurtis",
"canberra",
"chebyshev",
"dice",
"hamming",
"jaccard",
"matching",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]


def get_LazyTensor(nx):
Expand Down Expand Up @@ -185,7 +210,7 @@ def test_dist():

assert D4[0, 1] == D4[1, 0]

# dist shoul return squared euclidean
# dist should return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)

Expand Down Expand Up @@ -230,20 +255,32 @@ def test_dist():
ot.dist(x, x, metric="wminkowski")


def test_dist_backends(nx):
@pytest.mark.parametrize("metric", lst_metrics)
def test_dist_backends(nx, metric):
n = 100
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
x1 = nx.from_numpy(x)

lst_metric = ["euclidean", "sqeuclidean"]
D = ot.dist(x, x, metric=metric)
D1 = ot.dist(x1, x1, metric=metric)

for metric in lst_metric:
D = ot.dist(x, x, metric=metric)
D1 = ot.dist(x1, x1, metric=metric)
# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)

# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)

@pytest.mark.parametrize("metric", lst_all_metrics)
def test_dist_vs_cdist(metric):
n = 10

rng = np.random.RandomState(0)
x = rng.randn(n, 2)
y = rng.randn(n + 1, 2)

D = ot.dist(x, y, metric=metric)
D2 = scipy.spatial.distance.cdist(x, y, metric=metric)

np.testing.assert_allclose(D, D2, atol=1e-15)


def test_dist0():
Expand Down
Loading