forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIndexKernel.cu
160 lines (131 loc) · 5.54 KB
/
IndexKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/Array.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ExpandUtils.h>
namespace at { namespace native {
static constexpr int launch_bound2 = 4;
static constexpr int launch_size_nd = 128;
template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
__global__ void index_elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}
template<int nt, int vt, typename func_t>
static void launch_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
auto stream = at::cuda::getCurrentCUDAStream();
index_elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename func_t>
void gpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) {
int num_indices = index_size.size();
AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(num_indices == iter.ntensors() - 2);
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_index_kernel(sub_iter, index_size, index_stride, f);
}
return;
}
auto sizes = at::detail::Array<int64_t, 25>(0);
auto strides = at::detail::Array<int64_t, 25>(0);
auto index_ptrs = at::detail::Array<char*, 25>(nullptr);
for (int i = 0; i < num_indices; i++) {
sizes[i] = index_size[i];
strides[i] = index_stride[i];
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
}
char* out_ptr = (char*)iter.data_ptr(0);
char* in_ptr = (char*)iter.data_ptr(1);
auto offset_calc = make_offset_calculator<3>(iter);
launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), [=]__device__(int idx) {
auto offsets = offset_calc.get(idx);
char* out_data = out_ptr + offsets[0];
char* in_data = in_ptr + offsets[1];
int64_t offset = 0;
#pragma unroll
for (int i = 0; i < num_indices; i++) {
int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
CUDA_KERNEL_ASSERT(index >= -sizes[i] && index < sizes[i] && "index out of bounds");
if (index < 0) {
index += sizes[i];
}
offset += index * strides[i];
}
f(out_data, in_data, offset);
});
}
// The kernels are templated on an opaque, self-aligned type of the correct
// size to avoid redundant kernels for different types of the same size.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <typename scalar_t>
void index_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
*(scalar_t*)out_data = *(scalar_t*)(in_data + offset);
});
}
template <typename scalar_t>
void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
*(scalar_t*)(out_data + offset) = *(scalar_t*)in_data;
});
}
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_cuda", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
});
}
static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
TORCH_CHECK(!accumulate, "index_put does not support accumulate=true");
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
});
}
static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
"masked_select: expected BoolTensor or ByteTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");
Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask;
Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self;
std::tie(_mask, _self) = expand_outplace(_mask, _self);
at::native::index_out(result, _self, _mask);
return result;
}
Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) {
namedinference::compute_broadcast_outnames(self, mask);
Tensor result = at::empty({0}, self.options());
return masked_select_out_cuda_impl(result, self, mask);
}
Tensor & masked_select_out_cuda(Tensor & result, const Tensor & self, const Tensor & mask) {
namedinference::compute_broadcast_outnames(self, mask);
return masked_select_out_cuda_impl(result, self, mask);
}
REGISTER_DISPATCH(index_stub, &index_kernel);
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
}} // namespace at::native