Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters #680

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7eb14d2
bw barycenter with batched sqrtm
clbonet Oct 17, 2024
869955c
BWGD for barycenters
clbonet Oct 19, 2024
be985d1
sbwgd for barycenters
clbonet Oct 19, 2024
9a43369
Test fixed_point vs gradient_descent
clbonet Oct 19, 2024
016704b
Merge branch 'master' into bwgd_barycenter
cedricvincentcuaz Oct 23, 2024
8afc00b
fix test bwgd
clbonet Oct 25, 2024
b2b0bca
nx exp_bures
clbonet Oct 25, 2024
d287a2a
update doc
clbonet Oct 25, 2024
9377405
Merge branch 'master' into bwgd_barycenter
clbonet Oct 31, 2024
4f648bb
fix merge
clbonet Oct 31, 2024
b821ee8
doc exp bw
clbonet Oct 31, 2024
d22028b
First tests stochastic + exp
clbonet Nov 5, 2024
dffa0cf
exp_bures with einsum
clbonet Nov 5, 2024
f3e911a
type Id test
clbonet Nov 6, 2024
97f2261
up test stochastic
clbonet Nov 7, 2024
7594393
test weights
clbonet Nov 7, 2024
6c48b3c
Add BW distance with batchs
clbonet Nov 7, 2024
ba806ff
step size SGD BW Barycenter
clbonet Nov 11, 2024
7ab365a
Merge branch 'master' into bwgd_barycenter
rflamary Nov 12, 2024
447a1a6
Merge branch 'master' into bwgd_barycenter
rflamary Nov 19, 2024
d4045f1
batchable BW distance
clbonet Nov 19, 2024
f669a8e
Merge branch 'master' into bwgd_barycenter
cedricvincentcuaz Dec 1, 2024
6c0a2a0
Merge branch 'master' into bwgd_barycenter
rflamary Dec 17, 2024
5da317f
Merge branch 'master' into bwgd_barycenter
cedricvincentcuaz Jan 13, 2025
2b317e2
Merge branch 'master' into bwgd_barycenter
clbonet Jan 27, 2025
50994ed
RELEASES.md
clbonet Feb 12, 2025
bad385f
precommit
clbonet Feb 12, 2025
0b20759
Add ot.gaussian.bures
clbonet Feb 12, 2025
fe3d9db
Add arg backend
clbonet Feb 13, 2025
506a524
up stop criteria sgd Gaussian barycenter
clbonet Feb 16, 2025
c640ecb
Fix release
clbonet Feb 16, 2025
41ebffc
fix doc
clbonet Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,7 @@ Artificial Intelligence.
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).

[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.

[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.

[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.
227 changes: 198 additions & 29 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# License: MIT License

import warnings
import numpy as np

from .backend import get_backend
from .utils import dots, is_all_finite, list_to_array
from .utils import dots, is_all_finite, list_to_array, exp_bures


def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
Expand Down Expand Up @@ -344,79 +345,62 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
return W


def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False):
r"""Return OT linear operator between samples.
def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=False):
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.

The function estimates the optimal barycenter of the
empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.

The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :
The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>` by solving

.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i
\Sigma_b = \argmin_{\Sigma \in S_d^{+}(\mathbb{R})}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0, \Sigma_i)\big)

And the barycentric covariance is the solution of the following fixed-point algorithm:
The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)`
where :math: `\Sigma_b` is solution of the following fixed-point algorithm:

.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}


Parameters
----------
m : array-like (k,d)
mean of k distributions
C : array-like (k,d,d)
covariance of k distributions
weights : array-like (k), optional
weights for each distribution
method : str
method used for the solver, either 'fixed_point' or 'gradient_descent'
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
log : bool, optional
record log if True


Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
nx = get_backend(*C, *m,)
nx = get_backend(*C,)

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]

# Compute the mean barycenter
mb = nx.sum(m * weights[:, None], axis=0)

# Init the covariance barycenter
Cb = nx.mean(C * weights[:, None, None], axis=0)

for it in range(num_iter):
# fixed point update
Cb12 = nx.sqrtm(Cb)

Cnew = Cb12 @ C @ Cb12
C_ = []
for i in range(len(C)):
C_.append(nx.sqrtm(Cnew[i]))
Cnew = nx.stack(C_, axis=0)
Cnew = nx.sqrtm(Cb12 @ C @ Cb12)
Cnew *= weights[:, None, None]
Cnew = nx.sum(Cnew, axis=0)

Expand All @@ -425,15 +409,200 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo
if diff <= eps:
break
Cb = Cnew

if diff > eps:
print("Dit not converge.")

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
return Cb, log
else:
return Cb


def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None):
r"""Return OT linear operator between covariances.

The function estimates the optimal barycenter of empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(0,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.

The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(0,\Sigma_b)`

Parameters
----------
C : array-like (k,d,d)
covariance of k distributions
weights : array-like (k), optional
weights for each distribution
method : str
method used for the solver, either 'fixed_point' or 'gradient_descent'
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
log : bool, optional
record log if True
step_size: float, optional
step size for the gradient descent, 1 by default
batch_size: int, optional
batch size if use a stochastic gradient descent

Returns
-------
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters

References
----------
.. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020).
Gradient descent algorithms for Bures-Wasserstein barycenters.
In Conference on Learning Theory (pp. 1276-1304). PMLR.

.. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021).
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
"""
nx = get_backend(*C,)

n = C.shape[0]

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / n

# Init the covariance barycenter
Cb = nx.mean(C * weights[:, None, None], axis=0)
Id = nx.eye(C.shape[-1], type_as=Cb)

for it in range(num_iter):
Cb12 = nx.sqrtm(Cb)
Cb12_ = nx.inv(Cb12)

if batch_size is not None and batch_size < n: # if stochastic gradient descent
if batch_size <= 0:
raise ValueError("batch_size must be an integer between 0 and {}".format(n))
inds = np.random.choice(n, batch_size, replace=True, p=nx._to_numpy(weights))
M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12))
grad_bw = Id - nx.mean(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_), axis=0)
else: # gradient descent
M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12))
grad_bw = Id - nx.sum(nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) * weights[:, None, None], axis=0)

Cnew = exp_bures(Cb, - step_size * grad_bw)

# Right criteria? (for GD, seems fine, but for SGD?)
# check convergence
diff = nx.norm(Cb - Cnew)
if diff <= eps:
break

Cb = Cnew

if diff > eps:
print("Dit not converge.")

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
return Cb, log
else:
return Cb


def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None):
r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions.

The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\left{\mathcal{N}(\mu_i,\Sigma_i)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>` by solving

.. math::
(\mu_b, \Sigma_b) = \argmin_{\mu,\Sigma}\ \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(\mu,\Sigma), \mathcal{N}(\mu_i, \Sigma_i)\big)

The barycenter still follows a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :

.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i

And the barycentric covariance is the solution of the following fixed-point algorithm:

.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}

We propose two solvers: one based on solving the previous fixed-point problem [1]. Another based on
gradient descent in the Bures-Wasserstein space [74,75].

Parameters
----------
m : array-like (k,d)
mean of k distributions
C : array-like (k,d,d)
covariance of k distributions
weights : array-like (k), optional
weights for each distribution
method : str
method used for the solver, either 'fixed_point' or 'gradient_descent'
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
log : bool, optional
record log if True
step_size: float, optional
step size for the gradient descent, 1 by default
batch_size: int, optional
batch size if use a stochastic gradient descent. If not None, use method='gradient_descent'


Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-OT-bures_wasserstein-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.

.. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020).
Gradient descent algorithms for Bures-Wasserstein barycenters.
In Conference on Learning Theory (pp. 1276-1304). PMLR.

.. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021).
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
"""
nx = get_backend(*m,)

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]

# Compute the mean barycenter
mb = nx.sum(m * weights[:, None], axis=0)

if method == "gradient_descent" or batch_size is not None:
out = bures_barycenter_gradient_descent(C, weights=weights, num_iter=num_iter, eps=eps, log=log, step_size=step_size, batch_size=batch_size)
elif method == "fixed_point":
out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, eps=eps, log=log)
else:
raise ValueError("Unknown method '%s'." % method)

if log:
Cb, log = out
return mb, Cb, log
else:
Cb = out
return mb, Cb


Expand Down
12 changes: 12 additions & 0 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,3 +1308,15 @@ def proj_SDP(S, nx=None, vmin=0.):
Q = nx.einsum('ijk,ik->ijk', P, w) # Q[i] = P[i] @ diag(w[i])
# R[i] = Q[i] @ P[i].T
return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1)))


def exp_bures(Sigma, S):
r"""
Exponential map Bures-Wasserstein space as Sigma: \exp_\Sigma(S)
"""
nx = get_backend(S)
d = S.shape[-1]
Id = nx.eye(d, type_as=S)
C = Id + S

return dots(C, Sigma, C)
37 changes: 33 additions & 4 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_empirical_bures_wasserstein_distance(nx, bias):
np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("method", ["fixed_point", "gradient_descent"])
def test_bures_wasserstein_barycenter(nx):
n = 50
k = 10
Expand All @@ -129,21 +130,21 @@ def test_bures_wasserstein_barycenter(nx):
m = nx.from_numpy(m)
C = nx.from_numpy(C)

mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True)
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False)
mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=True)
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False)

np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)

# Test weights argument
weights = nx.ones(k) / k
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False)
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, method=method, log=False)
np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2)

# test with closed form for diagonal covariance matrices
Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)]
Cdiag = nx.stack(Cdiag, axis=0)
mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False)
mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, method=method, log=False)

Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag]
Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0)
Expand All @@ -153,6 +154,34 @@ def test_bures_wasserstein_barycenter(nx):
np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2)


def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx):
n = 50
k = 10
X = []
y = []
m = []
C = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
m_ = np.mean(X_, axis=0)[None, :]
C_ = np.cov(X_.T)
X.append(X_)
y.append(y_)
m.append(m_)
C.append(C_)
m = np.array(m)
C = np.array(C)
X = nx.from_numpy(*X)
m = nx.from_numpy(m)
C = nx.from_numpy(C)

mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method="fixed_point", log=False)
mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(m, C, method="gradient_descent", log=False)

np.testing.assert_allclose(mb, mb2, atol=1e-5)
np.testing.assert_allclose(Cb, Cb2, atol=1e-5)


@pytest.mark.parametrize("bias", [True, False])
def test_empirical_bures_wasserstein_barycenter(nx, bias):
n = 50
Expand Down
Loading