-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support easyllm #118
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
c447c69
support easyllm
POI-WX 172ea00
add adamw
POI-WX b081169
fix format
POI-WX c0e8a20
add bert padding
POI-WX 45f23f3
Merge branch 'main' into wx/support_easyllm
yangbofun b73c668
fix bug
POI-WX 4ad4022
add Copyright
POI-WX a39598c
fix format and add rmsnorm
POI-WX 1a008b1
update
POI-WX 9c88058
share same code with internevo on adamw op
POI-WX d0c51e7
share same code with ascend speed on rms norm op
POI-WX eda8886
add test cases of rms norm op for easyllm
liujingfeng4A069 6c99b01
modify flash attention for ascend speed
liujingfeng4A069 8ecef35
modify flash attention for interevo
liujingfeng4A069 d7d723f
fix format
POI-WX f254dd9
fix according to clang tidy
POI-WX e41f1f6
Merge branch 'ljf/adapt_for_newest_diopiCustomizedFlashAttention' of …
POI-WX ad9377c
update
POI-WX 16f587a
Merge branch 'main' of https://github.com/DeepLink-org/DeepLinkExt in…
POI-WX 7f05a9f
Merge branch 'main' of https://github.com/DeepLink-org/DeepLinkExt in…
POI-WX 086be46
rename
POI-WX 07125d9
update easyllm
POI-WX f6f170e
update
POI-WX bb345a3
update
POI-WX d6823de
update
POI-WX 057df1a
fix
POI-WX bf7fffb
modify
POI-WX File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." | ||
|
||
try: | ||
from .adamw import AdamW | ||
except Exception as e: | ||
print(_not_impl.format(op_name="adamw")) | ||
from torch.optim import AdamW | ||
|
||
try: | ||
from .flash_attention import ( | ||
flash_attn_qkvpacked_func, | ||
flash_attn_kvpacked_func, | ||
flash_attn_func, | ||
flash_attn_varlen_qkvpacked_func, | ||
flash_attn_varlen_kvpacked_func, | ||
flash_attn_varlen_func, | ||
) | ||
except Exception as e: | ||
print(_not_impl.format(op_name="flash attention")) | ||
from .flash_attention_fallback import ( | ||
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, | ||
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, | ||
flash_attn_func_torch as flash_attn_func, | ||
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, | ||
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, | ||
flash_attn_varlen_func_torch as flash_attn_varlen_func, | ||
) | ||
|
||
try: | ||
from .rms_norm import rms_norm | ||
except: | ||
print( | ||
_not_impl.format(op_name="RMSNorm"), | ||
) | ||
from .rms_norm_fallback import rms_norm_torch as rms_norm | ||
|
||
from .bert_padding import pad_input, unpad_input, index_first_axis | ||
|
||
__all__ = [ | ||
"AdamW", | ||
"flash_attn_qkvpacked_func", | ||
"flash_attn_kvpacked_func", | ||
"flash_attn_func", | ||
"flash_attn_varlen_qkvpacked_func", | ||
"flash_attn_varlen_kvpacked_func", | ||
"flash_attn_varlen_func", | ||
"rms_norm", | ||
"pad_input", | ||
"unpad_input", | ||
"index_first_axis", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from deeplink_ext.interntrain_ops.adamw import AdamW | ||
|
||
|
||
__all__ = ["AdamW"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py | ||
|
||
__all__ = [ | ||
"pad_input", | ||
"unpad_input", | ||
"index_first_axis", | ||
] | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat | ||
|
||
|
||
class IndexFirstAxis(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, indices): | ||
ctx.save_for_backward(indices) | ||
assert input.ndim >= 2 | ||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | ||
second_dim = other_shape.numel() | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
# return input[indices] | ||
return torch.gather( | ||
rearrange(input, "b ... -> b (...)"), | ||
0, | ||
repeat(indices, "z -> z d", d=second_dim), | ||
).reshape(-1, *other_shape) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(indices,) = ctx.saved_tensors | ||
assert grad_output.ndim >= 2 | ||
other_shape = grad_output.shape[1:] | ||
grad_output = rearrange(grad_output, "b ... -> b (...)") | ||
grad_input = torch.zeros( | ||
[ctx.first_axis_dim, grad_output.shape[1]], | ||
device=grad_output.device, | ||
dtype=grad_output.dtype, | ||
) | ||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. | ||
# grad_input[indices] = grad_output | ||
grad_input.scatter_( | ||
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output | ||
) | ||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None | ||
|
||
|
||
index_first_axis = IndexFirstAxis.apply | ||
|
||
|
||
class IndexPutFirstAxis(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, values, indices, first_axis_dim): | ||
ctx.save_for_backward(indices) | ||
assert indices.ndim == 1 | ||
assert values.ndim >= 2 | ||
output = torch.zeros( | ||
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype | ||
) | ||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. | ||
output[indices] = values | ||
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(indices,) = ctx.saved_tensors | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
grad_values = grad_output[indices] | ||
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) | ||
return grad_values, None, None | ||
|
||
|
||
index_put_first_axis = IndexPutFirstAxis.apply | ||
|
||
|
||
class IndexFirstAxisResidual(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, indices): | ||
ctx.save_for_backward(indices) | ||
assert input.ndim >= 2 | ||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | ||
second_dim = other_shape.numel() | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
output = input[indices] | ||
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last | ||
# memory format to channel_first. In other words, input might not be contiguous. | ||
# If we don't detach, Pytorch complains about output being a view and is being modified inplace | ||
return output, input.detach() | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output, grad_residual): | ||
(indices,) = ctx.saved_tensors | ||
assert grad_output.ndim >= 2 | ||
other_shape = grad_output.shape[1:] | ||
assert grad_residual.shape[1:] == other_shape | ||
grad_input = grad_residual | ||
# grad_input[indices] += grad_output | ||
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) | ||
indices = indices.expand_as(grad_output) | ||
grad_input.scatter_add_(0, indices, grad_output) | ||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None | ||
|
||
|
||
index_first_axis_residual = IndexFirstAxisResidual.apply | ||
|
||
|
||
def unpad_input(hidden_states, attention_mask): | ||
""" | ||
Arguments: | ||
hidden_states: (batch, seqlen, ...) | ||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. | ||
Return: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | ||
max_seqlen_in_batch: int | ||
""" | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad( | ||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) | ||
) | ||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the | ||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim | ||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to | ||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be, | ||
# so we write custom forward and backward to make it a bit faster. | ||
return ( | ||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): | ||
""" | ||
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). | ||
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). | ||
|
||
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: | ||
``` | ||
[ | ||
[2, 3, 0, 0, 0, 0], | ||
[3, 2, 0, 0, 0, 0], | ||
[6, 0, 0, 0, 0, 0] | ||
] | ||
``` | ||
, which refers to the 3D-attention mask: | ||
``` | ||
[ | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1] | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[1, 1, 1, 0, 0, 0], | ||
[0, 0, 0, 1, 0, 0], | ||
[0, 0, 0, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1] | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[1, 1, 1, 0, 0, 0], | ||
[1, 1, 1, 1, 0, 0], | ||
[1, 1, 1, 1, 1, 0], | ||
[1, 1, 1, 1, 1, 1] | ||
] | ||
] | ||
```. | ||
|
||
Arguments: | ||
hidden_states: (batch, seqlen, ...) | ||
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. | ||
Return: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | ||
max_seqlen_in_batch: int | ||
""" | ||
length = attention_mask_in_length.sum(dim=-1) | ||
seqlen = attention_mask_in_length.size(-1) | ||
attention_mask_2d = torch.arange( | ||
seqlen, device=length.device, dtype=length.dtype | ||
).expand(len(length), seqlen) < length.unsqueeze(1) | ||
real_indices_idx = torch.nonzero( | ||
attention_mask_in_length.flatten(), as_tuple=False | ||
).flatten() | ||
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] | ||
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad( | ||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) | ||
) | ||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the | ||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim | ||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to | ||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be, | ||
# so we write custom forward and backward to make it a bit faster. | ||
return ( | ||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def pad_input(hidden_states, indices, batch, seqlen): | ||
""" | ||
Arguments: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. | ||
batch: int, batch size for the padded sequence. | ||
seqlen: int, maximum sequence length for the padded sequence. | ||
Return: | ||
hidden_states: (batch, seqlen, ...) | ||
""" | ||
dim = hidden_states.shape[-1] | ||
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) | ||
# output[indices] = hidden_states | ||
output = index_put_first_axis(hidden_states, indices, batch * seqlen) | ||
return rearrange(output, "(b s) ... -> b s ...", b=batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from deeplink_ext.internevo_ops.flash_attention import ( | ||
flash_attn_qkvpacked_func, | ||
flash_attn_kvpacked_func, | ||
flash_attn_func, | ||
flash_attn_varlen_qkvpacked_func, | ||
flash_attn_varlen_kvpacked_func, | ||
flash_attn_varlen_func, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"flash_attn_qkvpacked_func", | ||
"flash_attn_kvpacked_func", | ||
"flash_attn_func", | ||
"flash_attn_varlen_qkvpacked_func", | ||
"flash_attn_varlen_kvpacked_func", | ||
"flash_attn_varlen_func", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from deeplink_ext.internevo_ops.flash_attention_fallback import ( | ||
flash_attn_qkvpacked_func_torch, | ||
flash_attn_kvpacked_func_torch, | ||
flash_attn_func_torch, | ||
flash_attn_varlen_qkvpacked_func_torch, | ||
flash_attn_varlen_kvpacked_func_torch, | ||
flash_attn_varlen_func_torch, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"flash_attn_qkvpacked_func_torch", | ||
"flash_attn_kvpacked_func_torch", | ||
"flash_attn_func_torch", | ||
"flash_attn_varlen_qkvpacked_func_torch", | ||
"flash_attn_varlen_kvpacked_func_torch", | ||
"flash_attn_varlen_func_torch", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from deeplink_ext.ascend_speed.rms_norm import RMSNorm | ||
|
||
__all__ = ["rms_norm"] | ||
|
||
|
||
def rms_norm(x, weight, epsilon): | ||
return RMSNorm.apply(x, weight, epsilon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
import torch | ||
|
||
__all__ = ["rms_norm_torch"] | ||
|
||
|
||
def rms_norm_torch(x, weight, epsilon): | ||
input_dtype = x.dtype | ||
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
hidden_states = x * torch.rsqrt(variance + epsilon) | ||
|
||
return (hidden_states * weight).to(input_dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加到__init__的里面