Skip to content

[MRG] Partial optimal transport 1d solver #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
aef4b55
Implemented the partial optimal transport 1d solver from Chapel & Tav…
rtavenar Jul 1, 2025
c36a45d
bugfix in cythonize
rtavenar Jul 1, 2025
5a6ef86
yet another bugfix in setup.py
rtavenar Jul 1, 2025
978e0f4
relative imports
rtavenar Jul 1, 2025
189a98b
make partial_wasserstein_1d visible at the subpackage level
rtavenar Jul 1, 2025
e873da6
minor fix
rtavenar Jul 1, 2025
ea5f59b
add tests
rtavenar Jul 2, 2025
d0689aa
fix data gen in test
rtavenar Jul 2, 2025
ae504fa
bugfix
rtavenar Jul 2, 2025
5214d4b
make costs double
rtavenar Jul 2, 2025
8245cd9
renaming
rtavenar Jul 2, 2025
8a78d6d
minor
rtavenar Jul 2, 2025
796d09d
remove unused log arg
rtavenar Jul 2, 2025
a0c241d
check precommit
rtavenar Jul 2, 2025
5b37670
better docs and test
rtavenar Jul 2, 2025
6ef4ca3
linting
rtavenar Jul 2, 2025
7eee412
info
rtavenar Jul 2, 2025
b68aa7d
empty commit for co-authorship
rtavenar Jul 2, 2025
843324c
add gallery example
rtavenar Jul 2, 2025
ff42483
minor refactor
rtavenar Jul 2, 2025
ec7d9c1
bugfix: use heapq also at init step
rtavenar Jul 2, 2025
bbc930b
define a function for plotting
rtavenar Jul 2, 2025
2ba8d56
bugfix
rtavenar Jul 2, 2025
9958a31
example fig tweaking
rtavenar Jul 2, 2025
27ec673
minor docs
rtavenar Jul 2, 2025
f11575d
figsize
rtavenar Jul 3, 2025
98967fb
minor
rtavenar Jul 3, 2025
63b5c05
removed pure python types as much as possible (yet to be tested prope…
rtavenar Jul 3, 2025
6ccad78
bugfix in insert_new_chain
rtavenar Jul 4, 2025
11ab6b3
added tests
rtavenar Jul 4, 2025
eba542d
should not change anything
rtavenar Jul 4, 2025
86031ad
bugfix in outofbounds
rtavenar Aug 5, 2025
6463d22
Merge branch 'master' into partial_1d
rtavenar Aug 5, 2025
fd732da
added bib reference
rtavenar Aug 5, 2025
29754f7
install sphinx-gallery from pypi
rtavenar Aug 5, 2025
59654e8
remove Cython-generated C++ code
rtavenar Aug 6, 2025
4668633
do not include Cython-generated cpp in future commits
rtavenar Aug 6, 2025
a0d2d87
faster tests
rtavenar Aug 6, 2025
b96ed78
better labelling for partial
rtavenar Aug 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ ot.solvers:

ot.partial:
- changed-files:
- any-glob-to-any-file: ot/partial.py
- any-glob-to-any-file: ot/partial/**

ot.sliced:
- changed-files:
Expand All @@ -94,4 +94,4 @@ ot.dr:

ot.gnn:
- changed-files:
- any-glob-to-any-file: ot/gnn/**
- any-glob-to-any-file: ot/gnn/**
2 changes: 1 addition & 1 deletion .github/workflows/build_doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
python -m pip install --user --upgrade --progress-bar off pip
python -m pip install --user --upgrade --progress-bar off -r requirements_all.txt
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler
python -m pip install --user -e .
# Look at what we have and fail early if there is some library conflict
- name: Check installation
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ docs/modules/

# Cython output
ot/lp/emd_wrap.cpp
ot/partial/partial_cython.cpp

# Distribution / packaging
.Python
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,5 @@ Artificial Intelligence.
[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.

[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.

[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Implement 1d solver for partial optimal transport (PR #741)
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
- Removed release information from quickstart guide (PR #744)
Expand Down
85 changes: 85 additions & 0 deletions examples/unbalanced-partial/plot_partial_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
=========================
Partial Wasserstein in 1D
=========================

This script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.

We illustrate the intermediate transport plans for all `k = 1...n`, where `n = min(len(x_a), len(x_b))`.
"""

# sphinx_gallery_thumbnail_number = 5

import numpy as np
import matplotlib.pyplot as plt
from ot.partial import partial_wasserstein_1d


def plot_partial_transport(
ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None
):
y_a = np.ones_like(x_a)
y_b = -np.ones_like(x_b)
min_min = min(x_a.min(), x_b.min())
max_max = max(x_a.max(), x_b.max())

ax.plot([min_min - 1, max_max + 1], [1, 1], "k-", lw=0.5, alpha=0.5)
ax.plot([min_min - 1, max_max + 1], [-1, -1], "k-", lw=0.5, alpha=0.5)

# Plot transport lines
if indices_a is not None and indices_b is not None:
subset_a = np.sort(x_a[indices_a])
subset_b = np.sort(x_b[indices_b])

for x_a_i, x_b_j in zip(subset_a, subset_b):
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)

# Plot all points
ax.plot(x_a, y_a, "o", color="C0", label="x_a", markersize=8)
ax.plot(x_b, y_b, "o", color="C1", label="x_b", markersize=8)

if marginal_costs is not None:
k = len(marginal_costs)
ax.set_title(
f"Partial Transport - k = {k}, Cumulative Cost = {sum(marginal_costs):.2f}",
fontsize=16,
)
else:
ax.set_title("Original 1D Discrete Distributions", fontsize=16)
ax.legend(loc="upper right", fontsize=14)
ax.set_yticks([])
ax.set_xticks([])
ax.set_ylim(-2, 2)
ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)
ax.axis("off")


# Simulate two 1D discrete distributions
np.random.seed(0)
n = 6
x_a = np.sort(np.random.uniform(0, 10, size=n))
x_b = np.sort(np.random.uniform(0, 10, size=n))

# Plot original distributions
plt.figure(figsize=(6, 2))
plot_partial_transport(plt.gca(), x_a, x_b)
plt.show()

# %%
indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)

# Compute cumulative cost
cumulative_costs = np.cumsum(marginal_costs)

# Visualize all partial transport plans
for k in range(n):
plt.figure(figsize=(6, 2))
plot_partial_transport(
plt.gca(),
x_a,
x_b,
indices_a[: k + 1],
indices_b[: k + 1],
marginal_costs[: k + 1],
)
plt.show()
37 changes: 37 additions & 0 deletions ot/partial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
"""
Efficient 1D solver for the partial optimal transport problem.
"""

# Author: Romain Tavenard <[email protected]>
#
# License: MIT License

# import compiled emd
from .partial_solvers import (
partial_wasserstein_lagrange,
partial_wasserstein,
partial_wasserstein2,
entropic_partial_wasserstein,
gwgrad_partial,
gwloss_partial,
partial_gromov_wasserstein,
partial_gromov_wasserstein2,
entropic_partial_gromov_wasserstein,
entropic_partial_gromov_wasserstein2,
partial_wasserstein_1d,
)

__all__ = [
"partial_wasserstein_1d",
"partial_wasserstein_lagrange",
"partial_wasserstein",
"partial_wasserstein2",
"entropic_partial_wasserstein",
"gwgrad_partial",
"gwloss_partial",
"partial_gromov_wasserstein",
"partial_gromov_wasserstein2",
"entropic_partial_gromov_wasserstein",
"entropic_partial_gromov_wasserstein2",
]
Loading
Loading