diff --git a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py index 646f7e9..a8f8226 100644 --- a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py +++ b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py @@ -27,4 +27,8 @@ def forward(ctx, x, cos, sin): @staticmethod def backward(ctx, grad_output): out, cos, sin = ctx.saved_tensors - return torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], None, None + return ( + torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], + None, + None + ) diff --git a/tests/conftest.py b/tests/conftest.py index cd2aae5..6d3a6f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + @pytest.fixture(scope='session', autouse=True) def import_module(): platform = deeplink_ext_get_platform_type()