From 7eb14d2ed5e416d6d77980f5a8abacb31f276278 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 17 Oct 2024 21:34:00 +0200 Subject: [PATCH 01/24] bw barycenter with batched sqrtm --- ot/gaussian.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 832d193da..25e0f5f6e 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -344,6 +344,26 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, return W +def bures_wasserstein_barycenter_fixpoint(): + pass # TODO + + +def bures_wasserstein_barycenter_gradient_descent(): + r""" + + References + ---------- + [] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020, July). + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. + + [] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. + """ + pass # TODO + + def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False): r"""Return OT linear operator between samples. @@ -412,11 +432,7 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo # fixed point update Cb12 = nx.sqrtm(Cb) - Cnew = Cb12 @ C @ Cb12 - C_ = [] - for i in range(len(C)): - C_.append(nx.sqrtm(Cnew[i])) - Cnew = nx.stack(C_, axis=0) + Cnew = nx.sqrtm(Cb12 @ C @ Cb12) Cnew *= weights[:, None, None] Cnew = nx.sum(Cnew, axis=0) From 869955c0c8b8c17bcaa5755cf118ddab0033a66e Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 19 Oct 2024 18:00:44 +0200 Subject: [PATCH 02/24] BWGD for barycenters --- README.md | 4 + ot/gaussian.py | 224 +++++++++++++++++++++++++++++++++++------- ot/utils.py | 12 +++ test/test_gaussian.py | 9 +- 4 files changed, 207 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 89e16b508..19b5dbbad 100644 --- a/README.md +++ b/README.md @@ -384,3 +384,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. + +[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. diff --git a/ot/gaussian.py b/ot/gaussian.py index 25e0f5f6e..68170c715 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -11,7 +11,7 @@ import warnings from .backend import get_backend -from .utils import dots, is_all_finite, list_to_array +from .utils import dots, is_all_finite, list_to_array, exp_bures def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): @@ -344,35 +344,185 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, return W -def bures_wasserstein_barycenter_fixpoint(): - pass # TODO +def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=False): + r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. + The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` + :ref:`[1] ` by solving -def bures_wasserstein_barycenter_gradient_descent(): - r""" + .. math:: + \Sigma_b = \argmin_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big) + + The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` + where :math: `\Sigma_b` is solution of the following fixed-point algorithm: + + .. math:: + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + + Parameters + ---------- + C : array-like (k,d,d) + covariance of k distributions + weights : array-like (k), optional + weights for each distribution + method : str + method used for the solver, either 'fixed_point' or 'gradient_descent' + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + log : bool, optional + record log if True + + Returns + ------- + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, + 2011. + """ + nx = get_backend(*C,) + + if weights is None: + weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] + + # Init the covariance barycenter + Cb = nx.mean(C * weights[:, None, None], axis=0) + + for it in range(num_iter): + # fixed point update + Cb12 = nx.sqrtm(Cb) + + Cnew = nx.sqrtm(Cb12 @ C @ Cb12) + Cnew *= weights[:, None, None] + Cnew = nx.sum(Cnew, axis=0) + + # check convergence + diff = nx.norm(Cb - Cnew) + if diff <= eps: + break + Cb = Cnew + + if diff > eps: + print("Dit not converge.") + + if log: + log = {} + log['num_iter'] = it + log['final_diff'] = diff + return Cb, log + else: + return Cb + + +def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): + r"""Return OT linear operator between covariances. + + The function estimates the optimal barycenter of empirical distributions. This is equivalent to resolving the fixed point + algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(0,\Sigma)\right}_{i=1}^n` + :ref:`[1] `. + + The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` + + Parameters + ---------- + C : array-like (k,d,d) + covariance of k distributions + weights : array-like (k), optional + weights for each distribution + method : str + method used for the solver, either 'fixed_point' or 'gradient_descent' + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + log : bool, optional + record log if True + step_size: float, optional + step size for the gradient descent, 1 by default + batch_size: int, optional + batch size if use a stochastic gradient descent + + Returns + ------- + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters References ---------- - [] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020, July). + .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). Gradient descent algorithms for Bures-Wasserstein barycenters. In Conference on Learning Theory (pp. 1276-1304). PMLR. - [] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). + .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - pass # TODO + nx = get_backend(*C,) + if weights is None: + weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] -def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False): - r"""Return OT linear operator between samples. + # Init the covariance barycenter + Cb = nx.mean(C * weights[:, None, None], axis=0) + Id = nx.eye(C.shape[-1], type_as=Cb) - The function estimates the optimal barycenter of the - empirical distributions. This is equivalent to resolving the fixed point - algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n` - :ref:`[1] `. + for it in range(num_iter): + Cb12 = nx.sqrtm(Cb) + Cb12_ = nx.inv(Cb12) - The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` + # TODO: add stochastic option with batch_size != number of covs + + # if batch_size is not None: + # inds = np.random.choice(len(sigmas), batch_size, replace=True, p=weights.cpu().numpy()) + # M = sqrtm(dots(sk12, sigmas[inds], sk12)) + # grad_bw = Id - torch.mean(dots(sk_12, M, sk_12), axis=0) + # else: + # M = sqrtm(dots(sk12, sigmas, sk12)) + # grad_bw = Id - torch.sum(dots(sk_12, M, sk_12) * weights[:, None, None], axis=0) + + M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) + grad_bw = Id - nx.sum(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) * weights[:, None, None], axis=0) + Cnew = exp_bures(Cb, - step_size * grad_bw) + + # Right criteria? + # check convergence + diff = nx.norm(Cb - Cnew) + if diff <= eps: + break + + Cb = Cnew + + if diff > eps: + print("Dit not converge.") + + if log: + log = {} + log['num_iter'] = it + log['final_diff'] = diff + return Cb, log + else: + return Cb + + +def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_iter=1000, eps=1e-7, log=False): + r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. + + The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` + :ref:`[1] ` by solving + + .. math:: + (\mu_b, \Sigma_b) = \argmin_{\mu,\Sigma}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(\mu,\Sigma), \mathcal{N}(\mu_i, \Sigma_i)\big) + + The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` where : .. math:: @@ -383,6 +533,8 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo .. math:: \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + We propose two solvers: one based on solving the previous fixed-point problem [1]. Another based on + gradient descent in the Bures-Wasserstein space [74,75]. Parameters ---------- @@ -392,6 +544,8 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo covariance of k distributions weights : array-like (k), optional weights for each distribution + method : str + method used for the solver, either 'fixed_point' or 'gradient_descent' num_iter : int, optional number of iteration for the fixed point algorithm eps : float, optional @@ -409,15 +563,22 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo log : dict log dictionary return only if log==True in parameters - - .. _references-OT-mapping-linear-barycenter: + .. _references-OT-bures_wasserstein-barycenter: References ---------- .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. + + .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. + + .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend(*C, *m,) + nx = get_backend(*m,) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -425,31 +586,18 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo # Compute the mean barycenter mb = nx.sum(m * weights[:, None], axis=0) - # Init the covariance barycenter - Cb = nx.mean(C * weights[:, None, None], axis=0) - - for it in range(num_iter): - # fixed point update - Cb12 = nx.sqrtm(Cb) - - Cnew = nx.sqrtm(Cb12 @ C @ Cb12) - Cnew *= weights[:, None, None] - Cnew = nx.sum(Cnew, axis=0) - - # check convergence - diff = nx.norm(Cb - Cnew) - if diff <= eps: - break - Cb = Cnew + if method == "fixed_point": + out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, eps=eps, log=log) + elif method == "gradient_descent": + out = bures_barycenter_gradient_descent(C, weights=weights, num_iter=num_iter, eps=eps, log=log, step_size=1, batch_size=None) else: - print("Dit not converge.") + raise ValueError("Unknown method '%s'." % method) if log: - log = {} - log['num_iter'] = it - log['final_diff'] = diff + Cb, log = out return mb, Cb, log else: + Cb = out return mb, Cb diff --git a/ot/utils.py b/ot/utils.py index 12910c479..009b481db 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1308,3 +1308,15 @@ def proj_SDP(S, nx=None, vmin=0.): Q = nx.einsum('ijk,ik->ijk', P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1))) + + +def exp_bures(Sigma, S): + r""" + Exponential map Bures-Wasserstein space as Sigma: \exp_\Sigma(S) + """ + nx = get_backend(S) + d = S.shape[-1] + Id = nx.eye(d, type_as=S) + C = Id + S + + return dots(C, Sigma, C) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index c66d5908c..6a27edb7a 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -108,6 +108,7 @@ def test_empirical_bures_wasserstein_distance(nx, bias): np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) +@pytest.mark.parametrize("method", ["fixed_point", "gradient_descent"]) def test_bures_wasserstein_barycenter(nx): n = 50 k = 10 @@ -129,21 +130,21 @@ def test_bures_wasserstein_barycenter(nx): m = nx.from_numpy(m) C = nx.from_numpy(C) - mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True) - mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False) + mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=True) + mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False) np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2) # Test weights argument weights = nx.ones(k) / k - mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False) + mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, method=method, log=False) np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) # test with closed form for diagonal covariance matrices Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)] Cdiag = nx.stack(Cdiag, axis=0) - mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False) + mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, method=method, log=False) Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag] Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0) From be985d1a21135aa35a61a06a6881ef5851e47eac Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 19 Oct 2024 18:18:59 +0200 Subject: [PATCH 03/24] sbwgd for barycenters --- ot/gaussian.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 68170c715..c5d5373c0 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -9,6 +9,7 @@ # License: MIT License import warnings +import numpy as np from .backend import get_backend from .utils import dots, is_all_finite, list_to_array, exp_bures @@ -468,8 +469,10 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, """ nx = get_backend(*C,) + n = C.shape[0] + if weights is None: - weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] + weights = nx.ones(C.shape[0], type_as=C[0]) / n # Init the covariance barycenter Cb = nx.mean(C * weights[:, None, None], axis=0) @@ -479,21 +482,19 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, Cb12 = nx.sqrtm(Cb) Cb12_ = nx.inv(Cb12) - # TODO: add stochastic option with batch_size != number of covs - - # if batch_size is not None: - # inds = np.random.choice(len(sigmas), batch_size, replace=True, p=weights.cpu().numpy()) - # M = sqrtm(dots(sk12, sigmas[inds], sk12)) - # grad_bw = Id - torch.mean(dots(sk_12, M, sk_12), axis=0) - # else: - # M = sqrtm(dots(sk12, sigmas, sk12)) - # grad_bw = Id - torch.sum(dots(sk_12, M, sk_12) * weights[:, None, None], axis=0) + if batch_size is not None and batch_size < n: # if stochastic gradient descent + if batch_size <= 0: + raise ValueError("batch_size must be an integer between 0 and {}".format(n)) + inds = np.random.choice(n, batch_size, replace=True, p=nx._to_numpy(weights)) + M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) + grad_bw = Id - nx.mean(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_), axis=0) + else: # gradient descent + M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) + grad_bw = Id - nx.sum(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) * weights[:, None, None], axis=0) - M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) - grad_bw = Id - nx.sum(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) * weights[:, None, None], axis=0) Cnew = exp_bures(Cb, - step_size * grad_bw) - # Right criteria? + # Right criteria? (for GD, seems fine, but for SGD?) # check convergence diff = nx.norm(Cb - Cnew) if diff <= eps: @@ -513,7 +514,7 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, return Cb -def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_iter=1000, eps=1e-7, log=False): +def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` @@ -552,6 +553,10 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_i tolerance for the fixed point algorithm log : bool, optional record log if True + step_size: float, optional + step size for the gradient descent, 1 by default + batch_size: int, optional + batch size if use a stochastic gradient descent. If not None, use method='gradient_descent' Returns @@ -586,10 +591,10 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_i # Compute the mean barycenter mb = nx.sum(m * weights[:, None], axis=0) - if method == "fixed_point": + if method == "gradient_descent" or batch_size is not None: + out = bures_barycenter_gradient_descent(C, weights=weights, num_iter=num_iter, eps=eps, log=log, step_size=step_size, batch_size=batch_size) + elif method == "fixed_point": out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, eps=eps, log=log) - elif method == "gradient_descent": - out = bures_barycenter_gradient_descent(C, weights=weights, num_iter=num_iter, eps=eps, log=log, step_size=1, batch_size=None) else: raise ValueError("Unknown method '%s'." % method) From 9a433696a01c03b304b5036fb94f96ae7f73066c Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 19 Oct 2024 18:36:33 +0200 Subject: [PATCH 04/24] Test fixed_point vs gradient_descent --- test/test_gaussian.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 6a27edb7a..81e073911 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -154,6 +154,34 @@ def test_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2) +def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif('3gauss', n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method="fixed_point", log=False) + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(m, C, method="gradient_descent", log=False) + + np.testing.assert_allclose(mb, mb2, atol=1e-5) + np.testing.assert_allclose(Cb, Cb2, atol=1e-5) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_barycenter(nx, bias): n = 50 From 8afc00b5e1fff71a628da73ee0b74686b0035ec1 Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 25 Oct 2024 18:40:55 +0200 Subject: [PATCH 05/24] fix test bwgd --- ot/gaussian.py | 12 +++++++++--- test/test_gaussian.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index c5d5373c0..8540d0213 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -514,7 +514,9 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, return Cb -def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): +def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", + num_iter=1000, eps=1e-7, log=False, + step_size=1, batch_size=None): r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` @@ -592,9 +594,13 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_i mb = nx.sum(m * weights[:, None], axis=0) if method == "gradient_descent" or batch_size is not None: - out = bures_barycenter_gradient_descent(C, weights=weights, num_iter=num_iter, eps=eps, log=log, step_size=step_size, batch_size=batch_size) + out = bures_barycenter_gradient_descent(C, weights=weights, + num_iter=num_iter, eps=eps, + log=log, step_size=step_size, + batch_size=batch_size) elif method == "fixed_point": - out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, eps=eps, log=log) + out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, + eps=eps, log=log) else: raise ValueError("Unknown method '%s'." % method) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 81e073911..2e54356fd 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -109,7 +109,7 @@ def test_empirical_bures_wasserstein_distance(nx, bias): @pytest.mark.parametrize("method", ["fixed_point", "gradient_descent"]) -def test_bures_wasserstein_barycenter(nx): +def test_bures_wasserstein_barycenter(nx, method): n = 50 k = 10 X = [] From b2b0bca3984c45380be3deb4756558552e894e1e Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 25 Oct 2024 19:00:07 +0200 Subject: [PATCH 06/24] nx exp_bures --- ot/gaussian.py | 19 +++++++++++++------ ot/utils.py | 5 +++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 8540d0213..37ecac1c8 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -485,18 +485,25 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, if batch_size is not None and batch_size < n: # if stochastic gradient descent if batch_size <= 0: raise ValueError("batch_size must be an integer between 0 and {}".format(n)) - inds = np.random.choice(n, batch_size, replace=True, p=nx._to_numpy(weights)) + inds = np.random.choice(n, batch_size, replace=True, + p=nx._to_numpy(weights)) M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) - grad_bw = Id - nx.mean(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_), axis=0) + ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) + grad_bw = Id - nx.mean(ot_maps, axis=0) else: # gradient descent M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) - grad_bw = Id - nx.sum(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) * weights[:, None, None], axis=0) + ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) + grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0) - Cnew = exp_bures(Cb, - step_size * grad_bw) + Cnew = exp_bures(Cb, - step_size * grad_bw, nx=nx) - # Right criteria? (for GD, seems fine, but for SGD?) # check convergence - diff = nx.norm(Cb - Cnew) + if batch_size is not None and batch_size < n: + # TODO: criteria for SGD: on gradients? + test SGD + diff = nx.norm(Cb - Cnew) + else: + diff = nx.norm(Cb - Cnew) + if diff <= eps: break diff --git a/ot/utils.py b/ot/utils.py index 009b481db..348b98393 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1310,11 +1310,12 @@ def proj_SDP(S, nx=None, vmin=0.): return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1))) -def exp_bures(Sigma, S): +def exp_bures(Sigma, S, nx=None): r""" Exponential map Bures-Wasserstein space as Sigma: \exp_\Sigma(S) """ - nx = get_backend(S) + if nx is None: + nx = get_backend(Sigma, S) d = S.shape[-1] Id = nx.eye(d, type_as=S) C = Id + S From d287a2a431b773babf4bd371bc9c8ddf0d71ea8d Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 25 Oct 2024 23:48:41 +0200 Subject: [PATCH 07/24] update doc --- RELEASES.md | 1 + ot/gaussian.py | 66 +++++++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 821432548..7a8d9eca8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -19,6 +19,7 @@ - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) +- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) diff --git a/ot/gaussian.py b/ot/gaussian.py index 37ecac1c8..daee94b99 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -348,17 +348,17 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=False): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. - The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` - :ref:`[1] ` by solving + The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` + :ref:`[16] ` by solving .. math:: - \Sigma_b = \argmin_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big) + \Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big). The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` - where :math: `\Sigma_b` is solution of the following fixed-point algorithm: + where :math:`\Sigma_b` is solution of the following fixed-point algorithm: .. math:: - \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}. Parameters ---------- @@ -382,9 +382,11 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals log : dict log dictionary return only if log==True in parameters + + .. _references-OT-bures-barycenter-fixed-point: References ---------- - .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + .. [16] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ @@ -423,13 +425,14 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): - r"""Return OT linear operator between covariances. + r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. - The function estimates the optimal barycenter of empirical distributions. This is equivalent to resolving the fixed point - algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(0,\Sigma)\right}_{i=1}^n` - :ref:`[1] `. + The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` + by using a gradient descent in the Wasserstein space :ref:`[74, 75] ` + on the objective - The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` + .. math:: + \mathcal{L}(\Sigma) = \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i)\big). Parameters ---------- @@ -445,9 +448,9 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, tolerance for the fixed point algorithm log : bool, optional record log if True - step_size: float, optional + step_size : float, optional step size for the gradient descent, 1 by default - batch_size: int, optional + batch_size : int, optional batch size if use a stochastic gradient descent Returns @@ -457,15 +460,17 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log : dict log dictionary return only if log==True in parameters + + .. _references-OT-bures-barycenter-gradient_descent: References ---------- .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). - Gradient descent algorithms for Bures-Wasserstein barycenters. - In Conference on Learning Theory (pp. 1276-1304). PMLR. + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). - Averaging on the Bures-Wasserstein manifold: dimension-free convergence - of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ nx = get_backend(*C,) @@ -526,19 +531,19 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", step_size=1, batch_size=None): r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. - The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n` - :ref:`[1] ` by solving + The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\big(\mathcal{N}(\mu_i,\Sigma_i)\big)_{i=1}^n` + :ref:`[16, 74, 75] ` by solving .. math:: - (\mu_b, \Sigma_b) = \argmin_{\mu,\Sigma}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(\mu,\Sigma), \mathcal{N}(\mu_i, \Sigma_i)\big) + (\mu_b, \Sigma_b) = \mathrm{argmin}_{\mu,\Sigma}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(\mu,\Sigma), \mathcal{N}(\mu_i, \Sigma_i)\big) The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` - where : + where: .. math:: - \mu_b = \sum_{i=1}^n w_i \mu_i + \mu_b = \sum_{i=1}^n w_i \mu_i, - And the barycentric covariance is the solution of the following fixed-point algorithm: + and the barycentric covariance is the solution of the following fixed-point algorithm: .. math:: \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} @@ -562,9 +567,9 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", tolerance for the fixed point algorithm log : bool, optional record log if True - step_size: float, optional + step_size : float, optional step size for the gradient descent, 1 by default - batch_size: int, optional + batch_size : int, optional batch size if use a stochastic gradient descent. If not None, use method='gradient_descent' @@ -577,20 +582,21 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", log : dict log dictionary return only if log==True in parameters + .. _references-OT-bures_wasserstein-barycenter: References ---------- - .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + .. [16] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). - Gradient descent algorithms for Bures-Wasserstein barycenters. - In Conference on Learning Theory (pp. 1276-1304). PMLR. + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). - Averaging on the Bures-Wasserstein manifold: dimension-free convergence - of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ nx = get_backend(*m,) From 4f648bb9b35f50b2337281aac60b0db06131346d Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 31 Oct 2024 22:30:24 +0100 Subject: [PATCH 08/24] fix merge --- ot/gaussian.py | 71 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 0d99d9fae..eb38f0bc0 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -393,7 +393,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ - nx = get_backend(*C,) + nx = get_backend( + *C, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -420,19 +422,21 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals if log: log = {} - log['num_iter'] = it - log['final_diff'] = diff + log["num_iter"] = it + log["final_diff"] = diff return Cb, log else: return Cb -def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): +def bures_barycenter_gradient_descent( + C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None +): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` by using a gradient descent in the Wasserstein space :ref:`[74, 75] ` - on the objective + on the objective .. math:: \mathcal{L}(\Sigma) = \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i)\big). @@ -475,7 +479,9 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend(*C,) + nx = get_backend( + *C, + ) n = C.shape[0] @@ -492,9 +498,12 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, if batch_size is not None and batch_size < n: # if stochastic gradient descent if batch_size <= 0: - raise ValueError("batch_size must be an integer between 0 and {}".format(n)) - inds = np.random.choice(n, batch_size, replace=True, - p=nx._to_numpy(weights)) + raise ValueError( + "batch_size must be an integer between 0 and {}".format(n) + ) + inds = np.random.choice( + n, batch_size, replace=True, p=nx._to_numpy(weights) + ) M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) grad_bw = Id - nx.mean(ot_maps, axis=0) @@ -503,7 +512,7 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0) - Cnew = exp_bures(Cb, - step_size * grad_bw, nx=nx) + Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx) # check convergence if batch_size is not None and batch_size < n: @@ -522,16 +531,24 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, if log: log = {} - log['num_iter'] = it - log['final_diff'] = diff + log["num_iter"] = it + log["final_diff"] = diff return Cb, log else: return Cb -def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", - num_iter=1000, eps=1e-7, log=False, - step_size=1, batch_size=None): +def bures_wasserstein_barycenter( + m, + C, + weights=None, + method="fixed_point", + num_iter=1000, + eps=1e-7, + log=False, + step_size=1, + batch_size=None, +): r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\big(\mathcal{N}(\mu_i,\Sigma_i)\big)_{i=1}^n` @@ -601,7 +618,9 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend(*m,) + nx = get_backend( + *m, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -610,20 +629,24 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", mb = nx.sum(m * weights[:, None], axis=0) if method == "gradient_descent" or batch_size is not None: - out = bures_barycenter_gradient_descent(C, weights=weights, - num_iter=num_iter, eps=eps, - log=log, step_size=step_size, - batch_size=batch_size) + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + step_size=step_size, + batch_size=batch_size, + ) elif method == "fixed_point": - out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, - eps=eps, log=log) + out = bures_barycenter_fixpoint( + C, weights=weights, num_iter=num_iter, eps=eps, log=log + ) else: raise ValueError("Unknown method '%s'." % method) if log: Cb, log = out - log["num_iter"] = it - log["final_diff"] = diff return mb, Cb, log else: Cb = out From b821ee8100872435ab13ae4348570abadd807d14 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 31 Oct 2024 23:35:25 +0100 Subject: [PATCH 09/24] doc exp bw --- ot/utils.py | 22 ++++++++++++++++++++-- test/test_gaussian.py | 22 ++++++++++++++++------ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index 5094f21c0..a436a3bbf 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1340,12 +1340,30 @@ def proj_SDP(S, nx=None, vmin=0.0): else: # input was (n, d, d): broadcasting Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T - return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1))) + return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) def exp_bures(Sigma, S, nx=None): r""" - Exponential map Bures-Wasserstein space as Sigma: \exp_\Sigma(S) + Exponential map in Bures-Wasserstein space at Sigma: + + .. math:: + \exp_\Sigma(S) = (I_d+S)\Sigma(I_d+S). + + Parameters + ---------- + Sigma : array-like (d,d) + SPD matrix + S : array-like (d,d) + Symmetric matrix + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `Sigma, S`. + + Returns + ------- + P : array-like (d,d) + SPD matrix obtained as the exponential map of S at Sigma """ if nx is None: nx = get_backend(Sigma, S) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 0b59c026c..26925c37b 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -144,7 +144,9 @@ def test_bures_wasserstein_barycenter(nx, method): m = nx.from_numpy(m) C = nx.from_numpy(C) - mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=True) + mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=method, log=True + ) mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False) np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) @@ -152,13 +154,17 @@ def test_bures_wasserstein_barycenter(nx, method): # Test weights argument weights = nx.ones(k) / k - mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, method=method, log=False) + mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter( + m, C, weights=weights, method=method, log=False + ) np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) # test with closed form for diagonal covariance matrices Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)] Cdiag = nx.stack(Cdiag, axis=0) - mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, method=method, log=False) + mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter( + m, Cdiag, method=method, log=False + ) Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag] Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0) @@ -176,7 +182,7 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): m = [] C = [] for _ in range(k): - X_, y_ = make_data_classif('3gauss', n) + X_, y_ = make_data_classif("3gauss", n) m_ = np.mean(X_, axis=0)[None, :] C_ = np.cov(X_.T) X.append(X_) @@ -189,8 +195,12 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): m = nx.from_numpy(m) C = nx.from_numpy(C) - mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method="fixed_point", log=False) - mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(m, C, method="gradient_descent", log=False) + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="fixed_point", log=False + ) + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="gradient_descent", log=False + ) np.testing.assert_allclose(mb, mb2, atol=1e-5) np.testing.assert_allclose(Cb, Cb2, atol=1e-5) From d22028b8540efe7fb411c088e579eef6d38cb54f Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 5 Nov 2024 22:29:24 +0100 Subject: [PATCH 10/24] First tests stochastic + exp --- ot/gaussian.py | 16 ++++++---- test/test_gaussian.py | 68 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 26 +++++++++++++++++ 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index eb38f0bc0..8f9f51418 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -492,6 +492,8 @@ def bures_barycenter_gradient_descent( Cb = nx.mean(C * weights[:, None, None], axis=0) Id = nx.eye(C.shape[-1], type_as=Cb) + L_grads = [] + for it in range(num_iter): Cb12 = nx.sqrtm(Cb) Cb12_ = nx.inv(Cb12) @@ -517,7 +519,11 @@ def bures_barycenter_gradient_descent( # check convergence if batch_size is not None and batch_size < n: # TODO: criteria for SGD: on gradients? + test SGD - diff = nx.norm(Cb - Cnew) + L_grads.append(nx.sum(grad_bw**2)) + diff = np.mean(L_grads) + + # L_values.append(nx.norm(Cb - Cnew)) + # print(diff, np.mean(L_values)) else: diff = nx.norm(Cb - Cnew) @@ -530,10 +536,10 @@ def bures_barycenter_gradient_descent( print("Dit not converge.") if log: - log = {} - log["num_iter"] = it - log["final_diff"] = diff - return Cb, log + dict_log = {} + dict_log["num_iter"] = it + dict_log["final_diff"] = diff + return Cb, dict_log else: return Cb diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 26925c37b..5b11c06ab 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -206,6 +206,74 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(Cb, Cb2, atol=1e-5) +def test_stochastic_gd_bures_wasserstein_barycenter(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="fixed_point", log=False + ) + + n_samples = [10, 20, 50] # for 1 or 5, too slow to converge + for n in n_samples: + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="gradient_descent", log=False, batch_size=n + ) + + np.testing.assert_allclose(mb, mb2, atol=1e-5) + np.testing.assert_allclose(Cb, Cb2, atol=1e-5) + + with pytest.raises(ValueError): + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="gradient_descent", log=False, batch_size=-5 + ) + + +def test_not_implemented_method(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + not_implemented = "new_method" + with pytest.raises(ValueError): + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=not_implemented, log=False + ) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_barycenter(nx, bias): n = 50 diff --git a/test/test_utils.py b/test/test_utils.py index d50f29915..76241e89c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -655,3 +655,29 @@ def test_kl_div(nx): kl_mass = nx.kl_div(xb, yb, True) recovered_kl = kl_mass - nx.sum(yb - xb) np.testing.assert_allclose(kl, recovered_kl) + + +def test_exp_bures(nx): + d = 2 + Sigma = nx.eye(d) + + rng = np.random.RandomState(42) + X = rng.randn(d, d) + z = rng.randn(d) + X, z = nx.from_numpy(X, z) + S = X + nx.transpose(X) + + Lambda = ot.utils.exp_bures(Sigma, S) + + # asserst SPD + np.testing.assert_array_less(np.zeros(1), nx.to_numpy(z.T @ Lambda @ z)) + + # OT map from Lambda to Sigma + Lambda_12 = nx.sqrtm(Lambda) + Lambda_12_ = nx.inv(Lambda_12) + M = nx.sqrtm(nx.einsum("ij, jk, kl", Lambda_12, Sigma, Lambda_12)) + T = nx.einsum("ij, jk, kl", Lambda_12_, M, Lambda_12_) + + # exp_\Lambda(log_\Lambda(Sigma)) = Sigma + Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d)) + np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5) From dffa0cfce8f0dc717090e2c1844ee4ba72650c7f Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 5 Nov 2024 22:46:14 +0100 Subject: [PATCH 11/24] exp_bures with einsum --- ot/utils.py | 2 +- test/test_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index a436a3bbf..2decee174 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1371,4 +1371,4 @@ def exp_bures(Sigma, S, nx=None): Id = nx.eye(d, type_as=S) C = Id + S - return dots(C, Sigma, C) + return nx.einsum("ij,jk,kl -> il", C, Sigma, C) diff --git a/test/test_utils.py b/test/test_utils.py index 76241e89c..3d6fedd89 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -675,8 +675,8 @@ def test_exp_bures(nx): # OT map from Lambda to Sigma Lambda_12 = nx.sqrtm(Lambda) Lambda_12_ = nx.inv(Lambda_12) - M = nx.sqrtm(nx.einsum("ij, jk, kl", Lambda_12, Sigma, Lambda_12)) - T = nx.einsum("ij, jk, kl", Lambda_12_, M, Lambda_12_) + M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda_12, Sigma, Lambda_12)) + T = nx.einsum("ij, jk, kl -> il", Lambda_12_, M, Lambda_12_) # exp_\Lambda(log_\Lambda(Sigma)) = Sigma Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d)) From f3e911a1162a1d1604897e6a6819409fba1b41fc Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 6 Nov 2024 09:06:47 +0100 Subject: [PATCH 12/24] type Id test --- test/test_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3d6fedd89..f6b1331fa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -659,7 +659,6 @@ def test_kl_div(nx): def test_exp_bures(nx): d = 2 - Sigma = nx.eye(d) rng = np.random.RandomState(42) X = rng.randn(d, d) @@ -667,6 +666,8 @@ def test_exp_bures(nx): X, z = nx.from_numpy(X, z) S = X + nx.transpose(X) + Sigma = nx.eye(d, type_as=S) + Lambda = ot.utils.exp_bures(Sigma, S) # asserst SPD @@ -679,5 +680,5 @@ def test_exp_bures(nx): T = nx.einsum("ij, jk, kl -> il", Lambda_12_, M, Lambda_12_) # exp_\Lambda(log_\Lambda(Sigma)) = Sigma - Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d)) + Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T)) np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5) From 97f226122196db2e34478be281d872ad83e7f8cb Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 7 Nov 2024 20:32:21 +0100 Subject: [PATCH 13/24] up test stochastic --- ot/gaussian.py | 2 ++ test/test_gaussian.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 8f9f51418..46e2f3038 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -519,6 +519,8 @@ def bures_barycenter_gradient_descent( # check convergence if batch_size is not None and batch_size < n: # TODO: criteria for SGD: on gradients? + test SGD + # TOO slow, test with value? (but don't want to compute the full barycenter) + # + need to make bures_wasserstein_distance batchable (TODO) L_grads.append(nx.sum(grad_bw**2)) diff = np.mean(L_grads) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 5b11c06ab..f4a5eea5a 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -205,6 +205,13 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(mb, mb2, atol=1e-5) np.testing.assert_allclose(Cb, Cb2, atol=1e-5) + # Test weights argument + weights = nx.ones(k) / k + Cbw = ot.gaussian.bures_barycenter_fixpoint(C, weights=weights) + Cbw2 = ot.gaussian.bures_barycenter_gradient_descent(C, weights=weights) + np.testing.assert_allclose(Cbw, Cb, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(Cbw2, Cb2, rtol=1e-5, atol=1e-5) + def test_stochastic_gd_bures_wasserstein_barycenter(nx): n = 50 @@ -231,14 +238,16 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): m, C, method="fixed_point", log=False ) - n_samples = [10, 20, 50] # for 1 or 5, too slow to converge + n_samples = [1, 5] for n in n_samples: mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( m, C, method="gradient_descent", log=False, batch_size=n ) np.testing.assert_allclose(mb, mb2, atol=1e-5) - np.testing.assert_allclose(Cb, Cb2, atol=1e-5) + # atol big for now because too slow, need to see if + # it can be improved... + np.testing.assert_allclose(Cb, Cb2, atol=0.5) with pytest.raises(ValueError): mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( From 7594393bd9385c12cce87901adaa0cacfbc7703d Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 7 Nov 2024 21:30:40 +0100 Subject: [PATCH 14/24] test weights --- test/test_gaussian.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index f4a5eea5a..6b12ec942 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -206,9 +206,8 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(Cb, Cb2, atol=1e-5) # Test weights argument - weights = nx.ones(k) / k - Cbw = ot.gaussian.bures_barycenter_fixpoint(C, weights=weights) - Cbw2 = ot.gaussian.bures_barycenter_gradient_descent(C, weights=weights) + Cbw = ot.gaussian.bures_barycenter_fixpoint(C, weights=None) + Cbw2 = ot.gaussian.bures_barycenter_gradient_descent(C, weights=None) np.testing.assert_allclose(Cbw, Cb, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(Cbw2, Cb2, rtol=1e-5, atol=1e-5) From 6c48b3c58a6dfa4e8703a48f3b37aed1993b1048 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 7 Nov 2024 22:23:21 +0100 Subject: [PATCH 15/24] Add BW distance with batchs --- ot/backend.py | 15 ++++++++++----- ot/gaussian.py | 24 +++++++++++++++++++++++- test/test_backend.py | 4 ++++ test/test_gaussian.py | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 6 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index a99639445..7c4f28dc6 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1355,7 +1355,8 @@ def solve(self, a, b): return np.linalg.solve(a, b) def trace(self, a): - return np.trace(a) + return np.einsum("...ii", a) + # return np.trace(a) def inv(self, a): return scipy.linalg.inv(a) @@ -1765,7 +1766,8 @@ def solve(self, a, b): return jnp.linalg.solve(a, b) def trace(self, a): - return jnp.trace(a) + return jnp.einsum("...ii", a) + # return jnp.trace(a) def inv(self, a): return jnp.linalg.inv(a) @@ -2295,7 +2297,8 @@ def solve(self, a, b): return torch.linalg.solve(a, b) def trace(self, a): - return torch.trace(a) + return torch.einsum("...ii", a) + # return torch.trace(a) def inv(self, a): return torch.linalg.inv(a) @@ -2706,7 +2709,8 @@ def solve(self, a, b): return cp.linalg.solve(a, b) def trace(self, a): - return cp.trace(a) + return cp.einsum("..ii", a) + # return cp.trace(a) def inv(self, a): return cp.linalg.inv(a) @@ -3139,7 +3143,8 @@ def solve(self, a, b): return tf.linalg.solve(a, b) def trace(self, a): - return tf.linalg.trace(a) + return tf.einsum("...ii", a) + # return tf.linalg.trace(a) def inv(self, a): return tf.linalg.inv(a) diff --git a/ot/gaussian.py b/ot/gaussian.py index 46e2f3038..71c641498 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -250,7 +250,6 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): nx = get_backend(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) - B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) @@ -262,6 +261,29 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): return W +def bures_wasserstein_distance_batch(ms, mt, Cs, Ct, log=False): + """ + TODO + Maybe try to merge it with bures_wasserstein_distance + """ + ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) + nx = get_backend(ms, mt, Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) + B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + + squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 + W = nx.sqrt(nx.maximum(squared_dist_m + B, 0)) + + if log: + log = {} + log["Cs12"] = Cs12 + return W, log + else: + return W + + def empirical_bures_wasserstein_distance( xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False ): diff --git a/test/test_backend.py b/test/test_backend.py index 435c6db8a..9ee784a4d 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -612,6 +612,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("trace") + A = nx.trace(nx.stack([SquareMb, SquareMb], axis=0)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("broadcast trace") + A = nx.inv(SquareMb) lst_b.append(nx.to_numpy(A)) lst_name.append("matrix inverse") diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 6b12ec942..348b7e25a 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -99,6 +99,47 @@ def test_bures_wasserstein_distance(nx): np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) +def test_bures_wasserstein_distance_batch(nx): + n = 50 + k = 2 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False) + + Wb2 = ot.gaussian.bures_wasserstein_distance_batch( + m[0, 0][None], m[1, 0][None], C[0][None], C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + + Wb2 = ot.gaussian.bures_wasserstein_distance_batch( + m[:, 0], m[1, 0][None], C, C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5) + + Wb2 = ot.gaussian.bures_wasserstein_distance_batch(m[:, 0], m[:, 0], C, C) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_distance(nx, bias): ns = 400 From ba806ff7f6026646b98af4ccdf7c3d197c7940a4 Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 11 Nov 2024 21:51:30 +0100 Subject: [PATCH 16/24] step size SGD BW Barycenter --- ot/gaussian.py | 5 +++++ test/test_gaussian.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 71c641498..be8abd543 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -531,6 +531,11 @@ def bures_barycenter_gradient_descent( M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) grad_bw = Id - nx.mean(ot_maps, axis=0) + + # step size from [74] (page 15) + step_size = 2 / (0.7 * (it + 2 / 0.7 + 1)) + + # TODO: Add one where we take samples in order, + averaging? cf [74] else: # gradient descent M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 348b7e25a..e353d748a 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -278,16 +278,25 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): m, C, method="fixed_point", log=False ) + loss = nx.mean( + ot.gaussian.bures_wasserstein_distance_batch(mb[None], m, Cb[None], C) + ) + n_samples = [1, 5] for n in n_samples: mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( m, C, method="gradient_descent", log=False, batch_size=n ) + loss2 = nx.mean( + ot.gaussian.bures_wasserstein_distance_batch(mb2[None], m, Cb2[None], C) + ) + np.testing.assert_allclose(mb, mb2, atol=1e-5) # atol big for now because too slow, need to see if # it can be improved... - np.testing.assert_allclose(Cb, Cb2, atol=0.5) + np.testing.assert_allclose(Cb, Cb2, atol=1e-1) + np.testing.assert_allclose(loss, loss2, atol=1e-3) with pytest.raises(ValueError): mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( From d4045f1cac80fdd8f2fbf4ca443dbebbc3ee3311 Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 19 Nov 2024 20:17:06 +0100 Subject: [PATCH 17/24] batchable BW distance --- ot/gaussian.py | 57 +++++++++++++++++++++---------------------- test/test_gaussian.py | 29 +++++++++++++++------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index be8abd543..583e3c1f3 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -207,7 +207,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, discussed in remark 2.31 :ref:`[1] `. - The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2` .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -219,13 +219,13 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Parameters ---------- - ms : array-like (d,) + ms : array-like (d,) or (n,d) mean of the source distribution - mt : array-like (d,) + mt : array-like (d,) or (m,d) mean of the target distribution - Cs : array-like (d,d) + Cs : array-like (d,d) or (n,d,d) covariance of the source distribution - Ct : array-like (d,d) + Ct : array-like (d,d) or (m,d,d) covariance of the target distribution log : bool, optional record log if True @@ -233,7 +233,9 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Returns ------- - W : float + W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), + mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), + array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters @@ -250,30 +252,27 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): nx = get_backend(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) - B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) - if log: - log = {} - log["Cs12"] = Cs12 - return W, log + if len(ms.shape) == 1 and len(mt.shape) == 1: + # Return float + squared_dist_m = nx.norm(ms - mt) ** 2 + B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + elif len(ms.shape) == 1: + # Return shape (m,) + M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12) + B = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2 + elif len(mt.shape) == 1: + # Return shape (n,) + M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12) + B = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2 else: - return W - - -def bures_wasserstein_distance_batch(ms, mt, Cs, Ct, log=False): - """ - TODO - Maybe try to merge it with bures_wasserstein_distance - """ - ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) - nx = get_backend(ms, mt, Cs, Ct) - - Cs12 = nx.sqrtm(Cs) - M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) - B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + # Return shape (n,m) + M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) + B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 - squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 W = nx.sqrt(nx.maximum(squared_dist_m + B, 0)) if log: @@ -361,12 +360,12 @@ def empirical_bures_wasserstein_distance( Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) if log: - W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) + W, log = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct, log=log) log["Cs"] = Cs log["Ct"] = Ct return W, log else: - W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) + W = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct) return W diff --git a/test/test_gaussian.py b/test/test_gaussian.py index e353d748a..95843374e 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -122,22 +122,35 @@ def test_bures_wasserstein_distance_batch(nx): Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch( + Wb2 = ot.gaussian.bures_wasserstein_distance( m[0, 0][None], m[1, 0][None], C[0][None], C[1][None] ) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1, 1)) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch( - m[:, 0], m[1, 0][None], C, C[1][None] - ) + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[1, 0][None], C, C[1][None]) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 1)) + + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0][None], m[1, 0], C[0][None], C[1] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch(m[:, 0], m[:, 0], C, C) + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0], m[1, 0][None], C[0], C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) + + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 2)) @pytest.mark.parametrize("bias", [True, False]) @@ -278,9 +291,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): m, C, method="fixed_point", log=False ) - loss = nx.mean( - ot.gaussian.bures_wasserstein_distance_batch(mb[None], m, Cb[None], C) - ) + loss = nx.mean(ot.gaussian.bures_wasserstein_distance(mb[None], m, Cb[None], C)) n_samples = [1, 5] for n in n_samples: @@ -289,7 +300,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): ) loss2 = nx.mean( - ot.gaussian.bures_wasserstein_distance_batch(mb2[None], m, Cb2[None], C) + ot.gaussian.bures_wasserstein_distance(mb2[None], m, Cb2[None], C) ) np.testing.assert_allclose(mb, mb2, atol=1e-5) From 50994ed5ec35231ebd10e35c6c5ef35f68fa2038 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 12 Feb 2025 22:31:18 +0100 Subject: [PATCH 18/24] RELEASES.md --- ot/utils.py | 2 +- test/test_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index c61ba86ea..431226910 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1373,7 +1373,7 @@ def exp_bures(Sigma, S, nx=None): return nx.einsum("ij,jk,kl -> il", C, Sigma, C) - + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. diff --git a/test/test_utils.py b/test/test_utils.py index f6b1331fa..3f5f9ec65 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -674,10 +674,10 @@ def test_exp_bures(nx): np.testing.assert_array_less(np.zeros(1), nx.to_numpy(z.T @ Lambda @ z)) # OT map from Lambda to Sigma - Lambda_12 = nx.sqrtm(Lambda) - Lambda_12_ = nx.inv(Lambda_12) - M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda_12, Sigma, Lambda_12)) - T = nx.einsum("ij, jk, kl -> il", Lambda_12_, M, Lambda_12_) + Lambda12 = nx.sqrtm(Lambda) + Lambda12inv = nx.inv(Lambda12) + M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda12, Sigma, Lambda12)) + T = nx.einsum("ij, jk, kl -> il", Lambda12inv, M, Lambda12inv) # exp_\Lambda(log_\Lambda(Sigma)) = Sigma Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T)) From bad385f2d82cecb7c54bb7a3a40898b3dcdca46f Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 12 Feb 2025 22:34:13 +0100 Subject: [PATCH 19/24] precommit --- RELEASES.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 21c13fcc6..f1035dd2e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,8 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Fix documentation in the module `ot.gaussian` (PR #718) +- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) +- `ot.gaussian_bures_wasserstein_distance` can be batched (PR #680) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) @@ -44,7 +46,6 @@ This release also contains few bug fixes, concerning the support of any metric i - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) - Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) -- Refactored `ot.bregman._convolutional` to improve readability (PR #709) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 0b20759e5bfaf51a49e4be8146537a8eedc28e07 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 12 Feb 2025 23:08:56 +0100 Subject: [PATCH 20/24] Add ot.gaussian.bures --- ot/gaussian.py | 101 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 86 insertions(+), 15 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 1ac4d7f11..f98209744 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -200,14 +200,77 @@ def empirical_bures_wasserstein_mapping( return A, b +def bures_distance(Cs, Ct, log=False): + r"""Return Bures distance. + + The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`, + given by: + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + Cs : array-like (d,d) or (n,d,d) + covariance of the source distribution + Ct : array-like (d,d) or (m,d,d) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), + Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), + array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d) + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + .. _references-bures-wasserstein-distance: + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + Cs, Ct = list_to_array(Cs, Ct) + nx = get_backend(Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + + if len(Cs.shape) == 2 and len(Ct.shape) == 2: + # Return float + bw2 = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + elif len(Cs.shape) == 2: + # Return shape (m,) + M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M)) + elif len(Ct.shape) == 2: + # Return shape (n,) + M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M)) + else: + # Return shape (n,m) + M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + + W = nx.sqrt(nx.maximum(bw2, 0)) + + if log: + log = {} + log["Cs12"] = Cs12 + return W, log + else: + return W + + def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): r"""Return Bures Wasserstein distance between samples. - The function estimates the Bures-Wasserstein distance between two - empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, - discussed in remark 2.31 :ref:`[1] `. - - The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2` + The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`, + as discussed in remark 2.31 :ref:`[1] `. .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -230,7 +293,6 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): log : bool, optional record log if True - Returns ------- W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), @@ -251,29 +313,38 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) - Cs12 = nx.sqrtm(Cs) + assert ( + ms.shape[0] == Cs.shape[0] + ), "Source Gaussians has different amount of components" + + assert ( + mt.shape[0] == Ct.shape[0] + ), "Target Gaussians has different amount of components" + + assert ( + ms.shape[-1] == mt.shape[-1] == Cs.shape[-1] == Ct.shape[-1] + ), "All Gaussian must have the same dimension" + + if log: + bw, log_dict = bures_distance(Cs, Ct, log) + Cs12 = log_dict["Cs12"] + else: + bw = bures_distance(Cs, Ct) if len(ms.shape) == 1 and len(mt.shape) == 1: # Return float squared_dist_m = nx.norm(ms - mt) ** 2 - B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) elif len(ms.shape) == 1: # Return shape (m,) - M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12) - B = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M)) squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2 elif len(mt.shape) == 1: # Return shape (n,) - M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12) - B = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M)) squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2 else: # Return shape (n,m) - M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) - B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 - W = nx.sqrt(nx.maximum(squared_dist_m + B, 0)) + W = nx.sqrt(nx.maximum(squared_dist_m + bw**2, 0)) if log: log = {} From fe3d9db13ab490a0a5a1219bbb9ab9ac9275497f Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 13 Feb 2025 14:16:30 +0100 Subject: [PATCH 21/24] Add arg backend --- ot/gaussian.py | 52 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index f98209744..74f3e255d 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -200,7 +200,7 @@ def empirical_bures_wasserstein_mapping( return A, b -def bures_distance(Cs, Ct, log=False): +def bures_distance(Cs, Ct, log=False, nx=None): r"""Return Bures distance. The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`, @@ -217,7 +217,9 @@ def bures_distance(Cs, Ct, log=False): covariance of the target distribution log : bool, optional record log if True - + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `Cs, Ct`. Returns ------- @@ -236,7 +238,11 @@ def bures_distance(Cs, Ct, log=False): Transport", 2018. """ Cs, Ct = list_to_array(Cs, Ct) - nx = get_backend(Cs, Ct) + + if nx is None: + nx = get_backend(Cs, Ct) + + assert Cs.shape[-1] == Ct.shape[-1], "All Gaussian must have the same dimension" Cs12 = nx.sqrtm(Cs) @@ -326,10 +332,10 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): ), "All Gaussian must have the same dimension" if log: - bw, log_dict = bures_distance(Cs, Ct, log) + bw, log_dict = bures_distance(Cs, Ct, log=log, nx=nx) Cs12 = log_dict["Cs12"] else: - bw = bures_distance(Cs, Ct) + bw = bures_distance(Cs, Ct, nx=nx) if len(ms.shape) == 1 and len(mt.shape) == 1: # Return float @@ -440,7 +446,9 @@ def empirical_bures_wasserstein_distance( return W -def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=False): +def bures_barycenter_fixpoint( + C, weights=None, num_iter=1000, eps=1e-7, log=False, nx=None +): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` @@ -469,6 +477,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals tolerance for the fixed point algorithm log : bool, optional record log if True + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `C`. Returns ------- @@ -485,9 +496,10 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ - nx = get_backend( - *C, - ) + if nx is None: + nx = get_backend( + *C, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -522,7 +534,14 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals def bures_barycenter_gradient_descent( - C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None + C, + weights=None, + num_iter=1000, + eps=1e-7, + log=False, + step_size=1, + batch_size=None, + nx=None, ): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. @@ -551,6 +570,9 @@ def bures_barycenter_gradient_descent( step size for the gradient descent, 1 by default batch_size : int, optional batch size if use a stochastic gradient descent + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `C`. Returns ------- @@ -571,9 +593,10 @@ def bures_barycenter_gradient_descent( Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend( - *C, - ) + if nx is None: + nx = get_backend( + *C, + ) n = C.shape[0] @@ -742,10 +765,11 @@ def bures_wasserstein_barycenter( log=log, step_size=step_size, batch_size=batch_size, + nx=nx, ) elif method == "fixed_point": out = bures_barycenter_fixpoint( - C, weights=weights, num_iter=num_iter, eps=eps, log=log + C, weights=weights, num_iter=num_iter, eps=eps, log=log, nx=nx ) else: raise ValueError("Unknown method '%s'." % method) From 506a524e901d76b82f30f63ad9462190ff8904f2 Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 16 Feb 2025 19:51:04 +0100 Subject: [PATCH 22/24] up stop criteria sgd Gaussian barycenter --- ot/gaussian.py | 66 +++++++++++++++++++++++++++++-------------- test/test_gaussian.py | 23 +++++++++++---- 2 files changed, 62 insertions(+), 27 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 74f3e255d..8cf5022f1 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -541,6 +541,7 @@ def bures_barycenter_gradient_descent( log=False, step_size=1, batch_size=None, + averaged=False, nx=None, ): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. @@ -570,6 +571,8 @@ def bures_barycenter_gradient_descent( step size for the gradient descent, 1 by default batch_size : int, optional batch size if use a stochastic gradient descent + averaged : bool, optional + if True, use the averaged procedure of :ref:`[74] ` nx : module, optional The numerical backend module to use. If not provided, the backend will be fetched from the input matrices `C`. @@ -607,7 +610,9 @@ def bures_barycenter_gradient_descent( Cb = nx.mean(C * weights[:, None, None], axis=0) Id = nx.eye(C.shape[-1], type_as=Cb) - L_grads = [] + L_diff = [] + + Cb_averaged = nx.copy(Cb) for it in range(num_iter): Cb12 = nx.sqrtm(Cb) @@ -627,8 +632,6 @@ def bures_barycenter_gradient_descent( # step size from [74] (page 15) step_size = 2 / (0.7 * (it + 2 / 0.7 + 1)) - - # TODO: Add one where we take samples in order, + averaging? cf [74] else: # gradient descent M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) @@ -636,31 +639,31 @@ def bures_barycenter_gradient_descent( Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx) + if averaged: + # ot map between Cb_averaged and Cnew + Cb_averaged12 = nx.sqrtm(Cb_averaged) + Cb_averaged12inv = nx.inv(Cb_averaged12) + M = nx.sqrtm(nx.einsum("ij,jk,kl->il", Cb_averaged12, Cnew, Cb_averaged12)) + ot_map = nx.einsum("ij,jk,kl->il", Cb_averaged12inv, M, Cb_averaged12inv) + map = Id * step_size / (step_size + 1) + ot_map / (step_size + 1) + Cb_averaged = nx.einsum("ij,jk,kl->il", map, Cb_averaged, map) + # check convergence - if batch_size is not None and batch_size < n: - # TODO: criteria for SGD: on gradients? + test SGD - # TOO slow, test with value? (but don't want to compute the full barycenter) - # + need to make bures_wasserstein_distance batchable (TODO) - L_grads.append(nx.sum(grad_bw**2)) - diff = np.mean(L_grads) - - # L_values.append(nx.norm(Cb - Cnew)) - # print(diff, np.mean(L_values)) - else: - diff = nx.norm(Cb - Cnew) + L_diff.append(nx.norm(Cb - Cnew)) - if diff <= eps: + # Criteria to stop + if np.mean(L_diff[-100:]) <= eps: break Cb = Cnew - if diff > eps: - print("Dit not converge.") + if averaged: + Cb = Cb_averaged if log: dict_log = {} dict_log["num_iter"] = it - dict_log["final_diff"] = diff + dict_log["final_diff"] = L_diff[-1] return Cb, dict_log else: return Cb @@ -708,7 +711,8 @@ def bures_wasserstein_barycenter( weights : array-like (k), optional weights for each distribution method : str - method used for the solver, either 'fixed_point' or 'gradient_descent' + method used for the solver, either 'fixed_point', 'gradient_descent', 'stochastic_gradient_descent' or + 'averaged_stochastic_gradient_descent' num_iter : int, optional number of iteration for the fixed point algorithm eps : float, optional @@ -756,7 +760,7 @@ def bures_wasserstein_barycenter( # Compute the mean barycenter mb = nx.sum(m * weights[:, None], axis=0) - if method == "gradient_descent" or batch_size is not None: + if method == "gradient_descent": out = bures_barycenter_gradient_descent( C, weights=weights, @@ -764,7 +768,27 @@ def bures_wasserstein_barycenter( eps=eps, log=log, step_size=step_size, - batch_size=batch_size, + nx=nx, + ) + elif method == "stochastic_gradient_descent": + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + batch_size=1 if batch_size is None else batch_size, + nx=nx, + ) + elif method == "averaged_stochastic_gradient_descent": + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + batch_size=1 if batch_size is None else batch_size, + averaged=True, nx=nx, ) elif method == "fixed_point": diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 95843374e..9d3dc6857 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -176,7 +176,15 @@ def test_empirical_bures_wasserstein_distance(nx, bias): np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) -@pytest.mark.parametrize("method", ["fixed_point", "gradient_descent"]) +@pytest.mark.parametrize( + "method", + [ + "fixed_point", + "gradient_descent", + "stochastic_gradient_descent", + "averaged_stochastic_gradient_descent", + ], +) def test_bures_wasserstein_barycenter(nx, method): n = 50 k = 10 @@ -203,7 +211,7 @@ def test_bures_wasserstein_barycenter(nx, method): ) mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False) - np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Cb, Cblog, rtol=1e-1, atol=1e-1) np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2) # Test weights argument @@ -211,7 +219,7 @@ def test_bures_wasserstein_barycenter(nx, method): mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter( m, C, weights=weights, method=method, log=False ) - np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Cbw, Cb, rtol=1e-1, atol=1e-1) # test with closed form for diagonal covariance matrices Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)] @@ -266,7 +274,10 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(Cbw2, Cb2, rtol=1e-5, atol=1e-5) -def test_stochastic_gd_bures_wasserstein_barycenter(nx): +@pytest.mark.parametrize( + "method", ["stochastic_gradient_descent", "averaged_stochastic_gradient_descent"] +) +def test_stochastic_gd_bures_wasserstein_barycenter(nx, method): n = 50 k = 10 X = [] @@ -296,7 +307,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): n_samples = [1, 5] for n in n_samples: mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( - m, C, method="gradient_descent", log=False, batch_size=n + m, C, method=method, log=False, batch_size=n ) loss2 = nx.mean( @@ -311,7 +322,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): with pytest.raises(ValueError): mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( - m, C, method="gradient_descent", log=False, batch_size=-5 + m, C, method=method, log=False, batch_size=-5 ) From c640ecb2e231794339e6c722d0f7122eaf3ea2af Mon Sep 17 00:00:00 2001 From: Clement Date: Sun, 16 Feb 2025 20:20:32 +0100 Subject: [PATCH 23/24] Fix release --- RELEASES.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index f1035dd2e..7fc3008a3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,8 +8,10 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Fix documentation in the module `ot.gaussian` (PR #718) +- Refactored `ot.bregman._convolutional` to improve readability (PR #709) - Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) -- `ot.gaussian_bures_wasserstein_distance` can be batched (PR #680) +- Added `ot.gaussian.bures_wasserstein_distance` (PR #680) +- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) @@ -45,7 +47,6 @@ This release also contains few bug fixes, concerning the support of any metric i - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) -- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 41ebffcc8fae218fa342606c76e52a673dbe8bb1 Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 17 Feb 2025 10:52:54 +0100 Subject: [PATCH 24/24] fix doc --- ot/gaussian.py | 50 ++++++++++++++++++++----------------------- test/test_gaussian.py | 2 ++ 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 8cf5022f1..c399bc9d5 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -204,7 +204,7 @@ def bures_distance(Cs, Ct, log=False, nx=None): r"""Return Bures distance. The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`, - given by: + given by (see e.g. Remark 2.31 :ref:`[15] `): .. math:: \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) @@ -223,19 +223,17 @@ def bures_distance(Cs, Ct, log=False, nx=None): Returns ------- - W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), - Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), - array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d) + W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d) Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters + .. _references-bures-wasserstein-distance: References ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. + .. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science. + Foundations and Trends® in Machine Learning, 11(5-6), 355-607. """ Cs, Ct = list_to_array(Cs, Ct) @@ -276,7 +274,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): r"""Return Bures Wasserstein distance between samples. The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`, - as discussed in remark 2.31 :ref:`[1] `. + as discussed in remark 2.31 :ref:`[15] `. .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -301,9 +299,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Returns ------- - W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), - mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), - array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) + W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters @@ -313,8 +309,8 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): References ---------- - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. + .. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science. + Foundations and Trends® in Machine Learning, 11(5-6), 355-607. """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) @@ -455,7 +451,7 @@ def bures_barycenter_fixpoint( :ref:`[16] ` by solving .. math:: - \Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big). + \Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{++}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big). The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` where :math:`\Sigma_b` is solution of the following fixed-point algorithm: @@ -699,7 +695,7 @@ def bures_wasserstein_barycenter( .. math:: \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} - We propose two solvers: one based on solving the previous fixed-point problem [1]. Another based on + We propose two solvers: one based on solving the previous fixed-point problem [16]. Another based on gradient descent in the Bures-Wasserstein space [74,75]. Parameters @@ -926,9 +922,8 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): .. _references-gaussien_gromov_wasserstein_distance: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(Cov_s, Cov_t) @@ -990,9 +985,9 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, wt=None, log .. _references-gaussien_gromov_wasserstein: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) @@ -1058,9 +1053,9 @@ def gaussian_gromov_wasserstein_mapping( .. _references-gaussien_gromov_wasserstein_mapping: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(mu_s, mu_t, Cov_s, Cov_t) @@ -1149,12 +1144,13 @@ def empirical_gaussian_gromov_wasserstein_mapping( b : (1, dt) array-like bias + .. _references-empirical_gaussian_gromov_wasserstein_mapping: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 9d3dc6857..a000a9590 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -92,11 +92,13 @@ def test_bures_wasserstein_distance(nx): msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + Wb2 = ot.gaussian.bures_distance(Csb, Ctb, log=False) np.testing.assert_allclose( nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(0, Wb2, rtol=1e-2, atol=1e-2) def test_bures_wasserstein_distance_batch(nx):