diff --git a/README.md b/README.md index f64db8f56..f7bb9f218 100644 --- a/README.md +++ b/README.md @@ -391,3 +391,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. + +[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. diff --git a/RELEASES.md b/RELEASES.md index 745a7de67..7fc3008a3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,10 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Fix documentation in the module `ot.gaussian` (PR #718) +- Refactored `ot.bregman._convolutional` to improve readability (PR #709) +- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680) +- Added `ot.gaussian.bures_wasserstein_distance` (PR #680) +- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) @@ -43,7 +47,6 @@ This release also contains few bug fixes, concerning the support of any metric i - Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) - Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) -- Refactored `ot.bregman._convolutional` to improve readability (PR #709) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) diff --git a/ot/backend.py b/ot/backend.py index c68887e4b..08cf6211d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1363,7 +1363,8 @@ def solve(self, a, b): return np.linalg.solve(a, b) def trace(self, a): - return np.trace(a) + return np.einsum("...ii", a) + # return np.trace(a) def inv(self, a): return scipy.linalg.inv(a) @@ -1776,7 +1777,8 @@ def solve(self, a, b): return jnp.linalg.solve(a, b) def trace(self, a): - return jnp.trace(a) + return jnp.einsum("...ii", a) + # return jnp.trace(a) def inv(self, a): return jnp.linalg.inv(a) @@ -2309,7 +2311,8 @@ def solve(self, a, b): return torch.linalg.solve(a, b) def trace(self, a): - return torch.trace(a) + return torch.einsum("...ii", a) + # return torch.trace(a) def inv(self, a): return torch.linalg.inv(a) @@ -2723,7 +2726,8 @@ def solve(self, a, b): return cp.linalg.solve(a, b) def trace(self, a): - return cp.trace(a) + return cp.einsum("..ii", a) + # return cp.trace(a) def inv(self, a): return cp.linalg.inv(a) @@ -3159,7 +3163,8 @@ def solve(self, a, b): return tf.linalg.solve(a, b) def trace(self, a): - return tf.linalg.trace(a) + return tf.einsum("...ii", a) + # return tf.linalg.trace(a) def inv(self, a): return tf.linalg.inv(a) diff --git a/ot/gaussian.py b/ot/gaussian.py index 002e69fb4..c399bc9d5 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -9,9 +9,10 @@ # License: MIT License import warnings +import numpy as np from .backend import get_backend -from .utils import dots, is_all_finite, list_to_array +from .utils import dots, is_all_finite, list_to_array, exp_bures def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): @@ -199,14 +200,81 @@ def empirical_bures_wasserstein_mapping( return A, b +def bures_distance(Cs, Ct, log=False, nx=None): + r"""Return Bures distance. + + The function computes the Bures distance between :math:`\mu_s=\mathcal{N}(0,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(0,\Sigma_t)`, + given by (see e.g. Remark 2.31 :ref:`[15] `): + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + Cs : array-like (d,d) or (n,d,d) + covariance of the source distribution + Ct : array-like (d,d) or (m,d,d) + covariance of the target distribution + log : bool, optional + record log if True + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `Cs, Ct`. + + Returns + ------- + W : float if Cs and Cd of shape (d,d), array-like (n,) if Cs of shape (n,d,d), Ct of shape (d,d), array-like (m,) if Cs of shape (d,d) and Ct of shape (m,d,d), array-like (n,m) if Cs of shape (n,d,d) and mt of shape (m,d,d) + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + + .. _references-bures-wasserstein-distance: + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science. + Foundations and Trends® in Machine Learning, 11(5-6), 355-607. + """ + Cs, Ct = list_to_array(Cs, Ct) + + if nx is None: + nx = get_backend(Cs, Ct) + + assert Cs.shape[-1] == Ct.shape[-1], "All Gaussian must have the same dimension" + + Cs12 = nx.sqrtm(Cs) + + if len(Cs.shape) == 2 and len(Ct.shape) == 2: + # Return float + bw2 = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + elif len(Cs.shape) == 2: + # Return shape (m,) + M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M)) + elif len(Ct.shape) == 2: + # Return shape (n,) + M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M)) + else: + # Return shape (n,m) + M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) + bw2 = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + + W = nx.sqrt(nx.maximum(bw2, 0)) + + if log: + log = {} + log["Cs12"] = Cs12 + return W, log + else: + return W + + def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): r"""Return Bures Wasserstein distance between samples. - The function estimates the Bures-Wasserstein distance between two - empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, - discussed in remark 2.31 :ref:`[1] `. - - The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + The function computes the Bures-Wasserstein distance between :math:`\mu_s=\mathcal{N}(m_s,\Sigma_s)` and :math:`\mu_t=\mathcal{N}(m_t,\Sigma_t)`, + as discussed in remark 2.31 :ref:`[15] `. .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -218,21 +286,20 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Parameters ---------- - ms : array-like (d,) + ms : array-like (d,) or (n,d) mean of the source distribution - mt : array-like (d,) + mt : array-like (d,) or (m,d) mean of the target distribution - Cs : array-like (d,d) + Cs : array-like (d,d) or (n,d,d) covariance of the source distribution - Ct : array-like (d,d) + Ct : array-like (d,d) or (m,d,d) covariance of the target distribution log : bool, optional record log if True - Returns ------- - W : float + W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters @@ -242,16 +309,44 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): References ---------- - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. + .. [15] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science. + Foundations and Trends® in Machine Learning, 11(5-6), 355-607. """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) - Cs12 = nx.sqrtm(Cs) + assert ( + ms.shape[0] == Cs.shape[0] + ), "Source Gaussians has different amount of components" + + assert ( + mt.shape[0] == Ct.shape[0] + ), "Target Gaussians has different amount of components" + + assert ( + ms.shape[-1] == mt.shape[-1] == Cs.shape[-1] == Ct.shape[-1] + ), "All Gaussian must have the same dimension" + + if log: + bw, log_dict = bures_distance(Cs, Ct, log=log, nx=nx) + Cs12 = log_dict["Cs12"] + else: + bw = bures_distance(Cs, Ct, nx=nx) + + if len(ms.shape) == 1 and len(mt.shape) == 1: + # Return float + squared_dist_m = nx.norm(ms - mt) ** 2 + elif len(ms.shape) == 1: + # Return shape (m,) + squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2 + elif len(mt.shape) == 1: + # Return shape (n,) + squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2 + else: + # Return shape (n,m) + squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 - B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) + W = nx.sqrt(nx.maximum(squared_dist_m + bw**2, 0)) if log: log = {} @@ -338,81 +433,73 @@ def empirical_bures_wasserstein_distance( Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) if log: - W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) + W, log = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct, log=log) log["Cs"] = Cs log["Ct"] = Ct return W, log else: - W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) + W = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct) return W -def bures_wasserstein_barycenter( - m, C, weights=None, num_iter=1000, eps=1e-7, log=False +def bures_barycenter_fixpoint( + C, weights=None, num_iter=1000, eps=1e-7, log=False, nx=None ): - r"""Return OT linear operator between samples. - - The function estimates the optimal barycenter of the - empirical distributions. This is equivalent to resolving the fixed point - algorithm for multiple Gaussian distributions :math:`\left\{\mathcal{N}(\mu,\Sigma)\right\}_{i=1}^n` - :ref:`[1] `. + r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. - The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` - where : + The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` + :ref:`[16] ` by solving .. math:: - \mu_b = \sum_{i=1}^n w_i \mu_i + \Sigma_b = \mathrm{argmin}_{\Sigma \in S_d^{++}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big). - And the barycentric covariance is the solution of the following fixed-point algorithm: + The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)` + where :math:`\Sigma_b` is solution of the following fixed-point algorithm: .. math:: - \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} - + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}. Parameters ---------- - m : array-like (k,d) - mean of k distributions C : array-like (k,d,d) covariance of k distributions weights : array-like (k), optional weights for each distribution + method : str + method used for the solver, either 'fixed_point' or 'gradient_descent' num_iter : int, optional number of iteration for the fixed point algorithm eps : float, optional tolerance for the fixed point algorithm log : bool, optional record log if True - + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `C`. Returns ------- - mb : (d,) array-like - mean of the barycenter Cb : (d, d) array-like covariance of the barycenter log : dict log dictionary return only if log==True in parameters - .. _references-OT-mapping-linear-barycenter: + .. _references-OT-bures-barycenter-fixed-point: References ---------- - .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + .. [16] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ - nx = get_backend( - *C, - *m, - ) + if nx is None: + nx = get_backend( + *C, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] - # Compute the mean barycenter - mb = nx.sum(m * weights[:, None], axis=0) - # Init the covariance barycenter Cb = nx.mean(C * weights[:, None, None], axis=0) @@ -420,11 +507,7 @@ def bures_wasserstein_barycenter( # fixed point update Cb12 = nx.sqrtm(Cb) - Cnew = Cb12 @ C @ Cb12 - C_ = [] - for i in range(len(C)): - C_.append(nx.sqrtm(Cnew[i])) - Cnew = nx.stack(C_, axis=0) + Cnew = nx.sqrtm(Cb12 @ C @ Cb12) Cnew *= weights[:, None, None] Cnew = nx.sum(Cnew, axis=0) @@ -433,15 +516,289 @@ def bures_wasserstein_barycenter( if diff <= eps: break Cb = Cnew - else: + + if diff > eps: print("Dit not converge.") if log: log = {} log["num_iter"] = it log["final_diff"] = diff + return Cb, log + else: + return Cb + + +def bures_barycenter_gradient_descent( + C, + weights=None, + num_iter=1000, + eps=1e-7, + log=False, + step_size=1, + batch_size=None, + averaged=False, + nx=None, +): + r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. + + The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` + by using a gradient descent in the Wasserstein space :ref:`[74, 75] ` + on the objective + + .. math:: + \mathcal{L}(\Sigma) = \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i)\big). + + Parameters + ---------- + C : array-like (k,d,d) + covariance of k distributions + weights : array-like (k), optional + weights for each distribution + method : str + method used for the solver, either 'fixed_point' or 'gradient_descent' + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + log : bool, optional + record log if True + step_size : float, optional + step size for the gradient descent, 1 by default + batch_size : int, optional + batch size if use a stochastic gradient descent + averaged : bool, optional + if True, use the averaged procedure of :ref:`[74] ` + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `C`. + + Returns + ------- + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-bures-barycenter-gradient_descent: + References + ---------- + .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. + + .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. + """ + if nx is None: + nx = get_backend( + *C, + ) + + n = C.shape[0] + + if weights is None: + weights = nx.ones(C.shape[0], type_as=C[0]) / n + + # Init the covariance barycenter + Cb = nx.mean(C * weights[:, None, None], axis=0) + Id = nx.eye(C.shape[-1], type_as=Cb) + + L_diff = [] + + Cb_averaged = nx.copy(Cb) + + for it in range(num_iter): + Cb12 = nx.sqrtm(Cb) + Cb12_ = nx.inv(Cb12) + + if batch_size is not None and batch_size < n: # if stochastic gradient descent + if batch_size <= 0: + raise ValueError( + "batch_size must be an integer between 0 and {}".format(n) + ) + inds = np.random.choice( + n, batch_size, replace=True, p=nx._to_numpy(weights) + ) + M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) + ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) + grad_bw = Id - nx.mean(ot_maps, axis=0) + + # step size from [74] (page 15) + step_size = 2 / (0.7 * (it + 2 / 0.7 + 1)) + else: # gradient descent + M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12)) + ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) + grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0) + + Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx) + + if averaged: + # ot map between Cb_averaged and Cnew + Cb_averaged12 = nx.sqrtm(Cb_averaged) + Cb_averaged12inv = nx.inv(Cb_averaged12) + M = nx.sqrtm(nx.einsum("ij,jk,kl->il", Cb_averaged12, Cnew, Cb_averaged12)) + ot_map = nx.einsum("ij,jk,kl->il", Cb_averaged12inv, M, Cb_averaged12inv) + map = Id * step_size / (step_size + 1) + ot_map / (step_size + 1) + Cb_averaged = nx.einsum("ij,jk,kl->il", map, Cb_averaged, map) + + # check convergence + L_diff.append(nx.norm(Cb - Cnew)) + + # Criteria to stop + if np.mean(L_diff[-100:]) <= eps: + break + + Cb = Cnew + + if averaged: + Cb = Cb_averaged + + if log: + dict_log = {} + dict_log["num_iter"] = it + dict_log["final_diff"] = L_diff[-1] + return Cb, dict_log + else: + return Cb + + +def bures_wasserstein_barycenter( + m, + C, + weights=None, + method="fixed_point", + num_iter=1000, + eps=1e-7, + log=False, + step_size=1, + batch_size=None, +): + r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. + + The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\big(\mathcal{N}(\mu_i,\Sigma_i)\big)_{i=1}^n` + :ref:`[16, 74, 75] ` by solving + + .. math:: + (\mu_b, \Sigma_b) = \mathrm{argmin}_{\mu,\Sigma}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(\mu,\Sigma), \mathcal{N}(\mu_i, \Sigma_i)\big) + + The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` + where: + + .. math:: + \mu_b = \sum_{i=1}^n w_i \mu_i, + + and the barycentric covariance is the solution of the following fixed-point algorithm: + + .. math:: + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + + We propose two solvers: one based on solving the previous fixed-point problem [16]. Another based on + gradient descent in the Bures-Wasserstein space [74,75]. + + Parameters + ---------- + m : array-like (k,d) + mean of k distributions + C : array-like (k,d,d) + covariance of k distributions + weights : array-like (k), optional + weights for each distribution + method : str + method used for the solver, either 'fixed_point', 'gradient_descent', 'stochastic_gradient_descent' or + 'averaged_stochastic_gradient_descent' + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + log : bool, optional + record log if True + step_size : float, optional + step size for the gradient descent, 1 by default + batch_size : int, optional + batch size if use a stochastic gradient descent. If not None, use method='gradient_descent' + + + Returns + ------- + mb : (d,) array-like + mean of the barycenter + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-bures_wasserstein-barycenter: + References + ---------- + .. [16] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, + 2011. + + .. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). + Gradient descent algorithms for Bures-Wasserstein barycenters. + In Conference on Learning Theory (pp. 1276-1304). PMLR. + + .. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). + Averaging on the Bures-Wasserstein manifold: dimension-free convergence + of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. + """ + nx = get_backend( + *m, + ) + + if weights is None: + weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] + + # Compute the mean barycenter + mb = nx.sum(m * weights[:, None], axis=0) + + if method == "gradient_descent": + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + step_size=step_size, + nx=nx, + ) + elif method == "stochastic_gradient_descent": + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + batch_size=1 if batch_size is None else batch_size, + nx=nx, + ) + elif method == "averaged_stochastic_gradient_descent": + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + batch_size=1 if batch_size is None else batch_size, + averaged=True, + nx=nx, + ) + elif method == "fixed_point": + out = bures_barycenter_fixpoint( + C, weights=weights, num_iter=num_iter, eps=eps, log=log, nx=nx + ) + else: + raise ValueError("Unknown method '%s'." % method) + + if log: + Cb, log = out return mb, Cb, log else: + Cb = out return mb, Cb @@ -565,9 +922,8 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): .. _references-gaussien_gromov_wasserstein_distance: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(Cov_s, Cov_t) @@ -629,9 +985,9 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, wt=None, log .. _references-gaussien_gromov_wasserstein: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) @@ -697,9 +1053,9 @@ def gaussian_gromov_wasserstein_mapping( .. _references-gaussien_gromov_wasserstein_mapping: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(mu_s, mu_t, Cov_s, Cov_t) @@ -788,12 +1144,13 @@ def empirical_gaussian_gromov_wasserstein_mapping( b : (1, dt) array-like bias + .. _references-empirical_gaussian_gromov_wasserstein_mapping: References ---------- - .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein - distances between Gaussian distributions. Journal of Applied Probability, - 59(4), 1178-1198. + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). + Gromov–Wasserstein distances between Gaussian distributions. + Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) diff --git a/ot/utils.py b/ot/utils.py index 045ac5a6c..431226910 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1343,6 +1343,37 @@ def proj_SDP(S, nx=None, vmin=0.0): return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) +def exp_bures(Sigma, S, nx=None): + r""" + Exponential map in Bures-Wasserstein space at Sigma: + + .. math:: + \exp_\Sigma(S) = (I_d+S)\Sigma(I_d+S). + + Parameters + ---------- + Sigma : array-like (d,d) + SPD matrix + S : array-like (d,d) + Symmetric matrix + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrices `Sigma, S`. + + Returns + ------- + P : array-like (d,d) + SPD matrix obtained as the exponential map of S at Sigma + """ + if nx is None: + nx = get_backend(Sigma, S) + d = S.shape[-1] + Id = nx.eye(d, type_as=S) + C = Id + S + + return nx.einsum("ij,jk,kl -> il", C, Sigma, C) + + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. diff --git a/test/test_backend.py b/test/test_backend.py index 9abb83390..ff5685f6a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -614,6 +614,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("trace") + A = nx.trace(nx.stack([SquareMb, SquareMb], axis=0)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("broadcast trace") + A = nx.inv(SquareMb) lst_b.append(nx.to_numpy(A)) lst_name.append("matrix inverse") diff --git a/test/test_gaussian.py b/test/test_gaussian.py index eed562d15..a000a9590 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -92,11 +92,67 @@ def test_bures_wasserstein_distance(nx): msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + Wb2 = ot.gaussian.bures_distance(Csb, Ctb, log=False) np.testing.assert_allclose( nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(0, Wb2, rtol=1e-2, atol=1e-2) + + +def test_bures_wasserstein_distance_batch(nx): + n = 50 + k = 2 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False) + + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0][None], m[1, 0][None], C[0][None], C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1, 1)) + + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[1, 0][None], C, C[1][None]) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 1)) + + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0][None], m[1, 0], C[0][None], C[1] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) + + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0], m[1, 0][None], C[0], C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) + + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 2)) @pytest.mark.parametrize("bias", [True, False]) @@ -122,7 +178,16 @@ def test_empirical_bures_wasserstein_distance(nx, bias): np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) -def test_bures_wasserstein_barycenter(nx): +@pytest.mark.parametrize( + "method", + [ + "fixed_point", + "gradient_descent", + "stochastic_gradient_descent", + "averaged_stochastic_gradient_descent", + ], +) +def test_bures_wasserstein_barycenter(nx, method): n = 50 k = 10 X = [] @@ -143,23 +208,27 @@ def test_bures_wasserstein_barycenter(nx): m = nx.from_numpy(m) C = nx.from_numpy(C) - mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True) - mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False) + mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=method, log=True + ) + mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False) - np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Cb, Cblog, rtol=1e-1, atol=1e-1) np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2) # Test weights argument weights = nx.ones(k) / k mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter( - m, C, weights=weights, log=False + m, C, weights=weights, method=method, log=False ) - np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Cbw, Cb, rtol=1e-1, atol=1e-1) # test with closed form for diagonal covariance matrices Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)] Cdiag = nx.stack(Cdiag, axis=0) - mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False) + mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter( + m, Cdiag, method=method, log=False + ) Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag] Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0) @@ -169,6 +238,124 @@ def test_bures_wasserstein_barycenter(nx): np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2) +def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="fixed_point", log=False + ) + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="gradient_descent", log=False + ) + + np.testing.assert_allclose(mb, mb2, atol=1e-5) + np.testing.assert_allclose(Cb, Cb2, atol=1e-5) + + # Test weights argument + Cbw = ot.gaussian.bures_barycenter_fixpoint(C, weights=None) + Cbw2 = ot.gaussian.bures_barycenter_gradient_descent(C, weights=None) + np.testing.assert_allclose(Cbw, Cb, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(Cbw2, Cb2, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "method", ["stochastic_gradient_descent", "averaged_stochastic_gradient_descent"] +) +def test_stochastic_gd_bures_wasserstein_barycenter(nx, method): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method="fixed_point", log=False + ) + + loss = nx.mean(ot.gaussian.bures_wasserstein_distance(mb[None], m, Cb[None], C)) + + n_samples = [1, 5] + for n in n_samples: + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=method, log=False, batch_size=n + ) + + loss2 = nx.mean( + ot.gaussian.bures_wasserstein_distance(mb2[None], m, Cb2[None], C) + ) + + np.testing.assert_allclose(mb, mb2, atol=1e-5) + # atol big for now because too slow, need to see if + # it can be improved... + np.testing.assert_allclose(Cb, Cb2, atol=1e-1) + np.testing.assert_allclose(loss, loss2, atol=1e-3) + + with pytest.raises(ValueError): + mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=method, log=False, batch_size=-5 + ) + + +def test_not_implemented_method(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif("3gauss", n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + not_implemented = "new_method" + with pytest.raises(ValueError): + mb, Cb = ot.gaussian.bures_wasserstein_barycenter( + m, C, method=not_implemented, log=False + ) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_barycenter(nx, bias): n = 50 diff --git a/test/test_utils.py b/test/test_utils.py index d50f29915..3f5f9ec65 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -655,3 +655,30 @@ def test_kl_div(nx): kl_mass = nx.kl_div(xb, yb, True) recovered_kl = kl_mass - nx.sum(yb - xb) np.testing.assert_allclose(kl, recovered_kl) + + +def test_exp_bures(nx): + d = 2 + + rng = np.random.RandomState(42) + X = rng.randn(d, d) + z = rng.randn(d) + X, z = nx.from_numpy(X, z) + S = X + nx.transpose(X) + + Sigma = nx.eye(d, type_as=S) + + Lambda = ot.utils.exp_bures(Sigma, S) + + # asserst SPD + np.testing.assert_array_less(np.zeros(1), nx.to_numpy(z.T @ Lambda @ z)) + + # OT map from Lambda to Sigma + Lambda12 = nx.sqrtm(Lambda) + Lambda12inv = nx.inv(Lambda12) + M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda12, Sigma, Lambda12)) + T = nx.einsum("ij, jk, kl -> il", Lambda12inv, M, Lambda12inv) + + # exp_\Lambda(log_\Lambda(Sigma)) = Sigma + Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T)) + np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5)