Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Vectorization implementation for GPU #223

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
julia = "1"
ArrayInterface = "3.1.17"
johnnychen94 marked this conversation as resolved.
Show resolved Hide resolved
StatsAPI = "1"
julia = "1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down
1 change: 1 addition & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Distances

using ArrayInterface: device, AbstractDevice, GPU
using LinearAlgebra
using Statistics
import StatsAPI: pairwise, pairwise!
Expand Down
21 changes: 21 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ _eltype(::Type{Union{Missing, T}}) where {T} = Union{Missing, T}
__eltype(::Base.HasEltype, a) = _eltype(eltype(a))
__eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a)))


abstract type AbstractEvaluateStrategy end
struct Vectorization <: AbstractEvaluateStrategy end
struct ScalarMapReduce <: AbstractEvaluateStrategy end

# Infer the optimal evaluation strategy based on given array types and distance type.
function infer_evaluate_strategy(d::PreMetric, a, b)
da, db = device(a), device(b)
return _infer_evaluate_strategy(d::PreMetric, da, db)
end
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::AbstractDevice) = ScalarMapReduce()
# when one of the input are scalar types
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::Nothing) = ScalarMapReduce()
@inline _infer_evaluate_strategy(d::PreMetric, ::Nothing, ::AbstractDevice) = ScalarMapReduce()
@inline _infer_evaluate_strategy(d::PreMetric, ::Nothing, ::Nothing) = ScalarMapReduce()
# It is way slower to use scalar indexing if any of the given array is GPU array
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::GPU) = Vectorization()
@inline _infer_evaluate_strategy(d::PreMetric, ::GPU, ::AbstractDevice) = Vectorization()
@inline _infer_evaluate_strategy(d::PreMetric, ::GPU, ::GPU) = Vectorization()


# Generic column-wise evaluation

"""
Expand Down
42 changes: 29 additions & 13 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,26 @@ result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, ::Nothing) where {Ta,Tb}
result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} =
typeof(_evaluate(dist, oneunit(Ta), oneunit(Tb), oneunit(_eltype(p))))

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b)
_evaluate(d, a, b, parameters(d))
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p=parameters(d))
_evaluate(infer_evaluate_strategy(d, a, b), d, a, b, p)
end
for M in (metrics..., weightedmetrics...)
@eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b)
Comment on lines +226 to +227
Copy link
Contributor Author

@johnnychen94 johnnychen94 Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for loop is moved from L327-L329.

end

# breaks the implementation into eval_start, eval_op, eval_reduce and eval_end
function _evaluate(::Vectorization, d::UnionMetrics, a, b, ::Nothing)
map_op(x,y) = eval_op(d, x, y)
reduce_op(x, y) = eval_reduce(d, x, y)
eval_end(d, reduce(reduce_op, map_op.(a, b); init=eval_start(d, a, b)))
end
function _evaluate(::Vectorization, d::UnionMetrics, a, b, p)
map_op(x,y,p) = eval_op(d, x, y, p)
reduce_op(x, y) = eval_reduce(d, x, y)
eval_end(d, reduce(reduce_op, map_op.(a, b, p); init=eval_start(d, a, b)))
end

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a, b, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -239,7 +252,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing)
end
return eval_end(d, s)
end
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -263,7 +276,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
end
end

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a, b, p)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -279,7 +292,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p)
end
return eval_end(d, s)
end
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand Down Expand Up @@ -308,8 +321,8 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
end
end

_evaluate(dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
function _evaluate(dist::UnionMetrics, a::Number, b::Number, p)
_evaluate(::ScalarMapReduce, dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
function _evaluate(::ScalarMapReduce, dist::UnionMetrics, a::Number, b::Number, p)
length(p) != 1 && throw(DimensionMismatch("inputs are scalars but parameters have length $(length(p))."))
eval_end(dist, eval_op(dist, a, b, first(p)))
end
Expand All @@ -324,10 +337,6 @@ _eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} =
eval_reduce(::UnionMetrics, s1, s2) = s1 + s2
eval_end(::UnionMetrics, s) = s

for M in (metrics..., weightedmetrics...)
@eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b, parameters(dist))
end

# Euclidean
@inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi)
eval_end(::Euclidean, s) = sqrt(s)
Expand Down Expand Up @@ -373,7 +382,14 @@ totalvariation(a, b) = TotalVariation()(a, b)
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
# if only NaN, will output NaN
Base.@propagate_inbounds eval_start(::Chebyshev, a, b) = abs(first(a) - first(b))
Base.@propagate_inbounds function eval_start(d::Chebyshev, a, b)
T = result_type(d, a, b)
if any(isnan, a) || any(isnan, b)
return convert(T, NaN)
else
zero(T) # lower bound of chebyshev distance
end
end
Comment on lines +386 to +393
Copy link
Contributor Author

@johnnychen94 johnnychen94 Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rewritten to support CuArray; scalar indexing first is slow for CuArray.

chebyshev(a, b) = Chebyshev()(a, b)

# Minkowski
Expand Down
9 changes: 8 additions & 1 deletion test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ end

function test_metricity(dist, x, y, z)
@testset "Test metricity of $(typeof(dist))" begin
@test dist(x, y) == evaluate(dist, x, y)
d = dist(x, y)
@test d == evaluate(dist, x, y)
if d isa Distances.UnionMetrics
# currently only UnionMetrics supports this strategy trait
d_vec = Distances._evaluate(Distances.Vectorization(), dist, x, y, Distances.parameters(dist))
d_scalar = Distances._evaluate(Distances.ScalarMapReduce(), dist, x, y, Distances.parameters(dist))
@test d_vec ≈ d_scalar
end

dxy = dist(x, y)
dxz = dist(x, z)
Expand Down