Skip to content

Commit

Permalink
partialsort by either alg
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Ellison committed Apr 6, 2024
1 parent be35ba8 commit 54a5ff9
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1)
size(c)[dims], length(c) ÷ size(c)[dims]
end

# compile kernels (using Int32 for indexing, if possible, yielding a 10% speedup)
# compile kernels (using Int32 for indexing, if possible, yielding a 70% speedup)
I = c_len <= typemax(Int32) ? Int32 : Int
args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims))
kernel1 = @cuda launch=false comparator_small_kernel(args1...)
Expand Down Expand Up @@ -989,7 +989,7 @@ function Base.sort(c::AnyCuArray; kwargs...)
end

function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
lt=isless, by=identity, rev=false)
lt=isless, by=identity, rev=false, alg::QuickSortAlg)
# for reverse sorting, invert the less-than function
if rev
lt = !lt
Expand All @@ -1008,6 +1008,22 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
return out(k)
end

function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
lt=isless, by=identity, rev=false, alg::SortingAlgorithm=BitonicSort)

function out(k :: OrdinalRange)
return copy(c[k])
end

# work around disallowed scalar index
function out(k :: Integer)
return Array(c[k:k])[1]
end

sort!(c, alg=alg; lt, by, rev)
return out(k)
end

function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs...)
return partialsort!(copy(c), k; kwargs...)
end
Expand Down

0 comments on commit 54a5ff9

Please sign in to comment.