From 2c7250435cab19c5e0b88861ea4fa88b1a5ddf00 Mon Sep 17 00:00:00 2001 From: David Nies Date: Fri, 14 Feb 2020 22:56:03 +0100 Subject: [PATCH 1/6] Remove obsolete comment --- test/test_dists.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/test/test_dists.jl b/test/test_dists.jl index 06d2bc1..32c63e7 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -595,20 +595,3 @@ end end @test bregman(F, ∇, p, q) ≈ ISdist(p, q) end - -#= -@testset "zero allocation colwise!" begin - d = Euclidean() - a = rand(2, 41) - b = rand(2, 41) - z = zeros(41) - colwise!(z, d, a, b) - # This fails when bounds checking is enforced - bounds = Base.JLOptions().check_bounds - if bounds == 0 - @test (@allocated colwise!(z, d, a, b)) == 0 - else - @test_broken (@allocated colwise!(z, d, a, b)) == 0 - end -end -=# From ee6f804f92c8c147d464f662c319f5679bf75b00 Mon Sep 17 00:00:00 2001 From: David Nies Date: Fri, 14 Feb 2020 23:15:29 +0100 Subject: [PATCH 2/6] Add function and test stubs for Wasserstein distance --- src/Distances.jl | 3 +++ src/wasserstein.jl | 16 ++++++++++++++++ test/test_dists.jl | 22 ++++++++++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 src/wasserstein.jl diff --git a/src/Distances.jl b/src/Distances.jl index faef4f8..62dea96 100644 --- a/src/Distances.jl +++ b/src/Distances.jl @@ -58,6 +58,7 @@ export RMSDeviation, NormRMSDeviation, Bregman, + Wasserstein, # convenient functions euclidean, @@ -91,6 +92,7 @@ export bhattacharyya, hellinger, bregman, + wasserstein, haversine, @@ -107,5 +109,6 @@ include("haversine.jl") include("mahalanobis.jl") include("bhattacharyya.jl") include("bregman.jl") +include("wasserstein.jl") end # module end diff --git a/src/wasserstein.jl b/src/wasserstein.jl new file mode 100644 index 0000000..cee3c11 --- /dev/null +++ b/src/wasserstein.jl @@ -0,0 +1,16 @@ +# Wasserstein distance + +struct Wasserstein <: Metric + p::Float64 + + function Wasserstein(p::Float64) + @assert p >= 1 + new(p) + end +end + +Wasserstein() = Wasserstein(1.0) + +function (dist::Wasserstein)(a::AbstractVector, b::AbstractVector) + throw("implement me") +end diff --git a/test/test_dists.jl b/test/test_dists.jl index 32c63e7..cd31ede 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -1,5 +1,12 @@ # Unit tests for Distances +"Mapx `x` to the probability simplex" +function positive_and_normed(x) + abs_x = abs.(x) + M = sum(abs_x) + abs_x./M +end + function test_metricity(dist, x, y, z) @testset "Test metricity of $(typeof(dist))" begin @test dist(x, y) == evaluate(dist, x, y) @@ -117,6 +124,11 @@ end test_metricity(RenyiDivergence(2), p, q, r) test_metricity(RenyiDivergence(10), p, q, r) test_metricity(JSDivergence(), p, q, r) + + let x′, y′, z′ + x′, y′, z′ = positive_and_normed.([x, y, z]) + test_metricity(Wasserstein(), x′, y′, z′) + end end @testset "individual metrics" begin @@ -177,6 +189,13 @@ end @test weuclidean(x, y, w) == sqrt(wsqeuclidean(x, y, w)) @test wcityblock(x, y, w) ≈ dot(abs.(x - vec(y)), w) @test wminkowski(x, y, w, 2) ≈ weuclidean(x, y, w) + + let x′, y′ = positive_and_normed.[x, y] + @test wasserstein(x′, y′) == 88.0 + @test wasserstein(x′, y′, 2) == 89.0 + end + @test_throws AssertionError wasserstein(x, y) + @test_throws AssertionError Wasserstein(0.5) end # Test ChiSq doesn't give NaN at zero @@ -267,6 +286,7 @@ end #testset @test isa(renyi_divergence(a, b, 2.0), T) @test braycurtis(a, b) == 0.0 @test isa(braycurtis(a, b), T) + @test isa(wasserstein(a, b), T) w = T[] @test isa(whamming(a, b, w), T) @@ -478,6 +498,7 @@ end test_colwise(SqMahalanobis(Q), X, Y, T) test_colwise(Mahalanobis(Q), X, Y, T) + test_colwise(Wasserstein(), X, Y, T) end function test_pairwise(dist, x, y, T) @@ -555,6 +576,7 @@ end test_pairwise(SqMahalanobis(Q), X, Y, T) test_pairwise(Mahalanobis(Q), X, Y, T) + test_pairwise(Wasserstein, X, Y, T) end @testset "Euclidean precision" begin From 15669518f9bea40aacc43c364d5a81749f892c1d Mon Sep 17 00:00:00 2001 From: David Nies Date: Fri, 14 Feb 2020 23:19:59 +0100 Subject: [PATCH 3/6] Add `Cbc` and `JuMP` to dependencies --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index fc3d6b7..f9cab8b 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,8 @@ uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" version = "0.8.2" [deps] +Cbc = "9961bab8-2fa3-5c5a-9d89-47fab24efd76" +JuMP = "4076af6c-e467-56ae-b986-b466b2749572" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" From f85ef89e140cb23aba5e5cd99c3596211b79db35 Mon Sep 17 00:00:00 2001 From: David Nies Date: Sun, 16 Feb 2020 00:24:28 +0100 Subject: [PATCH 4/6] Implement model for linear program and solve it for p-Wasserstein --- src/wasserstein.jl | 56 +++++++++++++++++++++++++++++++++++++++++++++- test/test_dists.jl | 4 +++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/wasserstein.jl b/src/wasserstein.jl index cee3c11..56fc8cc 100644 --- a/src/wasserstein.jl +++ b/src/wasserstein.jl @@ -1,5 +1,11 @@ # Wasserstein distance +using JuMP: Model, AffExpr, with_optimizer, @variable, @constraint, + @objective, add_to_expression!, optimize!, termination_status, + objective_value +import JuMP +import Cbc + struct Wasserstein <: Metric p::Float64 @@ -12,5 +18,53 @@ end Wasserstein() = Wasserstein(1.0) function (dist::Wasserstein)(a::AbstractVector, b::AbstractVector) - throw("implement me") + @assert length(a) == length(b) + @assert sum(a) ≈ 1.0 atol=1e-6 + @assert sum(b) ≈ 1.0 atol=1e-6 + + model = make_wasserstein_model(a, b, dist.p) + optimize!(model) + @assert termination_status(model) == JuMP.MOI.OPTIMAL + + objective_value(model)^(1/dist.p) +end + +""" +Create JuMP `Model` for linear program to calculate the p-Wasserstein distance +of two discrete vectors from the same probability simplex. See also formula +(2.5) in [Optimal Transport on Discrete Domains](https://arxiv.org/abs/1801.07745). +""" +function make_wasserstein_model(a::AbstractVector, b::AbstractVector, p::Float64) :: Model + model = Model(with_optimizer(Cbc.Optimizer, logLevel=0)) + + N = length(a) + T = @variable(model, T[1:N, 1:N] >= 0) + + for i in 1:N + row_expression = AffExpr() + for j in 1:N + add_to_expression!(row_expression, 1.0, T[i, j]) + end + @constraint(model, row_expression == a[i]) + end + + for j in 1:N + column_expression = AffExpr() + for i in 1:N + add_to_expression!(column_expression, 1.0, T[i, j]) + end + @constraint(model, column_expression == b[j]) + end + + objective_expression = AffExpr() + for i in 1:N + for j in 1:N + add_to_expression!(objective_expression, abs(i - j)^p, T[i, j]) + end + end + @objective(model, Min, objective_expression) + + model end + +wasserstein(a::AbstractVector, b::AbstractVector, p::Float64=1.0) = Wasserstein(p)(a, b) \ No newline at end of file diff --git a/test/test_dists.jl b/test/test_dists.jl index cd31ede..ee38d1a 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -190,8 +190,10 @@ end @test wcityblock(x, y, w) ≈ dot(abs.(x - vec(y)), w) @test wminkowski(x, y, w, 2) ≈ weuclidean(x, y, w) - let x′, y′ = positive_and_normed.[x, y] + let x′, y′ + x′, y′ = positive_and_normed.([x, y]) @test wasserstein(x′, y′) == 88.0 + @test wasserstein(x′, y′) != wasserstein(x′, y′, 2) @test wasserstein(x′, y′, 2) == 89.0 end @test_throws AssertionError wasserstein(x, y) From 58ceee2216d8ce8a0d8ce99e05345d9dc58886c0 Mon Sep 17 00:00:00 2001 From: David Nies Date: Sun, 16 Feb 2020 01:45:00 +0100 Subject: [PATCH 5/6] Fix tests --- src/wasserstein.jl | 13 ++++++++----- test/test_dists.jl | 31 ++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/wasserstein.jl b/src/wasserstein.jl index 56fc8cc..adbea9c 100644 --- a/src/wasserstein.jl +++ b/src/wasserstein.jl @@ -17,10 +17,13 @@ end Wasserstein() = Wasserstein(1.0) -function (dist::Wasserstein)(a::AbstractVector, b::AbstractVector) +function (dist::Wasserstein)(a::AbstractArray{T}, b::AbstractArray{T}) where {T} @assert length(a) == length(b) - @assert sum(a) ≈ 1.0 atol=1e-6 - @assert sum(b) ≈ 1.0 atol=1e-6 + + isempty(a) && return zero(T) + + @assert isapprox(sum(a), 1.0, atol=1e-6) "sum(a) needs to be ~1 but is $(sum(a))" + @assert isapprox(sum(b), 1.0, atol=1e-6) "sum(b) needs to be ~1 but is $(sum(b))" model = make_wasserstein_model(a, b, dist.p) optimize!(model) @@ -34,7 +37,7 @@ Create JuMP `Model` for linear program to calculate the p-Wasserstein distance of two discrete vectors from the same probability simplex. See also formula (2.5) in [Optimal Transport on Discrete Domains](https://arxiv.org/abs/1801.07745). """ -function make_wasserstein_model(a::AbstractVector, b::AbstractVector, p::Float64) :: Model +function make_wasserstein_model(a::AbstractArray, b::AbstractArray, p::Float64) :: Model model = Model(with_optimizer(Cbc.Optimizer, logLevel=0)) N = length(a) @@ -67,4 +70,4 @@ function make_wasserstein_model(a::AbstractVector, b::AbstractVector, p::Float64 model end -wasserstein(a::AbstractVector, b::AbstractVector, p::Float64=1.0) = Wasserstein(p)(a, b) \ No newline at end of file +wasserstein(a::AbstractArray, b::AbstractArray, p::Float64=1.0) = Wasserstein(p)(a, b) \ No newline at end of file diff --git a/test/test_dists.jl b/test/test_dists.jl index ee38d1a..1a568f4 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -192,9 +192,9 @@ end let x′, y′ x′, y′ = positive_and_normed.([x, y]) - @test wasserstein(x′, y′) == 88.0 - @test wasserstein(x′, y′) != wasserstein(x′, y′, 2) - @test wasserstein(x′, y′, 2) == 89.0 + @test wasserstein(x′, y′) ≈ 0.471861471 atol=1e-6 + @test wasserstein(x′, y′) != wasserstein(x′, y′, 2.0) + @test wasserstein(x′, y′, 2.0) ≈ 0.68692173634 atol=1e-6 end @test_throws AssertionError wasserstein(x, y) @test_throws AssertionError Wasserstein(0.5) @@ -444,6 +444,19 @@ function test_colwise(dist, x, y, T) end end +function positive_and_normed_colwise(x, T) + rows, cols = size(x) + X = zeros(T, rows, cols) + @assert size(x) == size(X) + for i in 1:cols + normed_col = positive_and_normed(x[:, i]) + for j in 1:rows + X[j, i] = normed_col[j] + end + end + X +end + @testset "column-wise metrics on $T" for T in (Float64, F64) m = 5 n = 8 @@ -500,7 +513,11 @@ end test_colwise(SqMahalanobis(Q), X, Y, T) test_colwise(Mahalanobis(Q), X, Y, T) - test_colwise(Wasserstein(), X, Y, T) + let X′, Y′ + X′ = positive_and_normed_colwise(X, T) + Y′ = positive_and_normed_colwise(Y, T) + test_colwise(Wasserstein(), X′, Y′, T) + end end function test_pairwise(dist, x, y, T) @@ -578,7 +595,11 @@ end test_pairwise(SqMahalanobis(Q), X, Y, T) test_pairwise(Mahalanobis(Q), X, Y, T) - test_pairwise(Wasserstein, X, Y, T) + let X′, Y′ + X′ = positive_and_normed_colwise(X, T) + Y′ = positive_and_normed_colwise(Y, T) + test_pairwise(Wasserstein(), X′, Y′, T) + end end @testset "Euclidean precision" begin From fdfce1b776fb4de203e7dd0b0ef6ba7f98804d21 Mon Sep 17 00:00:00 2001 From: David Nies Date: Sun, 16 Feb 2020 01:59:44 +0100 Subject: [PATCH 6/6] Add notes to README.md for p-Wasserstein --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 1711927..209f6cc 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ This package also provides optimized functions to compute column-wise and pairwi * Normalized root mean squared deviation * Bray-Curtis dissimilarity * Bregman divergence +* Wasserstein distance For `Euclidean distance`, `Squared Euclidean distance`, `Cityblock distance`, `Minkowski distance`, and `Hamming distance`, a weighted version is also provided. @@ -173,6 +174,7 @@ Each distance corresponds to a distance type. The type name and the correspondin | WeightedMinkowski | `wminkowski(x, y, w, p)` | `sum(abs(x - y).^p .* w) ^ (1/p)` | | WeightedHamming | `whamming(x, y, w)` | `sum((x .!= y) .* w)` | | Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` | +| p-Wasserstein | `wasserstein(a, b, p)` | See (2,5) [here](https://arxiv.org/abs/1801.07745) | **Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way. The arguments `x` and `y` are arrays of real numbers; `k` and `l` are arrays of distinct elements of any kind; a and b are arrays of Bools; and finally, `p` and `q` are arrays forming a discrete probability distribution and are therefore both expected to sum to one. @@ -203,6 +205,11 @@ julia> pairwise(Euclidean(1e-12), x, x) 0.0 ``` +## Notes on Wasserstein distance + +The p-Wasserstein distances can only be calculated for values of the same probability simplex (i.e. non-negative real values with sum 1) + +The calculation of the p-Wasserstein distance contains the solution of a linear program in `N^2` variables. This metric is quite expensive to calculate. ## Benchmarks