From 1c23ede23e1dafc8f689927779fe0268c8ad99a3 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:53:08 +0000 Subject: [PATCH] fix py format --- deeplink_ext/ascend_speed/_rotary_embedding_npu.py | 6 +++++- tests/conftest.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py index 646f7e9..a8f8226 100644 --- a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py +++ b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py @@ -27,4 +27,8 @@ def forward(ctx, x, cos, sin): @staticmethod def backward(ctx, grad_output): out, cos, sin = ctx.saved_tensors - return torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], None, None + return ( + torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], + None, + None + ) diff --git a/tests/conftest.py b/tests/conftest.py index cd2aae5..6d3a6f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + @pytest.fixture(scope='session', autouse=True) def import_module(): platform = deeplink_ext_get_platform_type()