diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index d8af951afa7018..5e0a9d8352bd4a 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -20,16 +20,16 @@ namespace at::native { +// _foreach_norm supports only L1, L2, and inf norm +enum class NormType { L1, L2, LInf }; + template < typename T, - int NormType, + NormType norm_type, int depth = 1, int r_args_depth = 1, int res_arg_index = 0> struct LpNormFunctor { - static_assert( - NormType == 1 || NormType == 2, - "foreach_norm supports only L1 and L2 norm"); using opmath_t = typename at::opmath_type; __device__ __forceinline__ void operator()( int chunk_size, @@ -61,7 +61,11 @@ struct LpNormFunctor { #pragma unroll for (int ii = 0; ii < kILP; ii++) { opmath_t next = static_cast(r_x[ii]); - vals[ii] += NormType == 1 ? ::abs(next) : next * next; + if constexpr (norm_type == NormType::LInf) { + vals[ii] = max_propagate_nan(vals[ii], ::abs(next)); + } else { + vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next; + } } } } else { @@ -72,7 +76,11 @@ struct LpNormFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { opmath_t next = static_cast(x[i]); - vals[ii] += NormType == 1 ? ::abs(next) : next * next; + if constexpr (norm_type == NormType::LInf) { + vals[ii] = max_propagate_nan(vals[ii], ::abs(next)); + } else { + vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next; + } } } } @@ -80,19 +88,28 @@ struct LpNormFunctor { auto val = opmath_t(0); for (int i = 0; i < kILP; i++) { - val += vals[i]; + if constexpr (norm_type == NormType::LInf) { + val = max_propagate_nan(val, vals[i]); + } else { + val += vals[i]; + } } - auto final = at::native::cuda_utils::BlockReduceSum(val, s_vals); + auto final_val = norm_type == NormType::L1 || norm_type == NormType::L2 + ? at::native::cuda_utils::BlockReduceSum(val, s_vals) + : at::native::cuda_utils::BlockReduceMax(val, s_vals); if (threadIdx.x == 0) { output_per_tensor [(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + - chunk_idx] = final; + chunk_idx] = final_val; } } }; -template > +template < + typename T, + NormType norm_type, + typename opmath_t = at::opmath_type> __global__ void lpnorm_cleanup( const opmath_t* output_per_tensor, T* ret_per_tensor, @@ -103,11 +120,20 @@ __global__ void lpnorm_cleanup( output_per_tensor + blockIdx.x * max_chunks_per_tensor; opmath_t val = 0; for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { - val += output_this_tensor[i]; + if constexpr (norm_type == NormType::LInf) { + val = max_propagate_nan(val, output_this_tensor[i]); + } else { + val += output_this_tensor[i]; + } } - opmath_t final = at::native::cuda_utils::BlockReduceSum(val, vals); + opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2 + ? at::native::cuda_utils::BlockReduceSum(val, vals) + : at::native::cuda_utils::BlockReduceMax(val, vals); if (threadIdx.x == 0) { - ret_per_tensor[blockIdx.x] = NormType == 1 ? final : ::sqrt(final); + ret_per_tensor[blockIdx.x] = + norm_type == NormType::L1 || norm_type == NormType::LInf + ? final_val + : ::sqrt(final_val); } } @@ -135,7 +161,8 @@ std::vector foreach_tensor_norm_cuda( at::isComplexType(scalar_type); }); if (!can_use_fast_route(tensors) || has_int_or_complex || - !(p == static_cast(1) || p == static_cast(2))) { + !(p == static_cast(1) || p == static_cast(2) || + p == std::numeric_limits::infinity())) { return foreach_tensor_norm_slow(tensors, ord); } @@ -166,14 +193,14 @@ std::vector foreach_tensor_norm_cuda( using opmath_t = typename at::opmath_type; multi_tensor_apply<1>( tensor_lists, - LpNormFunctor(), + LpNormFunctor(), output_per_tensor.mutable_data_ptr(), max_chunks_per_tensor); C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( device_of(output_per_tensor)); auto stream = at::cuda::getCurrentCUDAStream(); - lpnorm_cleanup<<>>( + lpnorm_cleanup<<>>( output_per_tensor.const_data_ptr(), ret_per_tensor.mutable_data_ptr(), max_chunks_per_tensor); @@ -189,19 +216,43 @@ std::vector foreach_tensor_norm_cuda( using opmath_t = typename at::opmath_type; multi_tensor_apply<1>( tensor_lists, - LpNormFunctor(), + LpNormFunctor(), output_per_tensor.mutable_data_ptr(), max_chunks_per_tensor); C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( device_of(output_per_tensor)); auto stream = at::cuda::getCurrentCUDAStream(); - lpnorm_cleanup<<>>( + lpnorm_cleanup<<>>( output_per_tensor.const_data_ptr(), ret_per_tensor.mutable_data_ptr(), max_chunks_per_tensor); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); + } else if (p == std::numeric_limits::infinity()) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + tensor_lists[0][0].scalar_type(), + "foreach_tensor_norm_cuda", + [&]() { + using opmath_t = typename at::opmath_type; + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + const at::cuda::OptionalCUDAGuard device_guard( + device_of(output_per_tensor)); + auto stream = at::cuda::getCurrentCUDAStream(); + lpnorm_cleanup + <<>>( + output_per_tensor.const_data_ptr(), + ret_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } else { TORCH_CHECK( false, diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index fa75c71f8acafd..a21588a1329c9d 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -29,6 +29,19 @@ __inline__ __device__ T WarpReduceSum(T val) { return val; } +// Picks the maximum `val` accross all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + struct Block1D { static __forceinline__ __device__ int Tid() { return threadIdx.x; } @@ -72,6 +85,31 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) { return val; } +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + template __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { #pragma unroll diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8bba46d01777f8..3607db5ccfb143 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9146,10 +9146,10 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * assert "num_input_tensors" not in kwargs _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad - for ord in (0, 1, 2, -1, -2): + for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) disable_fastpath = True - if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16): + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): disable_fastpath = False yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) @@ -9159,13 +9159,32 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad - for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2)): + for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))): input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) disable_fastpath = True - if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16): + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): disable_fastpath = False yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) + # Also test nan propagation with a single tensor, but skip autograd testing + if not requires_grad: + nan_inputs = [ + [float('nan')], + [float('nan'), 1.0], + [1.0, float('nan')], + [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0], + [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0], + [3.0, float('nan'), float('nan'), -1.5, 6.0], + ] + for input in nan_inputs: + x = torch.tensor(input, device=device) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) + + + class foreach_lerp_sample_func(foreach_inputs_sample_func): def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):