diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py index e1534637..8ca5f9e2 100644 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ b/deeplink_ext/internlm_ops/rotary/__init__.py @@ -1,16 +1,13 @@ # Copyright (c) 2024, DeepLink. try: - from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb + from .deeplink import apply_rotary except: print( "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", end="", ) - from .fallback import ( - ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_, - ApplyRotaryEmb as DeepLinkApplyRotaryEmb, - ) + from .fallback import apply_rotary from . import fallback -__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"] +__all__ = ["apply_rotary", "fallback"] diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py index d065cba8..670a47b9 100644 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -1,131 +1,65 @@ # Copyright (c) 2024, DeepLink. +from typing import Optional, Union import torch from einops import rearrange import deeplink_ext.cpp_extensions as ext assert hasattr(ext, "apply_rotary") +__all__ = ["apply_rotary"] -class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - ext.apply_rotary( - q_ro, - q_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - ext.apply_rotary( - k_ro, - k_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - interleaved = ctx.interleaved - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - ext.apply_rotary( - dq_ro, - dq_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - interleaved, - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - ext.apply_rotary( - dk_ro, - dk_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - interleaved, - ) - return dqkv, None, None, None, None, None - - -class DeepLinkApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - out = torch.empty_like(x) if not inplace else x - out_ro = out[..., :rotary_dim] - ext.apply_rotary( - out_ro, - x_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - if not inplace and rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.inplace = inplace - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - inplace = ctx.inplace - do_ro = do[..., :rotary_dim] - dx = torch.empty_like(do) if not inplace else do - dx_ro = dx[..., :rotary_dim] - ext.apply_rotary( - dx_ro, - do_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ctx.interleaved, - ) - if not inplace and rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ext.apply_rotary( + output, + x, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + conjugate, + interleaved, + ) + return output diff --git a/deeplink_ext/internlm_ops/rotary/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback.py new file mode 100644 index 00000000..7ce7c652 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/fallback.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, DeepLink. + +from typing import Optional, Union +import torch +from einops import rearrange, repeat + +__all__ = ["apply_rotary"] + + +def _rotate_half(x: torch.Tensor, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def _apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False +): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + + data_type = x.dtype + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ).to(data_type) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" + ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + if conjugate: + sin = -sin + out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved) + if inplace: + x.copy_(out) + out = x + return out diff --git a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py deleted file mode 100644 index f0166042..00000000 --- a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from .fallback import ApplyRotaryEmbQKV_, ApplyRotaryEmb - -__all__ = ["ApplyRotaryEmbQKV_", "ApplyRotaryEmb"] diff --git a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py deleted file mode 100644 index abf63af6..00000000 --- a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange - - -# Rotary_emb -# torch 绕过实现函数 -def apply_rotary(x1, x2, cos, sin, conj): - data_dtype = x1.dtype - x1 = x1.to(torch.float32) - x2 = x2.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - if not conj: - out1 = x1 * cos - x2 * sin - out2 = x1 * sin + x2 * cos - else: - out1 = x1 * cos + x2 * sin - out2 = -x1 * sin + x2 * cos - out1 = out1.to(data_dtype) - out2 = out2.to(data_dtype) - return out1, out2 - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = ( - q_ro.chunk(2, dim=-1) - if not interleaved - else (q_ro[..., ::2], q_ro[..., 1::2]) - ) - q1, q2 = apply_rotary( - q1, - q2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 0, :, :rotary_dim] = torch.cat((q1, q2), dim=-1) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = ( - k_ro.chunk(2, dim=-1) - if not interleaved - else (k_ro[..., ::2], k_ro[..., 1::2]) - ) - k1, k2 = apply_rotary( - k1, - k2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 1, :, :rotary_dim] = torch.cat((k1, k2), dim=-1) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = ( - dq_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dq_ro[..., ::2], dq_ro[..., 1::2]) - ) - dq1, dq2 = apply_rotary( - dq1, - dq2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 0, :, :rotary_dim] = torch.cat((dq1, dq2), dim=-1) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = ( - dk_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dk_ro[..., ::2], dk_ro[..., 1::2]) - ) - dk1, dk2 = apply_rotary( - dk1, - dk2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1) - return dqkv, None, None, None, None, None - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - - assert ( - interleaved == False - ), "Interleaved rotary embedding fallback is not supported yet" - assert ( - inplace == False - ), "Inplace rotary embedding fallback is not supported yet" - - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = ( - x_ro.chunk(2, dim=-1) - if not interleaved - else (x_ro[..., ::2], x_ro[..., 1::2]) - ) - out = torch.empty_like(x) - - o1, o2 = apply_rotary( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - - out[..., :rotary_dim] = torch.cat((o1, o2), dim=-1) - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.inplace = inplace - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = ( - do_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (do_ro[..., ::2], do_ro[..., 1::2]) - ) - dx = torch.empty_like(do) - - dx1, dx2 = apply_rotary( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - - dx[..., :rotary_dim] = torch.cat((dx1, dx2), dim=-1) - - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index 6e4e0a9f..d5ac32b2 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -71,34 +71,33 @@ def CrossEntropyLossProxy(reduction, **_): def _patch_ops(): import deeplink_ext.internlm_ops as ext + import flash_attn.layers.rotary # type: ignore import internlm.model.embedding # type: ignore + flash_attn.layers.rotary.apply_rotary = ext.rotary.apply_rotary + class NonLegacyRotaryEmbQKV_(torch.autograd.Function): """the first 2 dims of qkv has been squeezed""" @staticmethod def forward(ctx, qkv: torch.Tensor, *args, **kwargs): # type: ignore unsqueezed_qkv = qkv.view([1] + list(qkv.shape)) - out: torch.Tensor = ext.rotary.DeepLinkApplyRotaryEmbQKV_.forward( - ctx, unsqueezed_qkv, *args, **kwargs + out: torch.Tensor = ( + internlm.model.embedding.LegacyApplyRotaryEmbQKV_.forward( + ctx, unsqueezed_qkv, *args, **kwargs + ) ) return out.view(out.shape[1:]) @staticmethod def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore unqueezed_dqkv = dqkv.view([1] + list(dqkv.shape)) - out: tuple = ext.rotary.DeepLinkApplyRotaryEmbQKV_.backward( + out: tuple = internlm.model.embedding.LegacyApplyRotaryEmbQKV_.backward( ctx, unqueezed_dqkv, *args, **kwargs ) return (out[0].view(out[0].shape[1:]),) + out[1:] internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply - internlm.model.embedding.legacy_apply_rotary_embed = ( - ext.rotary.DeepLinkApplyRotaryEmb.apply - ) - internlm.model.embedding.legacy_apply_rotary_embed_qkv = ( - ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply - ) import internlm.model.norm # type: ignore diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index 7cc4b5e2..b84ea165 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -4,77 +4,21 @@ import deeplink_ext.internlm_ops.rotary as ext -def RotaryEmbTest(func_name): - if func_name == "RotaryEmbQKV": - torch_apply = ext.fallback.ApplyRotaryEmbQKV_.apply - dipu_apply = ext.DeepLinkApplyRotaryEmbQKV_.apply - input = torch.randn( - 1, 125, 3, 16, 32, dtype=torch.float16, requires_grad=True - ).cuda() - elif func_name == "RotaryEmb": - torch_apply = ext.fallback.ApplyRotaryEmb.apply - dipu_apply = ext.DeepLinkApplyRotaryEmb.apply - input = torch.randn( - 1, 125, 16, 32, dtype=torch.float16, requires_grad=True - ).cuda() - else: - print(f"{func_name} is not supported.") - return False +def RotaryEmbTest() -> bool: + input = torch.randn(1, 125, 16, 32, dtype=torch.float16).cuda() - loss_fn = torch.nn.MSELoss() - cos = torch.randn(257, 16, dtype=torch.float16).cuda() - sin = torch.randn(257, 16, dtype=torch.float16).cuda() + cos = torch.randn(217, 16, dtype=torch.float16).cuda() + sin = torch.randn(217, 16, dtype=torch.float16).cuda() input1 = input.detach().clone() - input1.requires_grad = True - cos1 = cos.clone() - sin1 = sin.clone() - cos_k = None - sin_k = None + inplace = True interleaved = False - # 调用前向传播 - if func_name == "RotaryEmbQKV": - res1 = torch_apply(input, cos, sin, cos_k, sin_k, interleaved) - res2 = dipu_apply(input1, cos1, sin1, cos_k, sin_k, interleaved) - elif func_name == "RotaryEmb": - res1 = torch_apply(input, cos, sin, interleaved) - res2 = dipu_apply(input1, cos1, sin1, interleaved) - else: - print(f"{func_name} is not supported.") - return False + res1 = ext.fallback.apply_rotary( + input, cos, sin, interleaved=interleaved, inplace=inplace + ) + res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) - # 验证前向传播结果 - forward_correct = torch.allclose(res1, res2) + return torch.allclose(res1, res2) - # 计算第一个损失 - c = torch.ones_like(res1) - loss1 = loss_fn(res1, c) # 将输出的元素求和,得到标量 - input.retain_grad() - loss1.backward() - # 计算第二个损失 - c2 = torch.ones_like(res1) - loss2 = loss_fn(res2, c2) # 将输出的元素求和,得到标量 - input1.retain_grad() - loss2.backward() - - # 验证第一个反向传播梯度 - grad1 = input.grad - grad2 = input1.grad - backward_correct = torch.allclose(grad1, grad2) - # 判断前向和反向传播结果是否都正确 - if forward_correct and backward_correct: - print(f"{func_name} both forward and backward pass tests passed.") - return True - else: - print( - f"{func_name} tests failed: Forward pass:", - forward_correct, - "Backward pass:", - backward_correct, - ) - return False - - -assert RotaryEmbTest("RotaryEmbQKV") -assert RotaryEmbTest("RotaryEmb") +assert RotaryEmbTest()