diff --git a/deeplink_ext/easyllm_ops/__init__.py b/deeplink_ext/easyllm_ops/__init__.py new file mode 100644 index 0000000..439bd0b --- /dev/null +++ b/deeplink_ext/easyllm_ops/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, DeepLink. + +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +try: + from .adamw import AdamW +except Exception as e: + print(_not_impl.format(op_name="adamw")) + from torch.optim import AdamW + +try: + from .flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) +except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .flash_attention_fallback import ( + flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, + flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, + flash_attn_func_torch as flash_attn_func, + flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func_torch as flash_attn_varlen_func, + ) + +try: + from .rms_norm import rms_norm +except: + print( + _not_impl.format(op_name="RMSNorm"), + ) + from .rms_norm_fallback import rms_norm_torch as rms_norm + +from .bert_padding import pad_input, unpad_input, index_first_axis + +__all__ = [ + "AdamW", + "flash_attn_qkvpacked_func", + "flash_attn_kvpacked_func", + "flash_attn_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_attn_varlen_func", + "rms_norm", + "pad_input", + "unpad_input", + "index_first_axis", +] diff --git a/deeplink_ext/easyllm_ops/adamw.py b/deeplink_ext/easyllm_ops/adamw.py new file mode 100644 index 0000000..0bf67f5 --- /dev/null +++ b/deeplink_ext/easyllm_ops/adamw.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.interntrain_ops.adamw import AdamW + + +__all__ = ["AdamW"] diff --git a/deeplink_ext/easyllm_ops/bert_padding.py b/deeplink_ext/easyllm_ops/bert_padding.py new file mode 100644 index 0000000..ca59808 --- /dev/null +++ b/deeplink_ext/easyllm_ops/bert_padding.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024, DeepLink. +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +__all__ = [ + "pad_input", + "unpad_input", + "index_first_axis", +] + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_( + 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output + ) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange( + seqlen, device=length.device, dtype=length.dtype + ).expand(len(length), seqlen) < length.unsqueeze(1) + real_indices_idx = torch.nonzero( + attention_mask_in_length.flatten(), as_tuple=False + ).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/deeplink_ext/easyllm_ops/flash_attention.py b/deeplink_ext/easyllm_ops/flash_attention.py new file mode 100644 index 0000000..e352e2b --- /dev/null +++ b/deeplink_ext/easyllm_ops/flash_attention.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.internevo_ops.flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) + + +__all__ = [ + "flash_attn_qkvpacked_func", + "flash_attn_kvpacked_func", + "flash_attn_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_attn_varlen_func", +] diff --git a/deeplink_ext/easyllm_ops/flash_attention_fallback.py b/deeplink_ext/easyllm_ops/flash_attention_fallback.py new file mode 100644 index 0000000..e781ae1 --- /dev/null +++ b/deeplink_ext/easyllm_ops/flash_attention_fallback.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.internevo_ops.flash_attention_fallback import ( + flash_attn_qkvpacked_func_torch, + flash_attn_kvpacked_func_torch, + flash_attn_func_torch, + flash_attn_varlen_qkvpacked_func_torch, + flash_attn_varlen_kvpacked_func_torch, + flash_attn_varlen_func_torch, +) + + +__all__ = [ + "flash_attn_qkvpacked_func_torch", + "flash_attn_kvpacked_func_torch", + "flash_attn_func_torch", + "flash_attn_varlen_qkvpacked_func_torch", + "flash_attn_varlen_kvpacked_func_torch", + "flash_attn_varlen_func_torch", +] diff --git a/deeplink_ext/easyllm_ops/rms_norm.py b/deeplink_ext/easyllm_ops/rms_norm.py new file mode 100644 index 0000000..02d13c6 --- /dev/null +++ b/deeplink_ext/easyllm_ops/rms_norm.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.ascend_speed.rms_norm import RMSNorm + +__all__ = ["rms_norm"] + + +def rms_norm(x, weight, epsilon): + return RMSNorm.apply(x, weight, epsilon) diff --git a/deeplink_ext/easyllm_ops/rms_norm_fallback.py b/deeplink_ext/easyllm_ops/rms_norm_fallback.py new file mode 100644 index 0000000..80e9594 --- /dev/null +++ b/deeplink_ext/easyllm_ops/rms_norm_fallback.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, DeepLink. + +import torch + +__all__ = ["rms_norm_torch"] + + +def rms_norm_torch(x, weight, epsilon): + input_dtype = x.dtype + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = x * torch.rsqrt(variance + epsilon) + + return (hidden_states * weight).to(input_dtype) diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index 4f6b045..b2c86f6 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -20,12 +20,12 @@ except Exception as e: print(_not_impl.format(op_name="flash attention")) from .flash_attention_fallback import ( - torch_attn_qkvpacked_func as flash_attn_qkvpacked_func, - torch_attn_kvpacked_func as flash_attn_kvpacked_func, - torch_attn_func as flash_attn_func, - torch_attn_varlen_qkvpacked_func as flash_attn_varlen_qkvpacked_func, - torch_attn_varlen_kvpacked_func as flash_attn_varlen_kvpacked_func, - torch_attn_varlen_func as flash_attn_varlen_func, + flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, + flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, + flash_attn_func_torch as flash_attn_func, + flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func_torch as flash_attn_varlen_func, ) try: diff --git a/deeplink_ext/internevo_ops/flash_attention_fallback.py b/deeplink_ext/internevo_ops/flash_attention_fallback.py index 6c5e7f2..02f1ead 100644 --- a/deeplink_ext/internevo_ops/flash_attention_fallback.py +++ b/deeplink_ext/internevo_ops/flash_attention_fallback.py @@ -6,12 +6,12 @@ __all__ = [ - "torch_attn_qkvpacked_func", - "torch_attn_kvpacked_func", - "torch_attn_func", - "torch_attn_varlen_qkvpacked_func", - "torch_attn_varlen_kvpacked_func", - "torch_attn_varlen_func", + "flash_attn_qkvpacked_func_torch", + "flash_attn_kvpacked_func_torch", + "flash_attn_func_torch", + "flash_attn_varlen_qkvpacked_func_torch", + "flash_attn_varlen_kvpacked_func_torch", + "flash_attn_varlen_func_torch", ] @@ -66,7 +66,7 @@ def _pack_output_after_attn( return output -def torch_attn_qkvpacked_func( +def flash_attn_qkvpacked_func_torch( qkv, dropout_p=0.0, softmax_scale=None, @@ -95,7 +95,7 @@ def torch_attn_qkvpacked_func( return output -def torch_attn_kvpacked_func( +def flash_attn_kvpacked_func_torch( q, kv, dropout_p=0.0, @@ -131,7 +131,7 @@ def torch_attn_kvpacked_func( return output -def torch_attn_func( +def flash_attn_func_torch( q, k, v, @@ -144,7 +144,7 @@ def torch_attn_func( return_attn_probs=False, ): kv = torch.stack([k, v], dim=2) - return torch_attn_kvpacked_func( + return flash_attn_kvpacked_func_torch( q, kv, dropout_p, @@ -157,7 +157,7 @@ def torch_attn_func( ) -def torch_attn_varlen_qkvpacked_func( +def flash_attn_varlen_qkvpacked_func_torch( qkv, cu_seqlens, max_seqlen, @@ -171,7 +171,7 @@ def torch_attn_varlen_qkvpacked_func( ): packed_length = qkv.size(dim=0) qkv = _unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) - output = torch_attn_qkvpacked_func( + output = flash_attn_qkvpacked_func_torch( qkv, dropout_p, softmax_scale, @@ -184,7 +184,7 @@ def torch_attn_varlen_qkvpacked_func( return _pack_output_after_attn(output, cu_seqlens, packed_length) -def torch_attn_varlen_kvpacked_func( +def flash_attn_varlen_kvpacked_func_torch( q, kv, cu_seqlens_q, @@ -202,7 +202,7 @@ def torch_attn_varlen_kvpacked_func( packed_length = q.size(dim=0) q = _unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) kv = _unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) - output = torch_attn_kvpacked_func( + output = flash_attn_kvpacked_func_torch( q, kv, dropout_p, @@ -216,7 +216,7 @@ def torch_attn_varlen_kvpacked_func( return _pack_output_after_attn(output, cu_seqlens_q, packed_length) -def torch_attn_varlen_func( +def flash_attn_varlen_func_torch( q, k, v, @@ -237,7 +237,7 @@ def torch_attn_varlen_func( kv = torch.stack([k, v], dim=1) q = _unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) kv = _unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) - output = torch_attn_kvpacked_func( + output = flash_attn_kvpacked_func_torch( q, kv, dropout_p, diff --git a/tests/core.py b/tests/core.py index f08939e..ffa1ca1 100644 --- a/tests/core.py +++ b/tests/core.py @@ -4,16 +4,16 @@ from typing import Callable, Any __all__ = [ + "calculate_fwd_and_bwd", "call_module", "call_autograd_func", - "call_normal_func", "copy_to_cpu", "allclose", ] -def call_module(module: torch.nn.Module, *forward_args): - output_forward = module(*forward_args) +def calculate_fwd_and_bwd(func: Callable[..., Any], *args: tuple, **kwargs: dict): + output_forward = func(*args, **kwargs) grads = [] if torch.is_tensor(output_forward): output_forward.backward(torch.ones_like(output_forward)) @@ -24,13 +24,19 @@ def call_module(module: torch.nn.Module, *forward_args): raise RuntimeError( "the result of forward is not a tensor or list or tuple of tensor" ) - for arg in forward_args: + for arg in args: if torch.is_tensor(arg) and arg.requires_grad: grads.append(arg.grad) return output_forward, grads -def call_autograd_func(f: torch.autograd.Function, device, dtype, *args: tuple): +def call_module(module: torch.nn.Module, *args: tuple, **kwargs: dict): + return calculate_fwd_and_bwd(module, *args, **kwargs) + + +def call_autograd_func( + autograd_func: torch.autograd.Function, device, dtype, *args: tuple, **kwargs: dict +): class Module(torch.nn.Module): def __init__(self, func): super(Module, self).__init__() @@ -39,25 +45,7 @@ def __init__(self, func): def forward(self, *args): return self.func.apply(*args) - return call_module(Module(f).to(device).to(dtype), *args) - - -def call_normal_func(func: Callable[..., Any], *args: tuple, **kwargs: dict): - output_forward = func(*args, **kwargs) - grads = [] - if torch.is_tensor(output_forward): - output_forward.backward(torch.ones_like(output_forward)) - elif isinstance(output_forward, (list, tuple)): - assert torch.is_tensor(output_forward[0]), "output_forward[0] is not a tensor" - output_forward[0].backward(torch.ones_like(output_forward[0])) - else: - raise RuntimeError( - "the result of forward is not a tensor or list or tuple of tensor" - ) - for arg in args: - if torch.is_tensor(arg) and arg.requires_grad: - grads.append(arg.grad) - return output_forward, grads + return call_module(Module(autograd_func).to(device).to(dtype), *args, **kwargs) def copy_to_cpu(tensors: list[torch.Tensor], dtype=None): @@ -80,8 +68,8 @@ def allclose(expected_vals: list, real_vals: list, rtol=1e-05, atol=1e-08): if isinstance(expected_vals[i], torch.Tensor): assert isinstance(real_vals[i], torch.Tensor) return torch.allclose( - expected_vals[i].to(real_vals[i].dtype).cpu(), - real_vals[i].cpu(), + expected_vals[i].cpu().to(torch.float32), + real_vals[i].cpu().to(torch.float32), rtol, atol, ) diff --git a/tests/easyllm/test_rms_norm.py b/tests/easyllm/test_rms_norm.py new file mode 100644 index 0000000..1273b15 --- /dev/null +++ b/tests/easyllm/test_rms_norm.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, DeepLink. + +import torch +from tests.core import calculate_fwd_and_bwd, allclose +from deeplink_ext.easyllm_ops.rms_norm import rms_norm +from deeplink_ext.easyllm_ops.rms_norm_fallback import rms_norm_torch + + +def test_rms_norm(): + input_dtype_list = [torch.float16, torch.bfloat16] + weight_dtype_list = [torch.float16, torch.bfloat16] + for input_dtype, weight_dtype in zip(input_dtype_list, weight_dtype_list): + hidden_states_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + hidden_states_ext = hidden_states_ref.clone().detach().requires_grad_(True) + + weight_ref = torch.nn.Parameter( + torch.ones( + list(hidden_states_ref.shape)[-1], dtype=weight_dtype, device="cuda" + ), + requires_grad=True, + ) + weight_ext = weight_ref.clone().detach().requires_grad_(True) + + epsilon = 1e-5 + + output_ref, grad_ref = calculate_fwd_and_bwd( + rms_norm_torch, + hidden_states_ref, + weight_ref, + epsilon, + ) + + output_ext, grad_ext = calculate_fwd_and_bwd( + rms_norm, + hidden_states_ext, + weight_ext, + epsilon, + ) + + assert allclose( + output_ref, output_ext, rtol=1e-05, atol=1e-5 + ), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, RMSNorm fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext, rtol=1e-2, atol=1e-2 + ), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, RMSNorm fails to pass the backward test!" diff --git a/tests/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py index bdfc897..5126551 100644 --- a/tests/internevo/test_flash_attention.py +++ b/tests/internevo/test_flash_attention.py @@ -1,12 +1,12 @@ # Copyright (c) 2024, DeepLink. import torch -from tests.core import copy_to_cpu, allclose, call_normal_func +from tests.core import copy_to_cpu, allclose, calculate_fwd_and_bwd from deeplink_ext.internevo_ops.flash_attention_fallback import ( - torch_attn_qkvpacked_func, - torch_attn_kvpacked_func, - torch_attn_func, + flash_attn_qkvpacked_func_torch, + flash_attn_kvpacked_func_torch, + flash_attn_func_torch, ) from deeplink_ext.internevo_ops.flash_attention import ( flash_attn_qkvpacked_func, @@ -31,13 +31,13 @@ def test_flash_attn_qkvpacked_func_mha(): ] ) - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_qkvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_qkvpacked_func_torch, qkv_cpu[0], dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_qkvpacked_func, qkv_gpu, dropout_p=0.0, @@ -66,14 +66,14 @@ def test_flash_attn_kvpacked_func_gqa(): ) q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_kvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_kvpacked_func_torch, q_cpu, kv_cpu, dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_kvpacked_func, q_gpu, kv_gpu, @@ -109,15 +109,15 @@ def test_flash_attn_func_gqa(): ) q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_func_torch, q_cpu, k_cpu, v_cpu, dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_func, q_gpu, k_gpu, diff --git a/tests/internevo/test_varlen_flash_attention.py b/tests/internevo/test_varlen_flash_attention.py index 6f0f20f..97b8d64 100644 --- a/tests/internevo/test_varlen_flash_attention.py +++ b/tests/internevo/test_varlen_flash_attention.py @@ -1,12 +1,12 @@ # Copyright (c) 2024, DeepLink. import torch -from tests.core import allclose, call_normal_func, copy_to_cpu +from tests.core import allclose, calculate_fwd_and_bwd, copy_to_cpu from deeplink_ext.internevo_ops.flash_attention_fallback import ( - torch_attn_varlen_qkvpacked_func, - torch_attn_varlen_kvpacked_func, - torch_attn_varlen_func, + flash_attn_varlen_qkvpacked_func_torch, + flash_attn_varlen_kvpacked_func_torch, + flash_attn_varlen_func_torch, ) from deeplink_ext.internevo_ops.flash_attention import ( flash_attn_varlen_qkvpacked_func, @@ -46,15 +46,15 @@ def test_flash_attn_varlen_qkvpacked_func_mha(): ) max_seqlen = 128 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_qkvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_qkvpacked_func_torch, qkv_cpu[0], cu_seqlens_cpu, max_seqlen, dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_qkvpacked_func, qkv_gpu, cu_seqlens_gpu, @@ -90,15 +90,15 @@ def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): # the maximum sequence length is 4096 max_seqlen = 4096 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_qkvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_qkvpacked_func_torch, qkv_cpu[0], cu_seqlens_cpu, max_seqlen, dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_qkvpacked_func, qkv_gpu, cu_seqlens_gpu, @@ -140,8 +140,8 @@ def test_flash_attn_varlen_kvpacked_func_gqa(): max_seqlen_q = 128 max_seqlen_k = 128 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_kvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_kvpacked_func_torch, q_cpu, kv_cpu, cu_seqlens_q_cpu, @@ -151,7 +151,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa(): dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_kvpacked_func, q_gpu, kv_gpu, @@ -198,8 +198,8 @@ def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): max_seqlen_q = 4096 max_seqlen_k = 4096 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_kvpacked_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_kvpacked_func_torch, q_cpu, kv_cpu, cu_seqlens_q_cpu, @@ -209,7 +209,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_kvpacked_func, q_gpu, kv_gpu, @@ -260,8 +260,8 @@ def test_flash_attn_varlen_func_gqa(): max_seqlen_q = 128 max_seqlen_k = 128 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_func_torch, q_cpu, k_cpu, v_cpu, @@ -272,7 +272,7 @@ def test_flash_attn_varlen_func_gqa(): dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_func, q_gpu, k_gpu, @@ -326,8 +326,8 @@ def test_flash_attn_varlen_func_gqa_long_max_seqlen(): max_seqlen_q = 4096 max_seqlen_k = 4096 - ouput_forward_cpu, grads_cpu = call_normal_func( - torch_attn_varlen_func, + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( + flash_attn_varlen_func_torch, q_cpu, k_cpu, v_cpu, @@ -338,7 +338,7 @@ def test_flash_attn_varlen_func_gqa_long_max_seqlen(): dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_varlen_func, q_gpu, k_gpu,