Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support easyllm #118

Merged
merged 27 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions deeplink_ext/easyllm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024, DeepLink.

_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."

try:
from .adamw import AdamW
except Exception as e:
print(_not_impl.format(op_name="adamw"))
from torch.optim import AdamW

try:
from .flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .flash_attention_fallback import (
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
flash_attn_func_torch as flash_attn_func,
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func_torch as flash_attn_varlen_func,
)

try:
from .rms_norm import rms_norm
except:
print(
_not_impl.format(op_name="RMSNorm"),
)
from .rms_norm_fallback import rms_norm_torch as rms_norm

__all__ = [
"AdamW",
"flash_attn_qkvpacked_func",
"flash_attn_kvpacked_func",
"flash_attn_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_attn_varlen_func",
"rms_norm",
]
6 changes: 6 additions & 0 deletions deeplink_ext/easyllm_ops/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2024, DeepLink.

from deeplink_ext.interntrain_ops.adamw import AdamW


__all__ = ["AdamW"]
226 changes: 226 additions & 0 deletions deeplink_ext/easyllm_ops/bert_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2024, DeepLink.
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py

import torch
import torch.nn.functional as F
from einops import rearrange, repeat


class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"),
0,
repeat(indices, "z -> z d", d=second_dim),
).reshape(-1, *other_shape)

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis = IndexFirstAxis.apply


class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None


index_put_first_axis = IndexPutFirstAxis.apply


class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()

@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis_residual = IndexFirstAxisResidual.apply


def unpad_input(hidden_states, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加到__init__的里面

"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).

For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.

Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(
seqlen, device=length.device, dtype=length.dtype
).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(
attention_mask_in_length.flatten(), as_tuple=False
).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
20 changes: 20 additions & 0 deletions deeplink_ext/easyllm_ops/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024, DeepLink.

from deeplink_ext.internevo_ops.flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)


__all__ = [
"flash_attn_qkvpacked_func",
"flash_attn_kvpacked_func",
"flash_attn_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_attn_varlen_func",
]
20 changes: 20 additions & 0 deletions deeplink_ext/easyllm_ops/flash_attention_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024, DeepLink.

from deeplink_ext.internevo_ops.flash_attention_fallback import (
flash_attn_qkvpacked_func_torch,
flash_attn_kvpacked_func_torch,
flash_attn_func_torch,
flash_attn_varlen_qkvpacked_func_torch,
flash_attn_varlen_kvpacked_func_torch,
flash_attn_varlen_func_torch,
)


__all__ = [
"flash_attn_qkvpacked_func_torch",
"flash_attn_kvpacked_func_torch",
"flash_attn_func_torch",
"flash_attn_varlen_qkvpacked_func_torch",
"flash_attn_varlen_kvpacked_func_torch",
"flash_attn_varlen_func_torch",
]
9 changes: 9 additions & 0 deletions deeplink_ext/easyllm_ops/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2024, DeepLink.

from deeplink_ext.ascend_speed.rms_norm import RMSNorm

__all__ = ["rms_norm"]


def rms_norm(x, weight, epsilon):
return RMSNorm.apply(x, weight, epsilon)
13 changes: 13 additions & 0 deletions deeplink_ext/easyllm_ops/rms_norm_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, DeepLink.

import torch

__all__ = ["rms_norm_torch"]


def rms_norm_torch(x, weight, epsilon):
input_dtype = x.dtype
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = x * torch.rsqrt(variance + epsilon)

return (hidden_states * weight).to(input_dtype)
12 changes: 6 additions & 6 deletions deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .flash_attention_fallback import (
torch_attn_qkvpacked_func as flash_attn_qkvpacked_func,
torch_attn_kvpacked_func as flash_attn_kvpacked_func,
torch_attn_func as flash_attn_func,
torch_attn_varlen_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
torch_attn_varlen_kvpacked_func as flash_attn_varlen_kvpacked_func,
torch_attn_varlen_func as flash_attn_varlen_func,
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
flash_attn_func_torch as flash_attn_func,
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func_torch as flash_attn_varlen_func,
)

try:
Expand Down
Loading
Loading