Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Support pd_op.reduce_var with Welford algorithm in backend (demo) #71057

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

lshpku
Copy link
Contributor

@lshpku lshpku commented Feb 8, 2025

PR Category

CINN

PR Types

Improvements

Description

在CINN后端实现通过Welford算法计算方差

由于当前CINN前端还没有一个reduce_var算子,我暂时借用reduce_prod来调用

示例代码:

import numpy as np
import paddle

@paddle.jit.to_static(full_graph=True, backend='CINN')
def welford_func(x):
    return paddle.prod(x, axis=1)  # 借用prod接口来调用var算子

x = paddle.randn([128, 4096]) + 1e2

np_out = np.var(x.numpy(), axis=1)  # 以numpy的var为金标准
welford_out = welford_func(x)       # Welford算法
one_pass_out = paddle.mean(x * x, axis=1) - paddle.mean(x, axis=1) ** 2  # CINN原本的One-Pass算法

print('Validating Welford result')
np.testing.assert_allclose(np_out, welford_out.numpy(), rtol=1e-5, atol=1e-5)  # 不会报错~

print('Validating One-Pass result')
np.testing.assert_allclose(np_out, one_pass_out.numpy(), rtol=1e-5, atol=1e-5)  # 会报错!

生成的CUDA代码:

__global__
void __launch_bounds__(1024) fn_reduce_prod_yield_store___kernel(
  const float* __restrict__ var,
  float* __restrict__ var_1
) {
  welford_fp32 _var_0_rf_temp_buffer [ 1 ];
  welford_fp32 _var_0_temp_buffer [ 1 ];
  extern __shared__ uint8_t dyn_shared_buffer[];
  welford_fp32 *shm32__welford_fp32_reduce = (welford_fp32*)&dyn_shared_buffer[ 0 ];
  welford_fp32* var_0 = _var_0_temp_buffer;
  welford_fp32* var_0_rf = _var_0_rf_temp_buffer;
  welford_fp32* var_0_rf__reduce_init = _var_0_rf_temp_buffer;

  var_0_rf__reduce_init[0] = 0.00000000f;
  for (int32_t k = 0; k < 4; k += 1) {
    float var_local = var[((k * 1024) + (int)threadIdx.x) + ((int)blockIdx.x * 4096)];
    var_0_rf[0] = cinn_welford_add_fp32(var_0_rf[0], var_local);
  }
  var_0[0] = cinn_partial_block_reduce_sum_welford_fp32_internal_shm(var_0_rf[0], shm32__welford_fp32_reduce, false);
  if ((int)threadIdx.x == 0) {
    var_1[(int)blockIdx.x] = var_0[0];
  }
}

Pcard-85711

Copy link

paddle-bot bot commented Feb 8, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant