Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
f024f46
Move segmented sort kernels to separate header
NaderAlAwar Aug 11, 2025
ca729cf
Remove unused code
NaderAlAwar Aug 11, 2025
7e875c8
Move continuation back to dispatch file
NaderAlAwar Aug 12, 2025
fbcbc83
Begin working on dynamic dispatch for segmented sort
NaderAlAwar Aug 12, 2025
6be7b9d
Fix compilation errors
NaderAlAwar Aug 13, 2025
9ca779b
Add initial segmented sort c parallel implementation
NaderAlAwar Aug 14, 2025
8382203
Add error checks
NaderAlAwar Aug 14, 2025
233a73a
Make segment selector operator() device only (for c.parallel) and rev…
NaderAlAwar Aug 14, 2025
3a6da74
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 14, 2025
2f579ab
move three way partition kernels to separate header since they are us…
NaderAlAwar Aug 14, 2025
d760737
Enable dynamic cub dispatch in three way partition
NaderAlAwar Aug 14, 2025
93d5fda
Move three way partition kernels to separate file
NaderAlAwar Aug 14, 2025
48b0996
Call partition through dispatch and add template params for partition…
NaderAlAwar Aug 14, 2025
85305ea
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 14, 2025
cce3899
Various compilation fixes
NaderAlAwar Aug 14, 2025
b4c720d
Begin work to reuse CUB policies in c.parallel
NaderAlAwar Aug 14, 2025
ed7d6aa
Add encoded policy method
NaderAlAwar Aug 14, 2025
a72278a
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 14, 2025
888b7d7
Merge branch 'segmented-sort-c-parallel' of github.com:NaderAlAwar/cc…
NaderAlAwar Aug 14, 2025
c210f84
Fix other call to device partition
NaderAlAwar Aug 14, 2025
b8a7146
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 14, 2025
e4f7748
Fix issue with tuning policies
NaderAlAwar Aug 14, 2025
c725a18
Add missing runtime policies and use advance iterators function inste…
NaderAlAwar Aug 14, 2025
3a69fac
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 14, 2025
96194dc
Begin fixing runtime policy in segmented sort
NaderAlAwar Aug 18, 2025
eb3c174
Fix missing comma
NaderAlAwar Aug 18, 2025
0b4ba62
Split SmallAndMediumSegmentedSortPolicy into separate small and mediu…
NaderAlAwar Aug 18, 2025
21c3832
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Aug 18, 2025
57764c6
Continue fixing runtime segmented sort policies
NaderAlAwar Aug 19, 2025
01cc9d1
Fix compilation errors
NaderAlAwar Sep 3, 2025
fa60f52
Change policies to make them work from c.parallel
NaderAlAwar Sep 3, 2025
bb8eb8f
Add one more level for preprocessor for each
NaderAlAwar Sep 3, 2025
a63ebfd
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 3, 2025
750c3fc
Account for double buffers and sort order and use test tuple params s…
NaderAlAwar Sep 3, 2025
5fddefe
Add missing include
NaderAlAwar Sep 3, 2025
a9582c7
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 3, 2025
bfb47e5
Pass partition max policy to three way partition dispatch
NaderAlAwar Sep 4, 2025
5a20e6e
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 4, 2025
389f9b3
Move streaming_context_t and other typedefs to three way partition ke…
NaderAlAwar Sep 4, 2025
5842b63
Encode delay contructor info into ptx-json. This requires defining ne…
NaderAlAwar Sep 4, 2025
3888b8d
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 4, 2025
ce9ee5e
Get key size from kernel source instead of sizeof directly
NaderAlAwar Sep 4, 2025
797caad
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 4, 2025
11a8007
Move segmented test utils to common header
NaderAlAwar Sep 4, 2025
5bae494
Add three way partition kernels and policy and implement KeySize
NaderAlAwar Sep 4, 2025
c75d306
Use common utils from test_util
NaderAlAwar Sep 4, 2025
04a4a6d
Delete partition tuning policy
NaderAlAwar Sep 5, 2025
686ea77
Make parameter tuple member functions constexpr
NaderAlAwar Sep 6, 2025
be17ec0
Rename row_size to segment_size and fix error in key value pair corre…
NaderAlAwar Sep 6, 2025
c8f3d5f
Allow passing in custom types as items and pass segment selectors thr…
NaderAlAwar Sep 10, 2025
d2b22bb
Expand testing of segmented sort
NaderAlAwar Sep 10, 2025
5dd308f
Pass segment selectors through kernel source
NaderAlAwar Sep 10, 2025
7f0a001
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 10, 2025
18c9a48
Merge branch 'main' into segmented-sort-c-parallel
NaderAlAwar Sep 10, 2025
6162d17
remove merge leftovers
NaderAlAwar Sep 10, 2025
773119c
Pass large and small selector ops through kernel source. This is cons…
NaderAlAwar Sep 12, 2025
56227c2
Merge branch 'main' into segmented-sort-dynamic-cub-dispatch
NaderAlAwar Sep 12, 2025
9e2cb0f
Fix merge leftover and set offset through kernel source
NaderAlAwar Sep 12, 2025
d69a4bf
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 12, 2025
d584bdc
Clean up segmented sort c parallel tests
NaderAlAwar Sep 13, 2025
d657f16
Implement dynamic dispatch for three_way_partition
NaderAlAwar Sep 14, 2025
b798d3b
Replace threshold with actual offset type
NaderAlAwar Sep 14, 2025
a1f4355
Use global_segment_offset_t type instead of long long
NaderAlAwar Sep 14, 2025
cf175f3
Use void* for iterator types
NaderAlAwar Sep 14, 2025
a08541b
Make selector op states part of build instead of static storage
NaderAlAwar Sep 14, 2025
c284814
Use existing type alias instead of redefining one
NaderAlAwar Sep 15, 2025
727f13e
Fix dangling pointer error in indirect arg
NaderAlAwar Sep 15, 2025
12d987d
Avoid calling static function and don't store op state since it is al…
NaderAlAwar Sep 15, 2025
d956f96
Refactor to avoid code duplication
NaderAlAwar Sep 15, 2025
4c36550
Continue cleaning up code
NaderAlAwar Sep 15, 2025
795d59e
Replace CPP_SOURCE op with LTOIR
NaderAlAwar Sep 15, 2025
bc0695f
Add missing util_device include
NaderAlAwar Sep 17, 2025
dfd3ec9
Remove OffsetSize from kernel source
NaderAlAwar Sep 17, 2025
9acc485
Add missing enable_if include
NaderAlAwar Sep 17, 2025
cb66b06
Add ptx_json for delay constructors
NaderAlAwar Sep 17, 2025
89a612c
Add check for offset iterator types
NaderAlAwar Sep 18, 2025
591858a
Remove unused code
NaderAlAwar Sep 19, 2025
a23aab0
Move sorting algorithms to a new directory and add segmented_sort bin…
NaderAlAwar Sep 19, 2025
4406abc
Fix cython compilation errors
NaderAlAwar Sep 19, 2025
253c804
Add initial python wrappers for segmented_sort
NaderAlAwar Sep 19, 2025
dd783ba
Add missing imports
NaderAlAwar Sep 19, 2025
97fbfc4
Adjust segmented_sort build to not need the output arrays
NaderAlAwar Sep 19, 2025
dee4a17
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Sep 19, 2025
8964529
Separate num_segments and num_items properly
NaderAlAwar Sep 20, 2025
c376097
Add fp16 include during policy wrapper creation
NaderAlAwar Sep 20, 2025
c492b21
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Sep 20, 2025
c4ae3e2
Don't use num_segments == 0 in tests since this causes an issue where…
NaderAlAwar Sep 21, 2025
bcfd6b9
Add typename to avoid benchmark compilation error
NaderAlAwar Sep 22, 2025
e62a346
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 22, 2025
27a2d56
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Sep 22, 2025
baa01a3
Skip SASS check when the dtype is int64 since it exists in C++ as well
NaderAlAwar Sep 22, 2025
8c3e808
Rename parameter list for clarity
NaderAlAwar Sep 22, 2025
79f3f4a
c.parallel: enable dynamic policies in scan.
griwes Sep 22, 2025
418b1c8
Add a missing include.
griwes Sep 22, 2025
529f462
Merge remote-tracking branch 'origin/main' into feature/scan-dynamic-…
griwes Sep 22, 2025
8d416cc
Merge branch 'main' into three-way-partition-dynamic-cub-dispatch
NaderAlAwar Sep 22, 2025
2ec8219
Update cub/cub/device/dispatch/dispatch_three_way_partition.cuh
bernhardmgruber Sep 25, 2025
a22bab3
Merge branch 'main' into scan-dynamic-policy-main-merge
NaderAlAwar Sep 25, 2025
dac82b3
Merge branch 'scan-dynamic-policy-main-merge' into three-way-partitio…
NaderAlAwar Sep 25, 2025
509a0b0
Undo some CUB changes from #5960 (they will be added in another PR)
NaderAlAwar Sep 25, 2025
24c43dc
Merge branch 'three-way-partition-dynamic-cub-dispatch' into segmente…
NaderAlAwar Sep 25, 2025
eb456e6
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 25, 2025
fb77098
Update retrieval of delay constructor
NaderAlAwar Sep 25, 2025
6fcb492
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Sep 25, 2025
8d28b6c
Add comments at end of idefs
NaderAlAwar Sep 27, 2025
0f2c191
Merge branch 'main' into three-way-partition-dynamic-cub-dispatch
NaderAlAwar Sep 29, 2025
c5026dd
Merge branch 'three-way-partition-dynamic-cub-dispatch' into segmente…
NaderAlAwar Sep 29, 2025
43e95f5
Merge branch 'main' into segmented-sort-dynamic-cub-dispatch
NaderAlAwar Sep 30, 2025
39d7314
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 30, 2025
0d0baf6
remove redundant ptx json from delay constructors
NaderAlAwar Sep 30, 2025
82d030e
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 30, 2025
0d47ca4
Fix variable shadowing warning in MSVC
NaderAlAwar Sep 30, 2025
9b7cb1c
Fix other variable shadowing warning in MSVC
NaderAlAwar Sep 30, 2025
6db94bc
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 30, 2025
56de6c1
Add methods to retrieve enums used for asserts
NaderAlAwar Sep 30, 2025
60bf551
Replace Check*() methods with CUB_DETAIL_STATIC_ISH_ASSERT
NaderAlAwar Sep 30, 2025
8ee191d
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 30, 2025
c3ab74e
static cast to handle msvc warning
NaderAlAwar Sep 30, 2025
fdd3581
Merge branch 'segmented-sort-dynamic-cub-dispatch' into segmented-sor…
NaderAlAwar Sep 30, 2025
1133177
Merge branch 'main' into segmented-sort-c-parallel
NaderAlAwar Oct 1, 2025
777cbd9
Use different OffsetT to not break windows build
NaderAlAwar Oct 1, 2025
754fc0f
Fix merge conflict
NaderAlAwar Oct 1, 2025
fac3403
Address review comments
NaderAlAwar Oct 28, 2025
68ad687
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Oct 28, 2025
cc8c21a
Merge branch 'main' into segmented-sort-c-parallel
NaderAlAwar Oct 30, 2025
7e383a0
Implement single compilation for segmented_sort. This required making…
NaderAlAwar Oct 31, 2025
2a5a678
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Oct 31, 2025
b96b1ad
Add missing imports and fix imports in test
NaderAlAwar Oct 31, 2025
f7d50b9
Fix MSVC error
NaderAlAwar Oct 31, 2025
ca1c1cc
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Oct 31, 2025
e951835
Fix MSVC CI errors
NaderAlAwar Nov 3, 2025
10d0185
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Nov 3, 2025
8fdb4fe
Add comment explaining selector op compilation
NaderAlAwar Nov 3, 2025
5dedbd5
Use dummy global variable instead of &op in indirect_arg_t constructor
NaderAlAwar Nov 3, 2025
521cce2
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Nov 3, 2025
169b5fc
Revert change made to step counting iterator that was causing segment…
NaderAlAwar Nov 4, 2025
a77826d
Merge branch 'segmented-sort-c-parallel' into segmented-sort-python-w…
NaderAlAwar Nov 4, 2025
4481f5c
Merge branch 'main' into segmented-sort-python-wrappers
NaderAlAwar Nov 4, 2025
499e772
Fix missing args in call
NaderAlAwar Nov 4, 2025
27ba765
Use cccl_type_name_from_nvrtc to avoid windows errors
NaderAlAwar Nov 4, 2025
a588a5d
Fix incorrect file name
miscco Nov 4, 2025
80b42ed
Merge branch 'main' into segmented-sort-python-wrappers
NaderAlAwar Nov 4, 2025
e8fcef4
fix
bernhardmgruber Nov 4, 2025
4cf0750
Merge remote-tracking branch 'miscco/fix_invalid_file_name' into segm…
NaderAlAwar Nov 4, 2025
1b312bd
Merge remote-tracking branch 'miscco/fix_invalid_file_name' into segm…
NaderAlAwar Nov 4, 2025
e836931
Address reviewer feedback
NaderAlAwar Nov 4, 2025
0c03945
Address reviewer feedback
NaderAlAwar Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions c/parallel/src/segmented_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <cub/detail/launcher/cuda_driver.cuh> // cub::detail::CudaDriverLauncherFactory
#include <cub/detail/ptx-json-parser.h>
#include <cub/device/dispatch/dispatch_segmented_sort.cuh> // cub::DispatchSegmentedSort
#include <cub/device/dispatch/kernels/segmented_sort.cuh> // DeviceSegmentedSort kernels
#include <cub/device/dispatch/kernels/kernel_segmented_sort.cuh> // DeviceSegmentedSort kernels
#include <cub/device/dispatch/tuning/tuning_segmented_sort.cuh> // policy_hub
#include <cub/thread/thread_load.cuh> // cub::LoadModifier

Expand Down Expand Up @@ -56,10 +56,10 @@ std::string get_device_segmented_sort_fallback_kernel_name(
cccl_sort_order_t sort_order)
{
std::string chained_policy_t;
check(nvrtcGetTypeName<device_segmented_sort_policy>(&chained_policy_t));
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));

std::string offset_t;
check(nvrtcGetTypeName<OffsetT>(&offset_t));
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

/*
template <SortOrder Order, // 0 (ascending)
Expand Down Expand Up @@ -90,10 +90,10 @@ std::string get_device_segmented_sort_kernel_small_name(
cccl_sort_order_t sort_order)
{
std::string chained_policy_t;
check(nvrtcGetTypeName<device_segmented_sort_policy>(&chained_policy_t));
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));

std::string offset_t;
check(nvrtcGetTypeName<OffsetT>(&offset_t));
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

/*
template <SortOrder Order, // 0 (ascending)
Expand Down Expand Up @@ -124,10 +124,10 @@ std::string get_device_segmented_sort_kernel_large_name(
cccl_sort_order_t sort_order)
{
std::string chained_policy_t;
check(nvrtcGetTypeName<device_segmented_sort_policy>(&chained_policy_t));
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));

std::string offset_t;
check(nvrtcGetTypeName<OffsetT>(&offset_t));
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

/*
template <SortOrder Order, // 0 (ascending)
Expand Down Expand Up @@ -182,11 +182,11 @@ cccl_op_t make_segments_selector_op(
cccl_op_t selector_op{};
auto selector_op_state = std::make_unique<selector_state_t>();
std::string offset_t;
check(nvrtcGetTypeName<OffsetT>(&offset_t));
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

const std::string code = std::format(
R"XXX(
#include <cub/device/dispatch/kernels/segmented_sort.cuh>
#include <cub/device/dispatch/kernels/kernel_segmented_sort.cuh>

extern "C" __device__ void {0}(void* state_ptr, const void* arg_ptr, void* result_ptr)
{{
Expand Down Expand Up @@ -297,7 +297,7 @@ std::string get_three_way_partition_init_kernel_name()
std::string get_three_way_partition_kernel_name(std::string_view large_selector_t, std::string_view small_selector_t)
{
std::string chained_policy_t;
check(nvrtcGetTypeName<device_three_way_partition_policy>(&chained_policy_t));
check(cccl_type_name_from_nvrtc<device_three_way_partition_policy>(&chained_policy_t));

static constexpr std::string_view input_it_t =
"thrust::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>";
Expand All @@ -308,7 +308,7 @@ std::string get_three_way_partition_kernel_name(std::string_view large_selector_
static constexpr std::string_view num_selected_it_t = "cub::detail::segmented_sort::local_segment_index_t*";
static constexpr std::string_view scan_tile_state_t = "cub::detail::three_way_partition::ScanTileStateT";
std::string offset_t;
check(nvrtcGetTypeName<OffsetT>(&offset_t));
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

static constexpr std::string_view per_partition_offset_t = "cub::detail::three_way_partition::per_partition_offset_t";
static constexpr std::string_view streaming_context_t =
Expand Down Expand Up @@ -638,9 +638,9 @@ CUresult cccl_device_segmented_sort_build_ex(

const std::string final_src = std::format(
R"XXX(
#include <cub/device/dispatch/kernels/segmented_sort.cuh>
#include <cub/device/dispatch/kernels/kernel_segmented_sort.cuh>
#include <cub/device/dispatch/tuning/tuning_segmented_sort.cuh>
#include <cub/device/dispatch/kernels/three_way_partition.cuh>
#include <cub/device/dispatch/kernels/kernel_three_way_partition.cuh>
#include <cub/device/dispatch/tuning/tuning_three_way_partition.cuh>

{0}
Expand Down
4 changes: 4 additions & 0 deletions python/cuda_cccl/cuda/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
make_radix_sort,
make_reduce_into,
make_segmented_reduce,
make_segmented_sort,
make_three_way_partition,
make_unary_transform,
make_unique_by_key,
merge_sort,
radix_sort,
reduce_into,
segmented_reduce,
segmented_sort,
three_way_partition,
unary_transform,
unique_by_key,
Expand Down Expand Up @@ -59,6 +61,7 @@
"make_radix_sort",
"make_reduce_into",
"make_segmented_reduce",
"make_segmented_sort",
"make_three_way_partition",
"make_unary_transform",
"make_unique_by_key",
Expand All @@ -69,6 +72,7 @@
"reduce_into",
"ReverseIterator",
"segmented_reduce",
"segmented_sort",
"SortOrder",
"TransformIterator",
"TransformOutputIterator",
Expand Down
23 changes: 23 additions & 0 deletions python/cuda_cccl/cuda/compute/_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,29 @@ class DeviceHistogramBuildResult:
stream,
) -> None: ...

# -----------------
# DeviceSegmentedSort
# -----------------

class DeviceSegmentedSortBuildResult:
def __init__(self): ...
def compute(
self,
temp_storage_ptr: int | None,
temp_storage_nbytes: int,
d_in_keys: Iterator,
d_out_keys: Iterator,
d_in_values: Iterator,
d_out_values: Iterator,
num_items: int,
num_segments: int,
d_begin_offsets: Iterator,
d_end_offsets: Iterator,
is_overwrite_okay: bool,
selector: int,
stream,
) -> tuple[int, int]: ...

# ---------------------
# DeviceThreeWayPartition
# ---------------------
Expand Down
144 changes: 143 additions & 1 deletion python/cuda_cccl/cuda/compute/_bindings_impl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,7 @@ cdef class DeviceRadixSortBuildResult:

if status != 0:
raise RuntimeError(
f"Failed executing ascending radix_sort, error code: {status}"
f"Failed executing radix_sort, error code: {status}"
)
return <object>storage_sz, <object>selector_int

Expand Down Expand Up @@ -2271,3 +2271,145 @@ cdef class DeviceThreeWayPartitionBuildResult:
<const char*>self.build_data.cubin,
self.build_data.cubin_size
)


# -------------------
# DeviceSegmentedSort
# -------------------

cdef extern from "cccl/c/segmented_sort.h":
cdef struct cccl_device_segmented_sort_build_result_t 'cccl_device_segmented_sort_build_result_t':
const char* cubin
size_t cubin_size

cdef CUresult cccl_device_segmented_sort_build(
cccl_device_segmented_sort_build_result_t *build_ptr,
cccl_sort_order_t sort_order,
cccl_iterator_t d_keys_in,
cccl_iterator_t d_keys_out,
cccl_iterator_t begin_offset_in,
cccl_iterator_t end_offset_in,
int, int, const char *, const char *, const char *, const char *
) nogil

cdef CUresult cccl_device_segmented_sort(
cccl_device_segmented_sort_build_result_t build,
void* d_temp_storage,
size_t* temp_storage_bytes,
cccl_iterator_t d_keys_in,
cccl_iterator_t d_keys_out,
cccl_iterator_t d_values_in,
cccl_iterator_t d_values_out,
int64_t num_items,
int64_t num_segments,
cccl_iterator_t start_offset_in,
cccl_iterator_t end_offset_in,
bint is_overwrite_okay,
int* selector,
CUstream stream
) nogil

cdef CUresult cccl_device_segmented_sort_cleanup(
cccl_device_segmented_sort_build_result_t* build_ptr
) nogil

cdef class DeviceSegmentedSortBuildResult:
cdef cccl_device_segmented_sort_build_result_t build_data

def __dealloc__(DeviceSegmentedSortBuildResult self):
cdef CUresult status = -1
with nogil:
status = cccl_device_segmented_sort_cleanup(&self.build_data)
if (status != 0):
print(f"Return code {status} encountered during segmented_sort result cleanup")

def __cinit__(
DeviceSegmentedSortBuildResult self,
cccl_sort_order_t order,
Iterator d_keys_in,
Iterator d_values_in,
Iterator begin_offset_in,
Iterator end_offset_in,
CommonData common_data,
):
cdef CUresult status = -1
cdef int cc_major = common_data.get_cc_major()
cdef int cc_minor = common_data.get_cc_minor()
cdef const char *cub_path = common_data.cub_path_get_c_str()
cdef const char *thrust_path = common_data.thrust_path_get_c_str()
cdef const char *libcudacxx_path = common_data.libcudacxx_path_get_c_str()
cdef const char *ctk_path = common_data.ctk_path_get_c_str()

memset(&self.build_data, 0, sizeof(cccl_device_segmented_sort_build_result_t))
with nogil:
status = cccl_device_segmented_sort_build(
&self.build_data,
order,
d_keys_in.iter_data,
d_values_in.iter_data,
begin_offset_in.iter_data,
end_offset_in.iter_data,
cc_major,
cc_minor,
cub_path,
thrust_path,
libcudacxx_path,
ctk_path,
)
if status != 0:
raise RuntimeError(
f"Failed building segmented_sort, error code: {status}"
)

cpdef tuple compute(
DeviceSegmentedSortBuildResult self,
temp_storage_ptr,
temp_storage_bytes,
Iterator d_keys_in,
Iterator d_keys_out,
Iterator d_values_in,
Iterator d_values_out,
size_t num_items,
size_t num_segments,
Iterator start_offset_in,
Iterator end_offset_in,
bint is_overwrite_okay,
selector,
stream
):
cdef CUresult status = -1
cdef void *storage_ptr = (<void *><size_t>temp_storage_ptr) if temp_storage_ptr else NULL
cdef size_t storage_sz = <size_t>temp_storage_bytes
cdef int selector_int = <int>selector
cdef CUstream c_stream = <CUstream><size_t>(stream) if stream else NULL

with nogil:
status = cccl_device_segmented_sort(
self.build_data,
storage_ptr,
&storage_sz,
d_keys_in.iter_data,
d_keys_out.iter_data,
d_values_in.iter_data,
d_values_out.iter_data,
<uint64_t>num_items,
<uint64_t>num_segments,
start_offset_in.iter_data,
end_offset_in.iter_data,
is_overwrite_okay,
&selector_int,
c_stream
)

if status != 0:
raise RuntimeError(
f"Failed executing segmented_sort, error code: {status}"
)
return <object>storage_sz, <object>selector_int


def _get_cubin(self):
return PyBytes_FromStringAndSize(
<const char*>self.build_data.cubin,
self.build_data.cubin_size
)
14 changes: 9 additions & 5 deletions python/cuda_cccl/cuda/compute/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@

from ._histogram import histogram_even as histogram_even
from ._histogram import make_histogram_even as make_histogram_even
from ._merge_sort import make_merge_sort as make_merge_sort
from ._merge_sort import merge_sort as merge_sort
from ._radix_sort import DoubleBuffer, SortOrder
from ._radix_sort import make_radix_sort as make_radix_sort
from ._radix_sort import radix_sort as radix_sort
from ._reduce import make_reduce_into as make_reduce_into
from ._reduce import reduce_into as reduce_into
from ._scan import exclusive_scan as exclusive_scan
Expand All @@ -18,6 +13,13 @@
from ._scan import make_inclusive_scan as make_inclusive_scan
from ._segmented_reduce import make_segmented_reduce as make_segmented_reduce
from ._segmented_reduce import segmented_reduce
from ._sort import DoubleBuffer, SortOrder
from ._sort import make_merge_sort as make_merge_sort
from ._sort import make_radix_sort as make_radix_sort
from ._sort import make_segmented_sort as make_segmented_sort
from ._sort import merge_sort as merge_sort
from ._sort import radix_sort as radix_sort
from ._sort import segmented_sort as segmented_sort
from ._three_way_partition import make_three_way_partition as make_three_way_partition
from ._three_way_partition import three_way_partition as three_way_partition
from ._transform import binary_transform, unary_transform
Expand Down Expand Up @@ -47,6 +49,8 @@
"make_segmented_reduce",
"unique_by_key",
"make_unique_by_key",
"segmented_sort",
"make_segmented_sort",
"three_way_partition",
"make_three_way_partition",
"DoubleBuffer",
Expand Down
23 changes: 23 additions & 0 deletions python/cuda_cccl/cuda/compute/algorithms/_sort/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._merge_sort import make_merge_sort as make_merge_sort
from ._merge_sort import merge_sort as merge_sort
from ._radix_sort import make_radix_sort as make_radix_sort
from ._radix_sort import radix_sort as radix_sort
from ._segmented_sort import make_segmented_sort as make_segmented_sort
from ._segmented_sort import segmented_sort as segmented_sort
from ._sort_common import DoubleBuffer, SortOrder

__all__ = [
"make_merge_sort",
"merge_sort",
"make_radix_sort",
"radix_sort",
"make_segmented_sort",
"segmented_sort",
"DoubleBuffer",
"SortOrder",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@

import numba

from .. import _bindings
from .. import _cccl_interop as cccl
from .._caching import CachableFunction, cache_with_key
from .._cccl_interop import call_build, set_cccl_iterator_state
from .._utils import protocols
from .._utils.protocols import (
from ... import _bindings
from ... import _cccl_interop as cccl
from ..._caching import CachableFunction, cache_with_key
from ..._cccl_interop import call_build, set_cccl_iterator_state
from ..._utils import protocols
from ..._utils.protocols import (
get_data_pointer,
validate_and_get_stream,
)
from .._utils.temp_storage_buffer import TempStorageBuffer
from ..iterators._iterators import IteratorBase
from ..op import OpKind
from ..typing import DeviceArrayLike
from ..._utils.temp_storage_buffer import TempStorageBuffer
from ...iterators._iterators import IteratorBase
from ...op import OpKind
from ...typing import DeviceArrayLike


def make_cache_key(
Expand Down
Loading