@@ -283,7 +283,7 @@ static inline int cudaGetBlocks(const int N) {
283283#define DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR (T, Funcname, op ) \
284284 __global__ void _Kernel_##T##_##Funcname( \
285285 T* dst, const T* src, const int n) { \
286- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
286+ for (auto i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
287287 i += blockDim .x * gridDim .x ) { \
288288 dst[i] = dst[i] op src[i]; \
289289 } \
@@ -301,7 +301,7 @@ static inline int cudaGetBlocks(const int N) {
301301#define DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR (Funcname, op ) \
302302 __global__ void _Kernel_half_##Funcname( \
303303 half* dst, const half* src, const int n) { \
304- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
304+ for (auto i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
305305 i += blockDim .x * gridDim .x ) { \
306306 float r = __half2float (dst[i]) op __half2float (src[i]); \
307307 dst[i] = __float2half (r); \
@@ -337,7 +337,7 @@ DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR(cudaProduct, *);
337337#define DELEGATE_SIMPLE_CUDA_BINARY_COMPARE (T, Funcname, op ) \
338338 __global__ void _Kernel_##T##_##Funcname( \
339339 T* dst, const T* src, const int n) { \
340- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
340+ for (auto i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
341341 i += blockDim .x * gridDim .x ) { \
342342 if (src[i] op dst[i]) { \
343343 dst[i] = src[i]; \
@@ -357,7 +357,7 @@ DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR(cudaProduct, *);
357357#define DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE (Funcname, op ) \
358358 __global__ void _Kernel_half_##Funcname( \
359359 half* dst, const half* src, const int n) { \
360- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
360+ for (auto i = blockIdx .x * blockDim .x + threadIdx .x ; i < (n); \
361361 i += blockDim .x * gridDim .x ) { \
362362 if (__half2float (src[i]) op __half2float (dst[i])) { \
363363 dst[i] = src[i]; \
@@ -398,6 +398,12 @@ DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaSum, +);
398398DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR (BFloat16, cudaProduct, *);
399399DELEGATE_SIMPLE_CUDA_BINARY_COMPARE (BFloat16, cudaMin, <);
400400DELEGATE_SIMPLE_CUDA_BINARY_COMPARE (BFloat16, cudaMax, >);
401+ using Half = c10::Half;
402+ INSTANTIATE_COPY_ASYNC (Half);
403+ DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR (Half, cudaSum, +);
404+ DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR (Half, cudaProduct, *);
405+ DELEGATE_SIMPLE_CUDA_BINARY_COMPARE (Half, cudaMin, <);
406+ DELEGATE_SIMPLE_CUDA_BINARY_COMPARE (Half, cudaMax, >);
401407#endif
402408
403409} // namespace gloo
0 commit comments