From f09a6eaaa0c3c2da55c75419533c72f362a01d10 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Thu, 4 Jul 2024 17:05:36 +0800 Subject: [PATCH] [DCU] fix some faild ut --- .../phi/kernels/gpu/check_numerics_kernel.cu | 91 +------------------ paddle/phi/kernels/gpu/group_norm_kernel.cu | 15 +++ .../strings/gpu/strings_lower_upper_kernel.cu | 4 + .../test_elementwise_op_grad_grad.h | 16 ++-- 4 files changed, 28 insertions(+), 98 deletions(-) diff --git a/paddle/phi/kernels/gpu/check_numerics_kernel.cu b/paddle/phi/kernels/gpu/check_numerics_kernel.cu index f11f3460e7ebe..88f6cb66aea6c 100644 --- a/paddle/phi/kernels/gpu/check_numerics_kernel.cu +++ b/paddle/phi/kernels/gpu/check_numerics_kernel.cu @@ -58,82 +58,6 @@ static void InitMultiGPUOpVarMap() { multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex); } -template -__device__ __forceinline__ void PrintNanInfKernel(const T* value, - const size_t numel, - int print_num, - char* debug_info) { - const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - __shared__ unsigned int nan_count, inf_count, num_count; - if (threadIdx.x == 0) nan_count = inf_count = num_count = 0; - __syncthreads; - - for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) { - unsigned int count = 0; - if (isnan(value[i])) { - count = atomicAdd(&nan_count, 1); - } else if (isinf(value[i])) { - count = atomicAdd(&inf_count, 1); - } else { - count = atomicAdd(&num_count, 1); - } - // for cuda, print in every block - if (count < print_num) { - printf("numel:%lu idx:%lu value:%f\n", - static_cast(numel), - static_cast(i), - static_cast(value[i])); - } - } - __syncthreads; - -#ifdef __HIPCC__ - if (true && hipThreadIdx_x == 0) { - printf("In block %d, there has %u,%u,%u nan,inf,num\n", - hipBlockIdx_x, - nan_count, - inf_count, - num_count); -#else - if (true && threadIdx.x == 0) { - printf("In block %d, there has %u,%u,%u nan,inf,num\n", - blockIdx.x, - nan_count, - inf_count, - num_count); -#endif - PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info); - } -} - -// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s -template -__global__ void CheckNanInfKernel(const T* value, - const size_t numel, - int print_num, - char* debug_info) { - /// step 1, judge wheater has nan or inf - __shared__ volatile int has_nan_inf; - if (threadIdx.x == 0) has_nan_inf = false; - __syncthreads(); - - const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - T sum = static_cast(0.0); - // Todo(wangxi). simd speed up - for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) { - sum += (value[i] - value[i]); - } - - if (isnan(sum) || isinf(sum)) has_nan_inf = true; - __syncthreads(); - - /// Note. different blocks may behave differently - if (!has_nan_inf) return; - - PrintNanInfKernel(value, numel, print_num, debug_info); -} - template __device__ T BlockReduce(T value) { __shared__ T shared_mem[1024]; @@ -509,19 +433,7 @@ void CheckNumericsKernel(const Context& ctx, size_t blocks = std::min(static_cast(128), static_cast((tensor.numel() + threads - 1) / threads)); -#ifdef __HIPCC__ - int print_num = 3; - - hipLaunchKernelGGL(CheckNanInfKernel, - dim3(blocks), - dim3(threads), - 0, - ctx.stream(), - tensor.data(), - tensor.numel(), - print_num, - gpu_str_ptr); -#else + using MT = typename phi::dtype::MPTypeTrait::Type; int64_t numel_max_min = blocks; @@ -586,7 +498,6 @@ void CheckNumericsKernel(const Context& ctx, if (check_nan_inf_level == 0 && stack_height_limit > 0) { PrintStack(ctx, *stats, op_type, var_name, dev_id); } -#endif } } // namespace phi diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index 720447ea41a0e..2530d4f7473dd 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -880,8 +880,23 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, } x_mean /= number * imsize; x_var /= number * imsize; + +#ifdef __NVCC__ CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean); CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); +#endif +#ifdef __HIPCC__ + // Note(wangyanpeng04): When the block size is less than the warp size, + // WarpReduce will result in all zeros. It seems to be an internal problem of + // hipcub on DCU. + if (blockDim.x < phi::kps::details::kWarpSize) { + phi::CudaAtomicAdd(&mean[bid * groups + gid], x_mean); + phi::CudaAtomicAdd(&var[bid * groups + gid], x_var); + } else { + CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean); + CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); + } +#endif } template diff --git a/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu b/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu index 2a238e8a49b4d..7c793c9e4dc0f 100644 --- a/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu +++ b/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu @@ -41,7 +41,11 @@ struct AsciiCaseConverter { const pstring* in, pstring* out, size_t num) const { +#ifdef PADDLE_WITH_HIP + dim3 block_size = dim3(256, 1); +#else dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); +#endif dim3 grid_size = dim3((num + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); StringCaseConvertCUDAKernel diff --git a/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h b/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h index 3e772aa632e52..6db23533a29a2 100644 --- a/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h +++ b/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h @@ -128,13 +128,6 @@ class TestElementwiseOpGradGrad { } auto *out_ptr = cpu_out.data(); size_t numel = static_cast(common::product(dims_)); -#ifdef PADDLE_WITH_HIP - auto is_equal = std::equal( - out_ptr, - out_ptr + numel, - expected_outs_[out_name].data(), - [](const float &l, const float &r) { return fabs(l - r) < 1e-8; }); -#else bool is_equal; if (op_type_ == "elementwise_div_grad_grad") { is_equal = std::equal(out_ptr, @@ -144,10 +137,17 @@ class TestElementwiseOpGradGrad { return fabs(l - r) < 0.0005; }); } else { +#ifdef PADDLE_WITH_HIP + is_equal = std::equal( + out_ptr, + out_ptr + numel, + expected_outs_[out_name].data(), + [](const float &l, const float &r) { return fabs(l - r) < 1e-8; }); +#else is_equal = std::equal( out_ptr, out_ptr + numel, expected_outs_[out_name].data()); - } #endif + } if (!is_equal) { all_equal = false; break;