Skip to content

Commit

Permalink
[NewIR] No.9 Migrate rms_norm into pir (PaddlePaddle#57156)
Browse files Browse the repository at this point in the history
* [NewIR] No.9 Migrate rms_norm into pir

* update test

* update

* update

* fix bug
  • Loading branch information
GreatV authored Sep 25, 2023
1 parent ea090ff commit 2998ba2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/incubate/nn/functional/fused_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode


def fused_rms_norm(
Expand Down Expand Up @@ -63,7 +63,7 @@ def fused_rms_norm(
epsilon = 1e-6
paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.rms_norm(
x,
bias,
Expand Down
43 changes: 43 additions & 0 deletions test/legacy_test/test_rms_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,49 @@ def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype):
)
return out_s[0], paddle_naive_rmsnorm_out

def test_rmsnorm_pir(self):
paddle.disable_static()
x = paddle.to_tensor(self.x_np.astype("float32"))
gamma = paddle.to_tensor(self.norm_weight_np.astype("float32"))
beta = paddle.to_tensor(self.norm_bias_np.astype("float32"))

paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon)
paddle.enable_static()

with paddle.pir_utils.IrGuard():
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype="float32"
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype="float32"
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype="float32"
)
out, _ = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = base.Executor(self.place)
out_s = exe.run(
feed={
"x_static": self.x_np.astype("float32"),
"gamma_static": self.norm_weight_np.astype("float32"),
"beta_static": self.norm_bias_np.astype("float32"),
},
fetch_list=[out],
)

np.testing.assert_allclose(
out_s[0],
paddle_naive_rmsnorm_out.numpy(),
rtol=1e-3,
atol=1e-3,
)

def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
Expand Down

0 comments on commit 2998ba2

Please sign in to comment.