Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sortperm with dims #2308

Merged
merged 9 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 97 additions & 36 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,43 @@ end
end
end

@inline function extraneous_block(vals :: AbstractArray, dims):: Bool
other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z
return other_linear_index > length(vals) ÷ size(vals)[dims]
end

@inline function extraneous_block(vals, dims) :: Bool
return extraneous_block(vals[1], dims)
end

# methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably speed up view by returning a CuDeviceArray when possible, just like we do with CuArray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting a change here or is this an idea for a separate change in the repo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate change.

@inline function view_along_dims(vals :: AbstractArray{T, 1}, dimsval::Val{1}) where T
return vals
end

@inline function view_along_dims(vals :: Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T
return vals[1], view_along_dims(vals[2], dimsval)
end


@inline function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims}
otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N)
other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z
other = CartesianIndices(otherdims)[other_linear_index]
# create a view that keeps the sorting dimension but indexes across the others
slicedims = map(Base.Slice, axes(vals))
idxs = ntuple(i->i==dims ? slicedims[i] : other[i], N)
return view(vals, idxs...)
end

@inline function view_along_dims(vals, dimsval::Val{dims}) where dims
return vals[1], view_along_dims(vals[2], dimsval)
end


# Functions specifically for "large" bitonic steps (those that cannot use shmem)

@inline function compare!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt,
rev) where {T,I}
@inline function compare!(vals::AbstractArray{T, N}, i1::I, i2::I, dir::Bool, by, lt, rev) where {T,I,N}
i1′, i2′ = i1 + one(I), i2 + one(I)
@inbounds if dir != rev_lt(by(vals[i1′]), by(vals[i2′]), lt, rev)
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
Expand Down Expand Up @@ -645,16 +678,22 @@ Note that to avoid synchronization issues, only one thread from each pair of
indices being swapped will actually move data.
"""
function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2,
rev) where {I,F1,F2}
rev, dimsval :: Val{dims}) where {I,F1,F2,dims}
if extraneous_block(vals, dims)
return nothing
end

index = (blockDim().x * (blockIdx().x - one(I))) + threadIdx().x - one(I)

slice = view_along_dims(vals, dimsval)

lo, n, dir = get_range(length_vals, index, k, j)

if !(lo < zero(I) || n < zero(I)) && !(index >= length_vals)
m = gp2lt(n)
if lo <= index < lo + n - m
i1, i2 = index, index + m
@inbounds compare!(vals, i1, i2, dir, by, lt, rev)
@inbounds compare!(slice, i1, i2, dir, by, lt, rev)
end
end
return
Expand Down Expand Up @@ -804,15 +843,19 @@ a unique range of indices corresponding to a comparator in the sorting network.
Note that this moves the array values copied within shmem, but doesn't copy them
back to global the way it does for indices.
"""
function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I,
by::F1, lt::F2, rev) where {I,F1,F2}
function comparator_small_kernel(vals, length_vals::I, k::I, j_0::I, j_f::I,
by::F1, lt::F2, rev, dimsval::Val{dims}) where {I,F1,F2,dims}
if extraneous_block(vals, dims)
return nothing
end
slice = view_along_dims(vals, dimsval)
pseudo_block_idx = (blockIdx().x - one(I)) * blockDim().y + threadIdx().y - one(I)
# immutable info about the range used by this kernel
_lo, _n, dir = block_range(length_c, pseudo_block_idx, k, j_0)
_lo, _n, dir = block_range(length_vals, pseudo_block_idx, k, j_0)
index = _lo + threadIdx().x - one(I)
in_range = (threadIdx().x <= _n && _lo >= zero(I))

swap = initialize_shmem!(c, index, in_range)
swap = initialize_shmem!(slice, index, in_range)

# mutable copies for pseudo-recursion
lo, n = _lo, _n
Expand All @@ -829,7 +872,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I,
sync_threads()
end

finalize_shmem!(c, swap, index, in_range)
finalize_shmem!(slice, swap, index, in_range)
return
end

Expand All @@ -849,20 +892,20 @@ of values and an index array for doing `sortperm!`. Cannot provide a stable
`sort!` although `sortperm!` is properly stable. To reverse, set `rev=true`
rather than `lt=!isless` (otherwise stability of sortperm breaks down).
"""
function bitonic_sort!(c; by = identity, lt = isless, rev = false)
c_len = if typeof(c) <: Tuple
length(c[1])
function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1)
c_len, otherdims_len = if typeof(c) <: Tuple
size(c[1])[dims], length(c[1]) ÷ size(c[1])[dims]
else
length(c)
size(c)[dims], length(c) ÷ size(c)[dims]
end

# compile kernels (using Int32 for indexing, if possible, yielding a 10% speedup)
# compile kernels (using Int32 for indexing, if possible, yielding a 70% speedup)
I = c_len <= typemax(Int32) ? Int32 : Int
args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev))
args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims))
kernel1 = @cuda launch=false comparator_small_kernel(args1...)

config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads))
threads1 = prevpow(2, config1.threads)
args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev))
args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev), Val(dims))
kernel2 = @cuda launch=false comparator_kernel(args2...)
config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads))
# blocksize for kernel2 MUST be a power of 2
Expand All @@ -877,9 +920,13 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
k0 = ceil(Int, log2(c_len))
for k = k0:-1:1
j_final = 1 + k0 - k

# non-sorting dims are put into blocks along grid y/z. Using sqrt minimizes wasted blocks
other_block_dims = Int(ceil(sqrt(otherdims_len))), Int(ceil(sqrt(otherdims_len)))

for j = 1:j_final
args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev))
args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev))
args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev), Val(dims))
args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev), Val(dims))
if k0 - k - j + 2 <= log_threads
# pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network)
pseudo_block_length = 1 << abs(j_final + 1 - j)
Expand All @@ -888,15 +935,14 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length

# grid dimensions
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block), other_block_dims...
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length

kernel1(args1...; blocks=N_blocks, threads=block_size,
shmem=bitonic_shmem(c, block_size))
break
else
N_blocks = cld(c_len, threads1)
kernel2(args2...; blocks = N_blocks, threads=threads1)
N_blocks = cld(c_len, threads2), other_block_dims...
kernel2(args2...; blocks = N_blocks, threads=threads2)
end
end
end
Expand Down Expand Up @@ -930,29 +976,35 @@ function Base.sort!(c::AnyCuVector, alg::QuickSortAlg; lt=isless, by=identity, r
return c
end

function Base.sort!(c::AnyCuVector, alg::BitonicSortAlg; kwargs...)
function Base.sort!(c::AnyCuArray, alg::BitonicSortAlg; kwargs...)
return bitonic_sort!(c; kwargs...)
end

function Base.sort!(c::AnyCuVector; alg :: SortingAlgorithm = BitonicSort, kwargs...)
function Base.sort!(c::AnyCuArray; alg :: SortingAlgorithm = BitonicSort, kwargs...)
return sort!(c, alg; kwargs...)
end

function Base.sort!(c::AnyCuArray; dims::Integer, lt=isless, by=identity, rev=false)
# for multi dim sorting, only quicksort is supported so no alg keyword
if rev
lt = !lt
function Base.sort(c::AnyCuArray; kwargs...)
return sort!(copy(c); kwargs...)
end

function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::BitonicSortAlg;
lt=isless, by=identity, rev=false)

function out(k :: OrdinalRange)
return copy(c[k])
end

quicksort!(c; lt, by, dims)
return c
end
# work around disallowed scalar index
maleadt marked this conversation as resolved.
Show resolved Hide resolved
function out(k :: Integer)
return Array(c[k:k])[1]
end

function Base.sort(c::AnyCuArray; kwargs...)
return sort!(copy(c); kwargs...)
sort!(c, alg=alg; lt, by, rev)
return out(k)
end

function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::QuickSortAlg;
lt=isless, by=identity, rev=false)
# for reverse sorting, invert the less-than function
if rev
Expand All @@ -972,6 +1024,10 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
return out(k)
end

function Base.partialsort!(c::AnyCuArray, k::Union{Integer, OrdinalRange}; alg::SortingAlgorithm=BitonicSort, kwargs...)
return partialsort!(c, k, alg; kwargs...)
end

function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs...)
return partialsort!(copy(c), k; kwargs...)
end
Expand All @@ -986,6 +1042,11 @@ function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwa
return ix
end

function Base.sortperm(c::AnyCuArray; kwargs...)
function Base.sortperm(c::AnyCuVector; kwargs...)
sortperm!(CuArray(1:length(c)), c; initialized=true, kwargs...)
end

function Base.sortperm(c::AnyCuArray; dims, kwargs...)
# Base errors for Matrices without dims arg, we should too
sortperm!(reshape(CuArray(1:length(c)), size(c)), c; initialized=true, dims=dims, kwargs...)
maleadt marked this conversation as resolved.
Show resolved Hide resolved
end
9 changes: 8 additions & 1 deletion test/base/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ end

# multiple dimensions
@test check_sort!(Int32, (4, 50000, 4); dims=2)
@test check_sort!(Int32, (4, 4, 50000); dims=3, rev=true)
@test check_sort!(Int32, (2, 2, 50000); dims=3, rev=true)

# large sizes
@test check_sort!(Float32, 2^25; alg=CUDA.QuickSort)
Expand Down Expand Up @@ -389,6 +389,13 @@ end
@test check_sortperm(Float64, 1000000; rev=true)
@test check_sortperm(Float64, 1000000; by=x->abs(x-0.5))
@test check_sortperm(Float64, 1000000; rev=true, by=x->abs(x-0.5))

if VERSION >= v"1.9"
# Base.jl didn't implement sortperm(;dims) until 1.9
@test check_sortperm(Float32, (100_000, 16); dims=1)
@test check_sortperm(Float32, (100_000, 16); dims=2)
@test check_sortperm(Float32, (100, 256, 256); dims=1)
end
# check with Int32 indices
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000)
# `initialized` kwarg
Expand Down