Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entropy-regularised Gromov-Wasserstein #165

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2ef3e2b
first attempt at gromov-wasserstein
zsteve Oct 2, 2021
11efd8c
update
zsteve Mar 8, 2022
3273976
Merge branch 'master' into gromov
zsteve Mar 8, 2022
0956c3b
fixed computation of entropic gromov-wasserstein
zsteve Mar 8, 2022
c22d7e7
fixed computation of entropic gromov-wasserstein
zsteve Mar 8, 2022
ff1a92c
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve Mar 8, 2022
267dfad
exports and tests
zsteve Mar 8, 2022
21609b0
formatting
zsteve Mar 12, 2022
9699e04
Update test/gpu/simple_gpu.jl
zsteve Mar 12, 2022
8510397
update docstrings
zsteve Mar 12, 2022
2f2428f
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve Mar 12, 2022
20d5885
delete cache file
zsteve Mar 12, 2022
df41c28
add docs and format
zsteve Mar 12, 2022
a7c1a38
remove unnecessary Logging import
zsteve Mar 12, 2022
19e4cab
fix missing power of 2
zsteve Mar 13, 2022
56c4f9b
pull changes from master
zsteve Aug 28, 2022
6e3ac4c
update version number
zsteve Aug 28, 2022
5c376ae
add docs workflow
zsteve Dec 20, 2022
af2a493
add Gromov-Wasserstein to readme
zsteve Jan 25, 2023
6bc3127
bump Julia ver for CI
zsteve Jan 25, 2023
a806f0f
minor edit to runtests
zsteve Jan 25, 2023
f704397
Update .github/workflows/CI.yml
zsteve Jan 27, 2023
71351b9
Update test/runtests.jl
zsteve Jan 27, 2023
f2acc56
delete junk files/dirs
zsteve Jan 27, 2023
0635305
revert runtests.jl
zsteve Jan 27, 2023
c3efe5a
avoid unnecessary allocations
zsteve Jan 27, 2023
39f0b36
format
zsteve Jan 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
matrix:
version:
- '1.6'
- '1.8'
zsteve marked this conversation as resolved.
Show resolved Hide resolved
- '1'
- 'nightly'
os:
Expand Down
Empty file added .github/workflows/main
Empty file.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.20"
version = "0.3.21"

[deps]
ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
[![Coveralls](https://coveralls.io/repos/github/JuliaOptimalTransport/OptimalTransport.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaOptimalTransport/OptimalTransport.jl?branch=master)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)

This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as some variants or extensions.
This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as variants and extensions, including unbalanced transport and Gromov-Wasserstein matching.

Notably, OptimalTransport.jl provides GPU acceleration through [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl/) and [NNlibCUDA.jl](https://github.com/FluxML/NNlibCUDA.jl).

Expand Down
9 changes: 9 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ Currently the following algorithms for solving quadratically regularised optimal
QuadraticOTNewton
```

## Gromov-Wasserstein optimal transport

```@docs
entropic_gromov_wasserstein
```

Currently, only entropy-regularised Gromov-Wasserstein is supported. For exact computations, we refer the user to
[PythonOT](https://github.com/JuliaOptimalTransport/PythonOT.jl) to access functionality from the [Python Optimal Transport library](https://pythonot.github.io/).

## Dual

```@docs
Expand Down
129 changes: 129 additions & 0 deletions gpu/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# This file is machine-generated - editing it directly is not advised
zsteve marked this conversation as resolved.
Show resolved Hide resolved

julia_version = "1.7.0"
manifest_format = "2.0"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[deps.Conda]]
deps = ["Downloads", "JSON", "VersionParsing"]
git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.7.0"

[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.LinearAlgebra]]
deps = ["Libdl", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.9"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"

[[deps.Parsers]]
deps = ["Dates"]
git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.2.3"

[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.PyCall]]
deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"]
git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc"
uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
version = "1.93.1"

[[deps.Random]]
deps = ["SHA", "Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.VersionParsing]]
git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868"
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
version = "1.3.0"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
2 changes: 2 additions & 0 deletions gpu/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
zsteve marked this conversation as resolved.
Show resolved Hide resolved
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
4 changes: 4 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ using NNlib: NNlib
export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
export QuadraticOTNewton
export EntropicGromovWassersteinSinkhorn

export sinkhorn, sinkhorn2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export sinkhorn_divergence, sinkhorn_divergence_unbalanced
export quadreg
export entropic_gromov_wasserstein

include("utils.jl")

Expand All @@ -42,4 +44,6 @@ include("quadratic_newton.jl")

include("dual/entropic_dual.jl")

include("gromov.jl")

end
91 changes: 91 additions & 0 deletions src/gromov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Gromov-Wasserstein solver

abstract type EntropicGromovWasserstein end

struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein
alg_step::Sinkhorn
end

"""
entropic_gromov_wasserstein(
μ, ν, Cμ, Cν, ε, alg=EntropicGromovWassersteinSinkhorn(SinkhornGibbs());
atol = nothing, rtol = nothing, check_convergence = 10, maxiter = 1_000, kwargs...
)

Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target
marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}|^2 \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma),
```
where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref).

This function employs the iterative method described in (Section 10.6.4, [^PC19]), which solves a series of Sinkhorn iteration sub-problems to arrive at a solution. Note that the Gromov-Wasserstein problem is non-convex owing to the cross-terms in the
objective function, and thus in general one is guaranteed to arrive at a local optimum.

Every `check_convergence` steps, the current iteration of `γ` is compared with `γ_prev` (the previous iteration from `check_convergence` ago).
The quantity ``\\| \\gamma - \\gamma_\\text{prev} \\|_1`` is compared against `atol` and `rtol`.

[^PC19]: Peyré, G. and Cuturi, M., 2019. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6), pp.355-607.

See also: [`sinkhorn`](@ref)
"""
function entropic_gromov_wasserstein(
μ::AbstractVector,
ν::AbstractVector,
Cμ::AbstractMatrix,
Cν::AbstractMatrix,
ε::Real,
alg::EntropicGromovWasserstein=EntropicGromovWassersteinSinkhorn(SinkhornGibbs());
atol=nothing,
rtol=nothing,
check_convergence=10,
maxiter::Int=1_000,
kwargs...,
)
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν)))
C = similar(Cμ, T, size(μ, 1), size(ν, 1))
tmp = similar(C)
plan = similar(C)
@. plan = μ * ν'
plan_prev = similar(C)
plan_prev .= plan
norm_plan = sum(plan)

_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

function get_new_cost!(C, plan, tmp, Cμ, Cν)
A_batched_mul_B!(tmp, Cμ, plan)
return A_batched_mul_B!(C, tmp, -4Cν)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiplication with -4 introduces additional allocations. I wonder if this could be avoided, e.g., by updating in-place or some additional cache.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dealt with now (in-place scaling of tmp)

# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation?
# added the factor of 4 here to ensure reproducibility for the same value of ε.
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247
end

get_new_cost!(C, plan, tmp, Cμ, Cν)
to_check_step = check_convergence

isconverged = false
for iter in 1:maxiter
# perform Sinkhorn algorithm
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...)
solve!(solver)
# compute optimal transport plan
plan = sinkhorn_plan(solver)

to_check_step -= 1
if to_check_step == 0 || iter == maxiter
# reset counter
to_check_step = check_convergence
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_plan is never updated it seems but always set to sum(plan) of the initial randomly initialized plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also avoid allocations here by writing:

Suggested change
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan)
plan_prev .-= plan # used as a temporary array here to reduce allocations
isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_plan is never updated it seems but always set to sum(plan) of the initial randomly initialized plan?

The initial plan is taken to be the independent coupling and here we only consider the balanced problem, so norm_plan should not change. I agree however this is a special case of the unbalanced problem where it would not be constant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also avoid allocations here by writing:

Good catch, done

if isconverged
@debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged"
break
end
plan_prev .= plan
end
get_new_cost!(C, plan, tmp, Cμ, Cν)
end

return plan
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
zsteve marked this conversation as resolved.
Show resolved Hide resolved
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
31 changes: 31 additions & 0 deletions test/gromov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using OptimalTransport

using Distances
using PythonOT: PythonOT

using Random
using Test
using LinearAlgebra

const POT = PythonOT

Random.seed!(100)

@testset "gromov.jl" begin
@testset "entropic_gromov_wasserstein" begin
M, N = 250, 200

μ = fill(1 / M, M)
μ_spt = rand(M)
ν = fill(1 / N, N)
ν_spt = rand(N)

Cμ = pairwise(SqEuclidean(), μ_spt)
Cν = pairwise(SqEuclidean(), ν_spt)

γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence=10)
γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01)

@test γ ≈ γ_pot rtol = 1e-6
end
end
8 changes: 6 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OptimalTransport
using Pkg: Pkg
using Pkg
zsteve marked this conversation as resolved.
Show resolved Hide resolved
using SafeTestsets

using Test
Expand Down Expand Up @@ -36,10 +36,14 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Quadratically regularized OT" begin
include("quadratic.jl")
end

@safetestset "Gromov-Wasserstein OT" begin
include("gromov.jl")
end
end

# CUDA requires Julia >= 1.6
if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.6"
if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.8"
zsteve marked this conversation as resolved.
Show resolved Hide resolved
# activate separate environment: CUDA can't be added to test/Project.toml since it
# is not available on older Julia versions
Pkg.activate("gpu")
Expand Down