diff --git a/deeplink_ext/ascend_speed/_flash_attention_dipu.py b/deeplink_ext/ascend_speed/_flash_attention_dipu.py index e5ee61d..d6f3b41 100644 --- a/deeplink_ext/ascend_speed/_flash_attention_dipu.py +++ b/deeplink_ext/ascend_speed/_flash_attention_dipu.py @@ -9,7 +9,6 @@ class FlashSelfAttention(torch.autograd.Function): - @staticmethod def forward( ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout diff --git a/deeplink_ext/ascend_speed/_rms_norm_dipu.py b/deeplink_ext/ascend_speed/_rms_norm_dipu.py index 16d3502..7d4c237 100644 --- a/deeplink_ext/ascend_speed/_rms_norm_dipu.py +++ b/deeplink_ext/ascend_speed/_rms_norm_dipu.py @@ -9,7 +9,6 @@ class RMSNorm(torch.autograd.Function): - @staticmethod def forward(ctx, hidden_states, weight, eps): output = torch.empty_like(hidden_states) diff --git a/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py b/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py index b20f7ee..47f324d 100644 --- a/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py +++ b/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py @@ -11,7 +11,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod def forward(ctx, input, mask, scale, fixed_triu_mask): out = torch.empty_like(input) diff --git a/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py b/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py index b02e3c8..a4a1d06 100644 --- a/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py +++ b/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py @@ -7,7 +7,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod def forward(ctx, input, mask, scale, fixed_triu_mask): out = torch_npu.npu_scaled_masked_softmax(input, mask, scale, fixed_triu_mask) diff --git a/deeplink_ext/ops/adamw/_adamw_dipu.py b/deeplink_ext/ops/adamw/_adamw_dipu.py index 25a7b67..855c084 100644 --- a/deeplink_ext/ops/adamw/_adamw_dipu.py +++ b/deeplink_ext/ops/adamw/_adamw_dipu.py @@ -60,7 +60,6 @@ def fused_adamw( class AdamW(Optimizer): - def __init__( self, params, diff --git a/deeplink_ext/ops/bert_padding/__init__.py b/deeplink_ext/ops/bert_padding/__init__.py index b6fbd3f..246dc2c 100644 --- a/deeplink_ext/ops/bert_padding/__init__.py +++ b/deeplink_ext/ops/bert_padding/__init__.py @@ -1,2 +1,3 @@ from .bert_padding import pad_input, unpad_input, index_first_axis + __all__ = ["pad_input", "unpad_input", "index_first_axis"] diff --git a/deeplink_ext/ops/flash_attention/__init__.py b/deeplink_ext/ops/flash_attention/__init__.py index 0b5011e..486b9c3 100644 --- a/deeplink_ext/ops/flash_attention/__init__.py +++ b/deeplink_ext/ops/flash_attention/__init__.py @@ -24,5 +24,9 @@ from .interntrain_flash_attention import FlashSelfAttention, FlashCrossAttention except Exception as e: print(_not_impl.format(op_name="flash attention")) - from .interntrain_flash_attention_fallback import SelfAttention as FlashSelfAttention - from .interntrain_flash_attention_fallback import CrossAttention as FlashCrossAttention + from .interntrain_flash_attention_fallback import ( + SelfAttention as FlashSelfAttention, + ) + from .interntrain_flash_attention_fallback import ( + CrossAttention as FlashCrossAttention, + ) diff --git a/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py index 03be3ba..09dff62 100644 --- a/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py @@ -22,7 +22,6 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -108,7 +107,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -254,7 +252,6 @@ def flash_attn_qkvpacked_func( class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -344,7 +341,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -498,7 +494,6 @@ def flash_attn_kvpacked_func( class FlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -592,7 +587,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -753,7 +747,6 @@ def flash_attn_func( class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -850,7 +843,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1007,7 +999,6 @@ def flash_attn_varlen_qkvpacked_func( class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1112,7 +1103,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1288,7 +1278,6 @@ def flash_attn_varlen_kvpacked_func( class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1414,7 +1403,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py index 2e63aea..0ba83c6 100644 --- a/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py @@ -4,7 +4,10 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_DIPU: - from .interntrain_flash_attention_dipu import FlashSelfAttention, FlashCrossAttention + from .interntrain_flash_attention_dipu import ( + FlashSelfAttention, + FlashCrossAttention, + ) else: raise ImportError diff --git a/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py index a4f59d2..5b3822a 100644 --- a/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py @@ -16,7 +16,6 @@ class CustomizedFlashAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -206,7 +205,6 @@ def backward(ctx, dout): class FlashAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -359,7 +357,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -560,7 +557,6 @@ def backward(ctx, dout): class FlashAttentionVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -738,7 +734,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod def forward(ctx, q, kv, dropout_p, softmax_scale, causal): assert q.device == kv.device, "the devices of q and kv should be same" @@ -842,7 +837,6 @@ def backward(ctx, dout): class FlashAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod def forward(ctx, q, kv, dropout_p, softmax_scale, causal): assert q.device == kv.device, "the devices of q and kv should be same" @@ -920,7 +914,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1045,7 +1038,6 @@ def backward(ctx, dout): class FlashAttentionVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1266,7 +1258,6 @@ def forward( class FlashCrossAttention(nn.Module): - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal diff --git a/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py b/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py index 16d3502..7d4c237 100644 --- a/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py +++ b/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py @@ -9,7 +9,6 @@ class RMSNorm(torch.autograd.Function): - @staticmethod def forward(ctx, hidden_states, weight, eps): output = torch.empty_like(hidden_states) diff --git a/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py b/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py index 3782e85..74118cd 100644 --- a/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py +++ b/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py @@ -14,7 +14,6 @@ # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class _MixedFusedRMSNormFunction(torch.autograd.Function): - @staticmethod def forward(ctx, hidden_states, weight, eps, normalized_shape): # ascend currently does not support dtype of hidden_states with higher precision than weight. @@ -94,7 +93,6 @@ def backward(ctx, grad_output): class MixedFusedRMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): # TODO: Further optimization when there are device and dtype available. # factory_kwargs = {"device": device, "dtype": dtype} diff --git a/deeplink_ext/ops/rotary_embedding/__init__.py b/deeplink_ext/ops/rotary_embedding/__init__.py index e023c57..b152832 100644 --- a/deeplink_ext/ops/rotary_embedding/__init__.py +++ b/deeplink_ext/ops/rotary_embedding/__init__.py @@ -4,11 +4,17 @@ from .internevo_rotary_embedding import ApplyRotaryEmb except: print(_not_impl.format(op_name="rotary embedding")) - from .internevo_rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from .internevo_rotary_embedding_fallback import ( + ApplyRotaryEmbTorch as ApplyRotaryEmb, + ) try: from .interntrain_rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ except: print(_not_impl.format(op_name="rotary embedding")) - from .interntrain_rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb - from .interntrain_rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_ + from .interntrain_rotary_embedding_fallback import ( + ApplyRotaryEmbTorch as ApplyRotaryEmb, + ) + from .interntrain_rotary_embedding_fallback import ( + ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_, + )