Skip to content

Commit

Permalink
Add inf norm support for _foreach_norm (pytorch#118441)
Browse files Browse the repository at this point in the history
Fixes pytorch#117803

Pull Request resolved: pytorch#118441
Approved by: https://github.com/mlazos
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Jan 31, 2024
1 parent e87ac82 commit 21ce53b
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 22 deletions.
87 changes: 69 additions & 18 deletions aten/src/ATen/native/cuda/ForeachReduceOp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@

namespace at::native {

// _foreach_norm supports only L1, L2, and inf norm
enum class NormType { L1, L2, LInf };

template <
typename T,
int NormType,
NormType norm_type,
int depth = 1,
int r_args_depth = 1,
int res_arg_index = 0>
struct LpNormFunctor {
static_assert(
NormType == 1 || NormType == 2,
"foreach_norm supports only L1 and L2 norm");
using opmath_t = typename at::opmath_type<T>;
__device__ __forceinline__ void operator()(
int chunk_size,
Expand Down Expand Up @@ -61,7 +61,11 @@ struct LpNormFunctor {
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
opmath_t next = static_cast<opmath_t>(r_x[ii]);
vals[ii] += NormType == 1 ? ::abs(next) : next * next;
if constexpr (norm_type == NormType::LInf) {
vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
} else {
vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
}
}
}
} else {
Expand All @@ -72,27 +76,40 @@ struct LpNormFunctor {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
opmath_t next = static_cast<opmath_t>(x[i]);
vals[ii] += NormType == 1 ? ::abs(next) : next * next;
if constexpr (norm_type == NormType::LInf) {
vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
} else {
vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next;
}
}
}
}
}

auto val = opmath_t(0);
for (int i = 0; i < kILP; i++) {
val += vals[i];
if constexpr (norm_type == NormType::LInf) {
val = max_propagate_nan(val, vals[i]);
} else {
val += vals[i];
}
}
auto final = at::native::cuda_utils::BlockReduceSum(val, s_vals);
auto final_val = norm_type == NormType::L1 || norm_type == NormType::L2
? at::native::cuda_utils::BlockReduceSum(val, s_vals)
: at::native::cuda_utils::BlockReduceMax(val, s_vals);

if (threadIdx.x == 0) {
output_per_tensor
[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
chunk_idx] = final;
chunk_idx] = final_val;
}
}
};

template <typename T, int NormType, typename opmath_t = at::opmath_type<T>>
template <
typename T,
NormType norm_type,
typename opmath_t = at::opmath_type<T>>
__global__ void lpnorm_cleanup(
const opmath_t* output_per_tensor,
T* ret_per_tensor,
Expand All @@ -103,11 +120,20 @@ __global__ void lpnorm_cleanup(
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
opmath_t val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
val += output_this_tensor[i];
if constexpr (norm_type == NormType::LInf) {
val = max_propagate_nan(val, output_this_tensor[i]);
} else {
val += output_this_tensor[i];
}
}
opmath_t final = at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals);
opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
: at::native::cuda_utils::BlockReduceMax(val, vals);
if (threadIdx.x == 0) {
ret_per_tensor[blockIdx.x] = NormType == 1 ? final : ::sqrt(final);
ret_per_tensor[blockIdx.x] =
norm_type == NormType::L1 || norm_type == NormType::LInf
? final_val
: ::sqrt(final_val);
}
}

Expand Down Expand Up @@ -135,7 +161,8 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
at::isComplexType(scalar_type);
});
if (!can_use_fast_route(tensors) || has_int_or_complex ||
!(p == static_cast<double>(1) || p == static_cast<double>(2))) {
!(p == static_cast<double>(1) || p == static_cast<double>(2) ||
p == std::numeric_limits<double>::infinity())) {
return foreach_tensor_norm_slow(tensors, ord);
}

Expand Down Expand Up @@ -166,14 +193,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, 1>(),
LpNormFunctor<scalar_t, NormType::L1>(),
output_per_tensor.mutable_data_ptr<opmath_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
lpnorm_cleanup<scalar_t, 1><<<ntensors, 512, 0, stream>>>(
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
Expand All @@ -189,19 +216,43 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, 2>(),
LpNormFunctor<scalar_t, NormType::L2>(),
output_per_tensor.mutable_data_ptr<opmath_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
lpnorm_cleanup<scalar_t, 2><<<ntensors, 512, 0, stream>>>(
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else if (p == std::numeric_limits<double>::infinity()) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
tensor_lists[0][0].scalar_type(),
"foreach_tensor_norm_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, NormType::LInf>(),
output_per_tensor.mutable_data_ptr<opmath_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
lpnorm_cleanup<scalar_t, NormType::LInf>
<<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
TORCH_CHECK(
false,
Expand Down
38 changes: 38 additions & 0 deletions aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ __inline__ __device__ T WarpReduceSum(T val) {
return val;
}

// Picks the maximum `val` accross all threads in a warp.
//
// Assumptions:
// - The size of each block should be a multiple of `C10_WARP_SIZE`
template <typename T>
__inline__ __device__ T WarpReduceMax(T val) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
}
return val;
}

struct Block1D {
static __forceinline__ __device__ int Tid() { return threadIdx.x; }

Expand Down Expand Up @@ -72,6 +85,31 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
return val;
}

// Picks out the maximum `val` across all threads in a block.
//
// Warning: the return value is only valid for thread 0.
// Assumptions:
// - The size of each block should be a multiple of `C10_WARP_SIZE`
// - `shared` should be a pointer to shared memory with size of, at least,
// `sizeof(T) * number_of_warps`
template <typename T, typename B = Block1D>
__inline__ __device__ T BlockReduceMax(T val, T* shared) {
const int tid = B::Tid();
const int lid = tid % C10_WARP_SIZE;
const int wid = tid / C10_WARP_SIZE;
val = WarpReduceMax(val);
__syncthreads(); // prevent races when BlockReduces are called in a row.
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (tid < B::Warps()) ? shared[lid] : T(0);
if (wid == 0) {
val = WarpReduceMax(val);
}
return val;
}

template <typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
#pragma unroll
Expand Down
27 changes: 23 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9146,10 +9146,10 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, *
assert "num_input_tensors" not in kwargs
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad
for ord in (0, 1, 2, -1, -2):
for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')):
input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
disable_fastpath = True
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)

Expand All @@ -9159,13 +9159,32 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad

for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2)):
for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))):
input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
disable_fastpath = True
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)

# Also test nan propagation with a single tensor, but skip autograd testing
if not requires_grad:
nan_inputs = [
[float('nan')],
[float('nan'), 1.0],
[1.0, float('nan')],
[1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0],
[7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0],
[3.0, float('nan'), float('nan'), -1.5, 6.0],
]
for input in nan_inputs:
x = torch.tensor(input, device=device)
disable_fastpath = True
if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)




class foreach_lerp_sample_func(foreach_inputs_sample_func):
def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):
Expand Down

0 comments on commit 21ce53b

Please sign in to comment.