From d8f0ddc21d64a1205b779898e1499befa2f1a332 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 11 Sep 2024 20:05:04 +0800 Subject: [PATCH] rms_norm and rotary embedding back to combined impl --- deeplink_ext/internevo_ops/rotary_embedding.py | 3 ++- deeplink_ext/interntrain_ops/rms_norm.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 7764b9b..1a2a36d 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,7 +4,8 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - from ._rotary_embedding_npu import ApplyRotaryEmb + # from ._rotary_embedding_npu import ApplyRotaryEmb + from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py index e6834cb..e6e3f06 100644 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ b/deeplink_ext/interntrain_ops/rms_norm.py @@ -4,7 +4,8 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - from ._mixed_rms_norm_npu import MixedFusedRMSNorm + # from ._mixed_rms_norm_npu import MixedFusedRMSNorm + from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm elif platform_type == PlatformType.TORCH_DIPU: from ._mixed_rms_norm_dipu import MixedFusedRMSNorm else: