-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU comaptible pairwisel2 #11
Comments
nvm, i did not see all the todays commits |
oh, nice! maybe this would be faster than the current implementation? |
Actually, the way it is written right now is faster, probably because of scalar operations when using pairwise. using Distances
using CuArrays
using BenchmarkTools
function pairwise_gpu_compatible(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = Distances.deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
m = size(a, dims)
n = size(b, dims)
r = similar(a, result_type(metric, a, b), m, n)
pairwise!(r, metric, a, b, dims=dims)
end
pairwisel2_n(x,y) = pairwise_gpu_compatible(SqEuclidean(), x, y, dims=2)
pairwisel2(x::CuMatrix, y::CuMatrix) = -2 .* x' * y .+ sum(x.^2, dims=1)' .+ sum(y.^2,dims=1)
x = CuArray(randn(50,100));
y = CuArray(randn(50,100));
@benchmark pairwisel2_n($x,$y)
BenchmarkTools.Trial:
memory estimate: 3.68 MiB
allocs estimate: 90590
--------------
minimum time: 189.294 ms (0.00% GC)
median time: 228.505 ms (0.00% GC)
mean time: 225.154 ms (0.11% GC)
maximum time: 282.577 ms (0.00% GC)
--------------
samples: 23
evals/sample: 1
@benchmark pairwisel2($x,$y)
BenchmarkTools.Trial:
memory estimate: 22.13 KiB
allocs estimate: 581
--------------
minimum time: 83.320 μs (0.00% GC)
median time: 166.189 μs (0.00% GC)
mean time: 181.380 μs (2.42% GC)
maximum time: 19.423 ms (47.59% GC)
--------------
samples: 10000
evals/sample: 1
|
woah, nice! thank you for checking! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
pairwisel2
is now not GPU compatible because ofpairwise
from Distances. We could fix it (see https://github.com/JuliaStas/Distances.jl/pull/142) byThen there will still be the issue with potentially slow scalar operations on GPU, but that can probably be fixed.
The text was updated successfully, but these errors were encountered: