-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathscatterElementsPluginKernel.cu
157 lines (135 loc) · 5.53 KB
/
scatterElementsPluginKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ************************************************************************
* Modified from pytorch_scatter
* Copyright (c) 2020 Matthias Fey <[email protected]>
* See https://github.com/rusty1s/pytorch_scatter/blob/master/LICENSE for details
* ************************************************************************
*/
#include "TensorInfo.cuh"
#include "common/dimsHelpers.h"
#include "reducer.cuh"
#include "scatterElementsPluginKernel.h"
#include <thrust/device_vector.h>
namespace nvinfer1
{
namespace plugin
{
#define THREADS 256
#define BLOCKS(N) (N + THREADS - 1) / THREADS
using detail::TensorInfo;
using detail::getTensorInfo;
using nvinfer1::pluginInternal::volume;
template <typename TScalar, ReductionType tReduce>
__global__ void scatterElements_kernel(const TScalar* updatesData, const TensorInfo<int64_t, int32_t> indexInfo,
TScalar* outData, int32_t nE, int32_t nK, int32_t nN, int32_t nbElements)
{
int32_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b = thread_idx / (nE * nK);
int32_t k = thread_idx % nK;
if (thread_idx < nbElements)
{
int32_t offset = detail::IndexToOffset<int64_t, int32_t, -1>::get(thread_idx, indexInfo);
int64_t idx = indexInfo.data[offset];
Reducer<TScalar, tReduce>::atomic_write(outData + b * nN * nK + idx * nK + k, updatesData[thread_idx]);
}
}
bool hasBfloat16AtomicAdd()
{
int deviceId;
cudaGetDevice(&deviceId);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceId);
return deviceProp.major >= 8;
}
inline uint32_t getElementSize(nvinfer1::DataType t) noexcept
{
switch (t)
{
case nvinfer1::DataType::kINT64: return 8;
case nvinfer1::DataType::kINT32:
case nvinfer1::DataType::kFLOAT: return 4;
case nvinfer1::DataType::kBF16:
case nvinfer1::DataType::kHALF: return 2;
case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kUINT8:
case nvinfer1::DataType::kINT8:
case nvinfer1::DataType::kFP8: return 1;
case nvinfer1::DataType::kINT4:
PLUGIN_FAIL("Unsupported data type");
}
return 0;
}
template <typename TScalar>
void dispatchScatterElementsKernel(void* outDataPtr, void const* dataDataPtr, void const* updatesDataPtr,
void const* indicesDataPtr, PluginTensorDesc const& outDesc, PluginTensorDesc const& dataDesc,
PluginTensorDesc const& updatesDesc, PluginTensorDesc const& indicesDesc, int64_t axis, ReductionType reduction,
cudaStream_t stream)
{
auto updatesNumEl = volume(updatesDesc.dims);
auto nB = 1;
for (auto i = 0; i < axis; i++)
{
nB *= updatesDesc.dims.d[i];
}
auto nE = updatesDesc.dims.d[axis];
auto nK = updatesNumEl / (nB * nE);
auto nN = outDesc.dims.d[axis];
auto indexInfo = getTensorInfo<int64_t, int32_t>(indicesDataPtr, indicesDesc);
auto updatesData = (TScalar*) updatesDataPtr;
auto outData = (TScalar*) outDataPtr;
AT_DISPATCH_REDUCTION_TYPES(reduction, [&] {
scatterElements_kernel<TScalar, REDUCE>
<<<BLOCKS(updatesNumEl), THREADS, 0, stream>>>(updatesData, indexInfo, outData, nE, nK, nN, updatesNumEl);
});
}
#define DISPATCH_RUN_KERNEL(TYPE) \
dispatchScatterElementsKernel<TYPE>(outDataPtr, dataDataPtr, updatesDataPtr, indicesDataPtr, outDesc, dataDesc, \
updatesDesc, indicesDesc, axis, reduction, stream)
void runScatterElementsKernel(void* outDataPtr, void const* dataDataPtr, void const* updatesDataPtr,
void const* indicesDataPtr, PluginTensorDesc const& outDesc, PluginTensorDesc const& dataDesc,
PluginTensorDesc const& updatesDesc, PluginTensorDesc const& indicesDesc, int64_t axis, ReductionType reduction,
cudaStream_t stream)
{
auto updatesNumEl = volume(updatesDesc.dims);
auto outNumEl = volume(outDesc.dims);
// copy dataDataPtr data to outDataPtr area first
cudaMemcpyAsync(outDataPtr, dataDataPtr, getElementSize(outDesc.type) * outNumEl, cudaMemcpyDeviceToDevice, stream);
if (updatesNumEl == 0)
{
return;
}
switch (outDesc.type)
{
case nvinfer1::DataType::kFLOAT: DISPATCH_RUN_KERNEL(float); break;
case nvinfer1::DataType::kHALF: DISPATCH_RUN_KERNEL(__half); break;
case nvinfer1::DataType::kINT32: DISPATCH_RUN_KERNEL(int32_t); break;
case nvinfer1::DataType::kINT64: DISPATCH_RUN_KERNEL(int64_t); break;
case nvinfer1::DataType::kBF16: DISPATCH_RUN_KERNEL(__nv_bfloat16); break;
case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kUINT8:
case nvinfer1::DataType::kINT8:
case nvinfer1::DataType::kINT4:
case nvinfer1::DataType::kFP8:
std::ostringstream stream;
stream << "Unsupported data type:" << (int)outDesc.type << std::endl;
PLUGIN_FAIL(stream.str().c_str());
break;
}
}
} // namespace plugin
} // namespace nvinfer1