forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathForeachTensorAddScalar.cu
96 lines (80 loc) · 3.47 KB
/
ForeachTensorAddScalar.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/ForeachUtils.cuh>
#include <ATen/native/cuda/MultiTensorApply.cuh>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at { namespace native {
namespace {
template<typename x_t, typename out_t>
struct AddScalarFunctor {
__device__ void operator() (
int chunk_size,
TensorListMetadata<2>& tl,
x_t scalar) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
x_t r_x[kILP];
out_t r_out[kILP];
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_out[ii] = static_cast<x_t>(r_x[ii]) + scalar;
}
// store
load_store(out, r_out, i_start, 0);
}
}
else {
// Non-divergent exit condition for __syncthreads, not necessary here
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size) {
r_x[ii] = x[i];
}
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_out[ii] = static_cast<x_t>(r_x[ii]) + scalar;
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_out[ii];
}
}
}
}
};
} // namespace
std::vector<Tensor> foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Scalar scalar) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
if (!check_fast_route(tensors, scalar)) {
return at::native::foreach_add_scalar_kernel_fallback(tensors, scalar);
}
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
for (const auto& t: tensors) {
vec_res.emplace_back(at::native::empty_like(t));
}
tensor_lists.emplace_back(std::move(tensors.vec()));
tensor_lists.emplace_back(std::move(vec_res));
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda", [&]() {
multi_tensor_apply<2>(tensor_lists, AddScalarFunctor<scalar_t, scalar_t>(), scalar.to<scalar_t>());
});
return tensor_lists[1];
}
}} // namespace at::native