Skip to content

Commit

Permalink
Move binary_search_range to header (pytorch#1593)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1593

Move `binary_search_range` to header to allow kernels outside of
`jagged_tensor_ops.cu` to use it.

Reviewed By: jianyuh

Differential Revision: D43216577

fbshipit-source-id: 96a681ea0e0f3db929dc6e6d25d971107015aba1
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 24, 2023
1 parent 111f696 commit 7f791ed
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
45 changes: 45 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3158,4 +3158,49 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
}
}

template <typename scalar_t>
__device__ __forceinline__ void binary_search_range(
int* found,
const scalar_t* arr,
const scalar_t target,
const int num_entries) {
const int last_entry = num_entries - 1;
int start = 0, end = last_entry;
int found_ = -1;
while (start <= end) {
int mid = start + (end - start) / 2;
scalar_t mid_offset = arr[mid];
if (target == mid_offset) {
if (mid != last_entry && target != arr[last_entry]) {
// Do linear scan in case of duplicate data (We assume that the
// number of duplicates is small. This can we very bad if the
// number of duplicates is large)
for (int i = mid + 1; i < num_entries; i++) {
if (target != arr[i]) {
found_ = i;
break;
}
}
}
break;
} else if (target < mid_offset) {
if (mid == 0) {
found_ = 0;
break;
} else if (mid - 1 >= 0 && target > arr[mid - 1]) {
found_ = mid;
break;
}
end = mid - 1;
} else {
if (mid + 1 <= last_entry && target < arr[mid + 1]) {
found_ = mid + 1;
break;
}
start = mid + 1;
}
}
*found = found_;
}

} // namespace fbgemm_gpu
45 changes: 0 additions & 45 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2324,51 +2324,6 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
return padded_values_per_key;
}
template <typename scalar_t>
__device__ __forceinline__ void binary_search_range(
int* found,
const scalar_t* arr,
const scalar_t target,
const int num_entries) {
const int last_entry = num_entries - 1;
int start = 0, end = last_entry;
int found_ = -1;
while (start <= end) {
int mid = start + (end - start) / 2;
scalar_t mid_offset = arr[mid];
if (target == mid_offset) {
if (mid != last_entry && target != arr[last_entry]) {
// Do linear scan in case of duplicate data (We assume that the
// number of duplicates is small. This can we very bad if the
// number of duplicates is large)
for (int i = mid + 1; i < num_entries; i++) {
if (target != arr[i]) {
found_ = i;
break;
}
}
}
break;
} else if (target < mid_offset) {
if (mid == 0) {
found_ = 0;
break;
} else if (mid - 1 >= 0 && target > arr[mid - 1]) {
found_ = mid;
break;
}
end = mid - 1;
} else {
if (mid + 1 <= last_entry && target < arr[mid + 1]) {
found_ = mid + 1;
break;
}
start = mid + 1;
}
}
*found = found_;
}
template <typename index_t, typename offset_t, typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel(
scalar_t* output,
Expand Down

0 comments on commit 7f791ed

Please sign in to comment.