diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index 54c6e1dfba0215..3995cd4a4087d0 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -15,7 +15,7 @@ import paddle from paddle import _C_ops -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import LayerHelper, in_dynamic_or_pir_mode def fused_rms_norm( @@ -63,7 +63,7 @@ def fused_rms_norm( epsilon = 1e-6 paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.rms_norm( x, bias, diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index cd9fa001e83628..79e20e906d92ce 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -342,6 +342,49 @@ def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): ) return out_s[0], paddle_naive_rmsnorm_out + def test_rmsnorm_pir(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np.astype("float32")) + gamma = paddle.to_tensor(self.norm_weight_np.astype("float32")) + beta = paddle.to_tensor(self.norm_bias_np.astype("float32")) + + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + + with paddle.pir_utils.IrGuard(): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype="float32" + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype="float32" + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype="float32" + ) + out, _ = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + ) + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": self.x_np.astype("float32"), + "gamma_static": self.norm_weight_np.astype("float32"), + "beta_static": self.norm_bias_np.astype("float32"), + }, + fetch_list=[out], + ) + + np.testing.assert_allclose( + out_s[0], + paddle_naive_rmsnorm_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): paddle.disable_static() x = paddle.to_tensor(x_np.astype(dtype))