From 47f00bc78b7d3ef87248317855b4c734a1baaa94 Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Sun, 31 Mar 2024 13:37:05 -0400 Subject: [PATCH 1/9] sortperm with dims --- src/sorting.jl | 110 +++++++++++++++++++++++++++++-------------- test/base/sorting.jl | 5 ++ 2 files changed, 80 insertions(+), 35 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index eeb1206936..9985f11def 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -578,10 +578,43 @@ end end end +function extraneous_block(vals :: AbstractArray, dims):: Bool + other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z + return other_linear_index > length(vals) ÷ size(vals)[dims] +end + +function extraneous_block(vals, dims) :: Bool + return extraneous_block(vals[1], dims) +end + +# methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays +function view_along_dims(vals :: AbstractArray, dimsval::Val{1}) + return vals +end + +function view_along_dims(vals, dimsval::Val{1}) + return vals[1], view_along_dims(vals[2], dimsval) +end + + +function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} + otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N) + other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z + other = CartesianIndices(otherdims)[other_linear_index] + # create a view that keeps the sorting dimension but indexes across the others + slicedims = map(Base.Slice, axes(vals)) + idxs = ntuple(i->i==dims ? slicedims[i] : other[i], N) + return view(vals, idxs...) +end + +function view_along_dims(vals, dimsval::Val{dims}) where dims + return vals[1], view_along_dims(vals[2], dimsval) +end + + # Functions specifically for "large" bitonic steps (those that cannot use shmem) -@inline function compare!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt, - rev) where {T,I} +@inline function compare!(vals::AbstractArray{T, N}, i1::I, i2::I, dir::Bool, by, lt, rev) where {T,I,N} i1′, i2′ = i1 + one(I), i2 + one(I) @inbounds if dir != rev_lt(by(vals[i1′]), by(vals[i2′]), lt, rev) vals[i1′], vals[i2′] = vals[i2′], vals[i1′] @@ -645,16 +678,22 @@ Note that to avoid synchronization issues, only one thread from each pair of indices being swapped will actually move data. """ function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2, - rev) where {I,F1,F2} + rev, dimsval :: Val{dims}) where {I,F1,F2,dims} + if extraneous_block(vals, dims) + return nothing + end + index = (blockDim().x * (blockIdx().x - one(I))) + threadIdx().x - one(I) + slice = view_along_dims(vals, dimsval) + lo, n, dir = get_range(length_vals, index, k, j) if !(lo < zero(I) || n < zero(I)) && !(index >= length_vals) m = gp2lt(n) if lo <= index < lo + n - m i1, i2 = index, index + m - @inbounds compare!(vals, i1, i2, dir, by, lt, rev) + @inbounds compare!(slice, i1, i2, dir, by, lt, rev) end end return @@ -804,15 +843,19 @@ a unique range of indices corresponding to a comparator in the sorting network. Note that this moves the array values copied within shmem, but doesn't copy them back to global the way it does for indices. """ -function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, - by::F1, lt::F2, rev) where {I,F1,F2} +function comparator_small_kernel(vals, length_vals::I, k::I, j_0::I, j_f::I, + by::F1, lt::F2, rev, dimsval::Val{dims}) where {I,F1,F2,dims} + if extraneous_block(vals, dims) + return nothing + end + slice = view_along_dims(vals, dimsval) pseudo_block_idx = (blockIdx().x - one(I)) * blockDim().y + threadIdx().y - one(I) # immutable info about the range used by this kernel - _lo, _n, dir = block_range(length_c, pseudo_block_idx, k, j_0) + _lo, _n, dir = block_range(length_vals, pseudo_block_idx, k, j_0) index = _lo + threadIdx().x - one(I) in_range = (threadIdx().x <= _n && _lo >= zero(I)) - swap = initialize_shmem!(c, index, in_range) + swap = initialize_shmem!(slice, index, in_range) # mutable copies for pseudo-recursion lo, n = _lo, _n @@ -829,7 +872,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, sync_threads() end - finalize_shmem!(c, swap, index, in_range) + finalize_shmem!(slice, swap, index, in_range) return end @@ -849,20 +892,19 @@ of values and an index array for doing `sortperm!`. Cannot provide a stable `sort!` although `sortperm!` is properly stable. To reverse, set `rev=true` rather than `lt=!isless` (otherwise stability of sortperm breaks down). """ -function bitonic_sort!(c; by = identity, lt = isless, rev = false) - c_len = if typeof(c) <: Tuple - length(c[1]) +function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1) + c_len, otherdims_len = if typeof(c) <: Tuple + size(c[1])[dims], length(c[1]) ÷ size(c[1])[dims] else - length(c) + size(c)[dims], length(c) ÷ size(c)[dims] end # compile kernels (using Int32 for indexing, if possible, yielding a 10% speedup) I = c_len <= typemax(Int32) ? Int32 : Int - args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev)) + args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel1 = @cuda launch=false comparator_small_kernel(args1...) config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads)) - threads1 = prevpow(2, config1.threads) - args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev)) + args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel2 = @cuda launch=false comparator_kernel(args2...) config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads)) # blocksize for kernel2 MUST be a power of 2 @@ -877,9 +919,13 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false) k0 = ceil(Int, log2(c_len)) for k = k0:-1:1 j_final = 1 + k0 - k + + # non-sorting dims are put into blocks along grid y/z. Using sqrt minimizes wasted blocks + other_block_dims = Int(ceil(sqrt(otherdims_len))), Int(ceil(sqrt(otherdims_len))) + for j = 1:j_final - args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev)) - args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev)) + args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev), Val(dims)) + args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev), Val(dims)) if k0 - k - j + 2 <= log_threads # pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network) pseudo_block_length = 1 << abs(j_final + 1 - j) @@ -888,15 +934,14 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false) pseudo_blocks_per_block = threads2 ÷ pseudo_block_length # grid dimensions - N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block) + N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block), other_block_dims... block_size = pseudo_block_length, threads2 ÷ pseudo_block_length - kernel1(args1...; blocks=N_blocks, threads=block_size, shmem=bitonic_shmem(c, block_size)) break else - N_blocks = cld(c_len, threads1) - kernel2(args2...; blocks = N_blocks, threads=threads1) + N_blocks = cld(c_len, threads2), other_block_dims... + kernel2(args2...; blocks = N_blocks, threads=threads2) end end end @@ -930,24 +975,14 @@ function Base.sort!(c::AnyCuVector, alg::QuickSortAlg; lt=isless, by=identity, r return c end -function Base.sort!(c::AnyCuVector, alg::BitonicSortAlg; kwargs...) +function Base.sort!(c::AnyCuArray, alg::BitonicSortAlg; kwargs...) return bitonic_sort!(c; kwargs...) end -function Base.sort!(c::AnyCuVector; alg :: SortingAlgorithm = BitonicSort, kwargs...) +function Base.sort!(c::AnyCuArray; alg :: SortingAlgorithm = BitonicSort, kwargs...) return sort!(c, alg; kwargs...) end -function Base.sort!(c::AnyCuArray; dims::Integer, lt=isless, by=identity, rev=false) - # for multi dim sorting, only quicksort is supported so no alg keyword - if rev - lt = !lt - end - - quicksort!(c; lt, by, dims) - return c -end - function Base.sort(c::AnyCuArray; kwargs...) return sort!(copy(c); kwargs...) end @@ -986,6 +1021,11 @@ function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwa return ix end -function Base.sortperm(c::AnyCuArray; kwargs...) +function Base.sortperm(c::AnyCuVector; kwargs...) sortperm!(CuArray(1:length(c)), c; initialized=true, kwargs...) end + +function Base.sortperm(c::AnyCuArray; dims, kwargs...) + # Base errors for Matrices without dims arg, we should too + sortperm!(reshape(CuArray(1:length(c)), size(c)), c; initialized=true, dims=dims, kwargs...) +end diff --git a/test/base/sorting.jl b/test/base/sorting.jl index 69d025d962..bfd51bf55c 100644 --- a/test/base/sorting.jl +++ b/test/base/sorting.jl @@ -389,6 +389,11 @@ end @test check_sortperm(Float64, 1000000; rev=true) @test check_sortperm(Float64, 1000000; by=x->abs(x-0.5)) @test check_sortperm(Float64, 1000000; rev=true, by=x->abs(x-0.5)) + + @test check_sortperm(Float32, (100_000, 16); dims=1) + @test check_sortperm(Float32, (100_000, 16); dims=2) + @test check_sortperm(Float32, (100, 256, 256); dims=1) + # check with Int32 indices @test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000) # `initialized` kwarg From f52dc421d1f6b4c1f829c58b9887b090254c0681 Mon Sep 17 00:00:00 2001 From: xaellison Date: Sun, 31 Mar 2024 16:31:00 -0400 Subject: [PATCH 2/9] delete view_along_dims methods --- src/sorting.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index 9985f11def..9d3b05a723 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -587,15 +587,6 @@ function extraneous_block(vals, dims) :: Bool return extraneous_block(vals[1], dims) end -# methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays -function view_along_dims(vals :: AbstractArray, dimsval::Val{1}) - return vals -end - -function view_along_dims(vals, dimsval::Val{1}) - return vals[1], view_along_dims(vals[2], dimsval) -end - function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N) From 380d2ce0b860b6fae3a196dc1a9f834f364418d3 Mon Sep 17 00:00:00 2001 From: xaellison Date: Sun, 31 Mar 2024 16:53:09 -0400 Subject: [PATCH 3/9] restore with fixed signatures --- src/sorting.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/sorting.jl b/src/sorting.jl index 9d3b05a723..3a701f7fb5 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -587,6 +587,15 @@ function extraneous_block(vals, dims) :: Bool return extraneous_block(vals[1], dims) end +# methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays +function view_along_dims(vals :: AbstractArray{T, 1}, dimsval::Val{1}) where T + return vals +end + +function view_along_dims(vals :: Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T + return vals[1], view_along_dims(vals[2], dimsval) +end + function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N) @@ -894,6 +903,7 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1) I = c_len <= typemax(Int32) ? Int32 : Int args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel1 = @cuda launch=false comparator_small_kernel(args1...) + config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads)) args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel2 = @cuda launch=false comparator_kernel(args2...) From 781425bb85c0cc026c1a414c22543149c5e03857 Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Sun, 31 Mar 2024 21:15:04 -0400 Subject: [PATCH 4/9] dont test sortperm dims below j1.9 --- test/base/sorting.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/base/sorting.jl b/test/base/sorting.jl index bfd51bf55c..f946ef29b1 100644 --- a/test/base/sorting.jl +++ b/test/base/sorting.jl @@ -390,10 +390,12 @@ end @test check_sortperm(Float64, 1000000; by=x->abs(x-0.5)) @test check_sortperm(Float64, 1000000; rev=true, by=x->abs(x-0.5)) - @test check_sortperm(Float32, (100_000, 16); dims=1) - @test check_sortperm(Float32, (100_000, 16); dims=2) - @test check_sortperm(Float32, (100, 256, 256); dims=1) - + if VERSION >= v"1.9" + # Base.jl didn't implement sortperm(;dims) until 1.9 + @test check_sortperm(Float32, (100_000, 16); dims=1) + @test check_sortperm(Float32, (100_000, 16); dims=2) + @test check_sortperm(Float32, (100, 256, 256); dims=1) + end # check with Int32 indices @test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000) # `initialized` kwarg From 6e95a1310ad2e97f5fffbbf6b061158e4eae8f93 Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Mon, 1 Apr 2024 10:10:43 -0400 Subject: [PATCH 5/9] free speed --- src/sorting.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index 3a701f7fb5..e45afb06b5 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -578,26 +578,26 @@ end end end -function extraneous_block(vals :: AbstractArray, dims):: Bool +@inline function extraneous_block(vals :: AbstractArray, dims):: Bool other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z return other_linear_index > length(vals) ÷ size(vals)[dims] end -function extraneous_block(vals, dims) :: Bool +@inline function extraneous_block(vals, dims) :: Bool return extraneous_block(vals[1], dims) end # methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays -function view_along_dims(vals :: AbstractArray{T, 1}, dimsval::Val{1}) where T +@inline function view_along_dims(vals :: AbstractArray{T, 1}, dimsval::Val{1}) where T return vals end -function view_along_dims(vals :: Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T +@inline function view_along_dims(vals :: Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T return vals[1], view_along_dims(vals[2], dimsval) end -function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} +@inline function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N) other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z other = CartesianIndices(otherdims)[other_linear_index] @@ -607,7 +607,7 @@ function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,di return view(vals, idxs...) end -function view_along_dims(vals, dimsval::Val{dims}) where dims +@inline function view_along_dims(vals, dimsval::Val{dims}) where dims return vals[1], view_along_dims(vals[2], dimsval) end From 10f430e348d08044f46d71a1d0f235a0b6647552 Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Sat, 6 Apr 2024 10:24:57 -0400 Subject: [PATCH 6/9] partialsort by either alg --- src/sorting.jl | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index e45afb06b5..61b5edaad3 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -899,7 +899,7 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1) size(c)[dims], length(c) ÷ size(c)[dims] end - # compile kernels (using Int32 for indexing, if possible, yielding a 10% speedup) + # compile kernels (using Int32 for indexing, if possible, yielding a 70% speedup) I = c_len <= typemax(Int32) ? Int32 : Int args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel1 = @cuda launch=false comparator_small_kernel(args1...) @@ -989,7 +989,7 @@ function Base.sort(c::AnyCuArray; kwargs...) end function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; - lt=isless, by=identity, rev=false) + lt=isless, by=identity, rev=false, alg::QuickSortAlg) # for reverse sorting, invert the less-than function if rev lt = !lt @@ -1008,6 +1008,22 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; return out(k) end +function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; + lt=isless, by=identity, rev=false, alg::SortingAlgorithm=BitonicSort) + + function out(k :: OrdinalRange) + return copy(c[k]) + end + + # work around disallowed scalar index + function out(k :: Integer) + return Array(c[k:k])[1] + end + + sort!(c, alg=alg; lt, by, rev) + return out(k) +end + function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs...) return partialsort!(copy(c), k; kwargs...) end From 3de239a1cc2d377f45c37f2e17aaebf1b42ce572 Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Sat, 6 Apr 2024 16:35:40 -0400 Subject: [PATCH 7/9] dispatch fix: locally passes 1.8 and 1.10 --- src/sorting.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index 61b5edaad3..f78ceedb98 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -988,12 +988,8 @@ function Base.sort(c::AnyCuArray; kwargs...) return sort!(copy(c); kwargs...) end -function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; - lt=isless, by=identity, rev=false, alg::QuickSortAlg) - # for reverse sorting, invert the less-than function - if rev - lt = !lt - end +function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::BitonicSortAlg; + lt=isless, by=identity, rev=false) function out(k :: OrdinalRange) return copy(c[k]) @@ -1004,12 +1000,16 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; return Array(c[k:k])[1] end - quicksort!(c; lt, by, dims=1, partial_k=k) + sort!(c, alg=alg; lt, by, rev) return out(k) end -function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; - lt=isless, by=identity, rev=false, alg::SortingAlgorithm=BitonicSort) +function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::QuickSortAlg; + lt=isless, by=identity, rev=false) + # for reverse sorting, invert the less-than function + if rev + lt = !lt + end function out(k :: OrdinalRange) return copy(c[k]) @@ -1020,10 +1020,14 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}; return Array(c[k:k])[1] end - sort!(c, alg=alg; lt, by, rev) + quicksort!(c; lt, by, dims=1, partial_k=k) return out(k) end +function Base.partialsort!(c::AnyCuArray, k::Union{Integer, OrdinalRange}; alg::SortingAlgorithm=BitonicSort, kwargs...) + return partialsort!(c, k, alg; kwargs...) +end + function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs...) return partialsort!(copy(c), k; kwargs...) end From 555b5dbe42bac5bef806ef8301b6cb44dc502a1c Mon Sep 17 00:00:00 2001 From: Alex Ellison Date: Sat, 6 Apr 2024 17:15:46 -0400 Subject: [PATCH 8/9] reduce resources needed for heisenbug test case --- test/base/sorting.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/sorting.jl b/test/base/sorting.jl index f946ef29b1..7c42b942f4 100644 --- a/test/base/sorting.jl +++ b/test/base/sorting.jl @@ -315,7 +315,7 @@ end # multiple dimensions @test check_sort!(Int32, (4, 50000, 4); dims=2) - @test check_sort!(Int32, (4, 4, 50000); dims=3, rev=true) + @test check_sort!(Int32, (2, 2, 50000); dims=3, rev=true) # large sizes @test check_sort!(Float32, 2^25; alg=CUDA.QuickSort) From 6f5fdd00e573fca575f69c6989ad62ab7a120e10 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 9 Apr 2024 14:18:24 +0200 Subject: [PATCH 9/9] Clean-ups and formatting. --- src/sorting.jl | 80 ++++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index f78ceedb98..7dd5638319 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -69,7 +69,8 @@ affected by `parity`. See `flex_lt`. `swap` is an array for exchanging values and `sums` is an array of Ints used during the merge sort. Uses block y index to decide which values to operate on. """ -@inline function batch_partition(values, pivot, swap, sums, lo, hi, parity, lt::F1, by::F2) where {F1,F2} +@inline function batch_partition(values, pivot, swap, sums, lo, hi, parity, + lt::F1, by::F2) where {F1,F2} sync_threads() blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x @@ -88,7 +89,11 @@ Uses block y index to decide which values to operate on. cumsum!(sums) @inbounds if idx0 <= hi - dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x] + dest_idx = @inbounds if comparison + blockDim().x - sums[end] + sums[threadIdx().x] + else + threadIdx().x - sums[threadIdx().x] + end if dest_idx <= length(swap) swap[dest_idx] = val end @@ -211,7 +216,7 @@ Finds the median of `vals` starting after `lo` and going for `blockDim().x` elements spaced by `stride`. Performs bitonic sort in shmem, returns middle value. Faster than bubble sort, but not as flexible. Does not modify `vals` """ -function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, by::F2) where {T,F1,F2} +function bitonic_median(vals::AbstractArray{T}, swap, lo, L, stride, lt::F1, by::F2) where {T,F1,F2} sync_threads() bitonic_lt(i1, i2) = @inbounds flex_lt(swap[i1 + 1], swap[i2 + 1], false, lt, by) @@ -337,7 +342,7 @@ Quicksort recursion condition For a full sort, `partial` is nothing so it shouldn't affect whether recursion happens. """ -function partial_range_overlap(lo, hi, partial :: Nothing) +function partial_range_overlap(lo, hi, partial::Nothing) true end @@ -374,7 +379,8 @@ it's possible that the first pivot will be that value, which could lead to an in early end to recursion if we started `stuck` at 0. """ function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sync_depth, - prev_pivot, lt::F1, by::F2, ::Val{dims}, partial=nothing, stuck=-1) where {T, N, S, F1, F2, dims} + prev_pivot, lt::F1, by::F2, ::Val{dims}, partial=nothing, + stuck=-1) where {T, N, S, F1, F2, dims} b_sums = CuDynamicSharedArray(Int, blockDim().x) swap = CuDynamicSharedArray(T, blockDim().x, sizeof(b_sums)) shmem = sizeof(b_sums) + sizeof(swap) @@ -449,7 +455,7 @@ function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sy return end -function sort_args(args, partial_k :: Nothing) +function sort_args(args, partial_k::Nothing) return args end @@ -578,26 +584,26 @@ end end end -@inline function extraneous_block(vals :: AbstractArray, dims):: Bool +@inline function extraneous_block(vals::AbstractArray, dims):: Bool other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z return other_linear_index > length(vals) ÷ size(vals)[dims] end -@inline function extraneous_block(vals, dims) :: Bool +@inline function extraneous_block(vals, dims)::Bool return extraneous_block(vals[1], dims) end # methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays -@inline function view_along_dims(vals :: AbstractArray{T, 1}, dimsval::Val{1}) where T +@inline function view_along_dims(vals::AbstractArray{T, 1}, dimsval::Val{1}) where T return vals end -@inline function view_along_dims(vals :: Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T +@inline function view_along_dims(vals::Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T return vals[1], view_along_dims(vals[2], dimsval) end -@inline function view_along_dims(vals :: AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} +@inline function view_along_dims(vals::AbstractArray{T, N}, ::Val{dims}) where {T,N,dims} otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N) other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z other = CartesianIndices(otherdims)[other_linear_index] @@ -626,7 +632,8 @@ end i1′, i2′ = i1 + one(I), i2 + one(I) vals, inds = vals_inds # comparing tuples of (value, index) guarantees stability of sort - @inbounds if dir != rev_lt((by(vals[inds[i1′]]), inds[i1′]), (by(vals[inds[i2′]]), inds[i2′]), lt, rev) + @inbounds if dir != rev_lt((by(vals[inds[i1′]]), inds[i1′]), + (by(vals[inds[i2′]]), inds[i2′]), lt, rev) inds[i1′], inds[i2′] = inds[i2′], inds[i1′] end end @@ -678,11 +685,11 @@ Note that to avoid synchronization issues, only one thread from each pair of indices being swapped will actually move data. """ function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2, - rev, dimsval :: Val{dims}) where {I,F1,F2,dims} + rev, dimsval::Val{dims}) where {I,F1,F2,dims} if extraneous_block(vals, dims) return nothing end - + index = (blockDim().x * (blockIdx().x - one(I))) + threadIdx().x - one(I) slice = view_along_dims(vals, dimsval) @@ -903,7 +910,7 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1) I = c_len <= typemax(Int32) ? Int32 : Int args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel1 = @cuda launch=false comparator_small_kernel(args1...) - + config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads)) args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev), Val(dims)) kernel2 = @cuda launch=false comparator_kernel(args2...) @@ -922,13 +929,14 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1) j_final = 1 + k0 - k # non-sorting dims are put into blocks along grid y/z. Using sqrt minimizes wasted blocks - other_block_dims = Int(ceil(sqrt(otherdims_len))), Int(ceil(sqrt(otherdims_len))) + other_block_dims = Int(ceil(sqrt(otherdims_len))), Int(ceil(sqrt(otherdims_len))) for j = 1:j_final args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev), Val(dims)) args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev), Val(dims)) if k0 - k - j + 2 <= log_threads - # pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network) + # pseudo_block_length = max(nextpow(2, length(comparator)) + # for all comparators in this layer of the network) pseudo_block_length = 1 << abs(j_final + 1 - j) # N_pseudo_blocks = how many pseudo-blocks are in this layer of the network N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length @@ -980,7 +988,7 @@ function Base.sort!(c::AnyCuArray, alg::BitonicSortAlg; kwargs...) return bitonic_sort!(c; kwargs...) end -function Base.sort!(c::AnyCuArray; alg :: SortingAlgorithm = BitonicSort, kwargs...) +function Base.sort!(c::AnyCuArray; alg::SortingAlgorithm = BitonicSort, kwargs...) return sort!(c, alg; kwargs...) end @@ -988,35 +996,26 @@ function Base.sort(c::AnyCuArray; kwargs...) return sort!(copy(c); kwargs...) end -function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::BitonicSortAlg; - lt=isless, by=identity, rev=false) - - function out(k :: OrdinalRange) - return copy(c[k]) - end +function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, + alg::BitonicSortAlg; lt=isless, by=identity, rev=false) - # work around disallowed scalar index - function out(k :: Integer) - return Array(c[k:k])[1] - end - - sort!(c, alg=alg; lt, by, rev) - return out(k) + sort!(c, alg; lt, by, rev) + return @allowscalar copy(c[k]) end -function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg::QuickSortAlg; - lt=isless, by=identity, rev=false) +function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, + alg::QuickSortAlg; lt=isless, by=identity, rev=false) # for reverse sorting, invert the less-than function if rev lt = !lt end - function out(k :: OrdinalRange) + function out(k::OrdinalRange) return copy(c[k]) end # work around disallowed scalar index - function out(k :: Integer) + function out(k::Integer) return Array(c[k:k])[1] end @@ -1024,7 +1023,8 @@ function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange}, alg: return out(k) end -function Base.partialsort!(c::AnyCuArray, k::Union{Integer, OrdinalRange}; alg::SortingAlgorithm=BitonicSort, kwargs...) +function Base.partialsort!(c::AnyCuArray, k::Union{Integer, OrdinalRange}; + alg::SortingAlgorithm=BitonicSort, kwargs...) return partialsort!(c, k, alg; kwargs...) end @@ -1032,8 +1032,10 @@ function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs return partialsort!(copy(c), k; kwargs...) end -function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwargs...) where T - axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))")) +function Base.sortperm!(ix::AnyCuArray, A::AnyCuArray; initialized=false, kwargs...) + if axes(ix) != axes(A) + throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))")) + end if !initialized ix .= LinearIndices(A) @@ -1048,5 +1050,5 @@ end function Base.sortperm(c::AnyCuArray; dims, kwargs...) # Base errors for Matrices without dims arg, we should too - sortperm!(reshape(CuArray(1:length(c)), size(c)), c; initialized=true, dims=dims, kwargs...) + sortperm!(reshape(CuArray(1:length(c)), size(c)), c; initialized=true, dims, kwargs...) end