Skip to content

Commit

Permalink
support torch npu for ascend speed
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Aug 22, 2024
1 parent 1820a5a commit 0d9ddf2
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 100 deletions.
8 changes: 2 additions & 6 deletions deeplink_ext/ascend_speed/_rms_norm_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,18 @@
import torch
import deeplink_ext.cpp_extensions as ext


assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")

__all__ = ["RMSNorm"]


class RMSNorm(torch.autograd.Function):

@staticmethod
def forward(ctx, hidden_states, weight, eps):
output = torch.empty_like(hidden_states)
input_dtype = hidden_states.dtype
acc_dtype = (
torch.float32
if input_dtype in [torch.bfloat16, torch.float16]
else input_dtype
)
acc_dtype = (torch.float32 if input_dtype in [torch.bfloat16, torch.float16] else input_dtype)
n = weight.dim()
inv_rms = torch.empty(
list(hidden_states.shape[:-n]),
Expand Down
4 changes: 3 additions & 1 deletion deeplink_ext/ascend_speed/_rms_norm_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ def forward(ctx, hidden_states, weight, eps):
@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight = ctx.saved_tensors
grad_input, grad_weight = npu_rms_norm_backward(grad_output, hidden_states, weight, inv_rms)
grad_input, grad_weight = npu_rms_norm_backward(
grad_output, hidden_states, weight, inv_rms
)
return grad_input, grad_weight, None, None
44 changes: 44 additions & 0 deletions deeplink_ext/ascend_speed/_rotary_embedding_dipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024, DeepLink.

import torch
from typing import Optional, Union
import deeplink_ext.cpp_extensions as ext

__all__ = ["RotaryEmbedding"]


def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved=False,
conjugate=False,
) -> torch.Tensor:
output = torch.empty_like(x)
ext.apply_rotary(output, x, cos, sin, conjugate, interleaved)
return output


class RotaryEmbedding(torch.autograd.Function):
"""
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
def forward(ctx, x, cos, sin):
cos, _ = torch.chunk(cos, 2, -1)
sin, _ = torch.chunk(sin, 2, -1)
ctx.save_for_backward(cos, sin)
return apply_rotary(x, cos, sin)

@staticmethod
def backward(ctx, grad_output):
cos, sin = ctx.saved_tensors
return apply_rotary(grad_output, cos, sin, conjugate=True), None, None
63 changes: 63 additions & 0 deletions deeplink_ext/ascend_speed/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024, DeepLink.

import torch
import torch_npu

__all__ = ["RotaryEmbedding"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def apply_rotary(x: torch.Tensor, cos, sin, confj=False, interleaved=False):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

print(cos_cat.shape)
print(x_view.shape)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


class RotaryEmbedding(torch.autograd.Function):
"""
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
def forward(ctx, x, cos, sin):
cos, _ = torch.chunk(cos, 2, -1)
sin, _ = torch.chunk(sin, 2, -1)
ctx.save_for_backward(cos, sin)
return apply_rotary(x, cos, sin)

@staticmethod
def backward(ctx, grad_output):
cos, sin = ctx.saved_tensors
return apply_rotary(grad_output, cos, sin, conjugate=True), None, None
27 changes: 27 additions & 0 deletions deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(ext, "scaled_masked_softmax_bwd")

__all__ = ["ScaledMaskedSoftmax"]


class ScaledMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch.empty_like(input)
ext.scaled_masked_softmax_fwd(out, input, mask, scale, fixed_triu_mask)
ctx.save_for_backward(out, mask)
ctx.scale = scale
ctx.fixed_triu_mask = fixed_triu_mask
return out

@staticmethod
def backward(ctx, grad_output):
out, mask = ctx.saved_tensors
grad_input = torch.empty_like(grad_output)
ext.scaled_masked_softmax_bwd(grad_input, grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask)
return grad_input, None, None, None
28 changes: 28 additions & 0 deletions deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2024, DeepLink.

import torch
import torch_npu

__all__ = ["ScaledMaskedSoftmax"]


class ScaledMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch_npu.npu_scaled_masked_softmax(input, mask, scale, fixed_triu_mask)

ctx.save_for_backward(out, mask)
ctx.scale = scale
ctx.fixed_triu_mask = fixed_triu_mask
return out

@staticmethod
def backward(ctx, grad_output):
out, mask = ctx.saved_tensors
grad_input = torch.empty_like(grad_output)

grad_input = torch_npu.npu_scaled_masked_softmax_backward(grad_output, out, mask, ctx.scale,
ctx.fixed_triu_mask)

return grad_input, None, None, None
26 changes: 5 additions & 21 deletions deeplink_ext/ascend_speed/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
import torch_npu
# from torch_npu import npu_apply_adam_w as deeplink_ext_adamw
deeplink_ext_adamw = torch.ops.npu.npu_apply_adam_w
elif platform_type == PlatformType.TORCH_DIPU:
# import torch_dipu
Expand All @@ -19,30 +18,15 @@
__all__ = ["adamw"]


def adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int], *,
amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool,
norm_coeff_scale: float):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""

assert (
maximize == False
), "The maximize parameter is not supported by diopiAdamW yet"
assert (maximize == False), "The maximize parameter is not supported by diopiAdamW yet"

for i, param in enumerate(params):
if norm_coeff_scale is not None:
Expand Down
9 changes: 3 additions & 6 deletions deeplink_ext/ascend_speed/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@


class FlashSelfAttention(torch.autograd.Function):

@staticmethod
def forward(
ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout
):
def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout):
out = torch.empty_like(q)
assert (
q.device == k.device and k.device == v.device
), "the devices of q, k and v are not same"
assert (q.device == k.device and k.device == v.device), "the devices of q, k and v are not same"
gen = torch.Generator(device=q.device)
(
dropout_mask,
Expand Down
48 changes: 8 additions & 40 deletions deeplink_ext/ascend_speed/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,13 @@
# Copyright (c) 2024, DeepLink.

import torch
from typing import Optional, Union
import deeplink_ext.cpp_extensions as ext
from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import RotaryEmbedding
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import RotaryEmbedding
else:
raise ImportError

__all__ = ["RotaryEmbedding"]


def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved=False,
conjugate=False,
) -> torch.Tensor:
output = torch.empty_like(x)
ext.apply_rotary(output, x, cos, sin, conjugate, interleaved)
return output


class RotaryEmbedding(torch.autograd.Function):
"""
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
def forward(ctx, x, cos, sin):
cos, _ = torch.chunk(cos, 2, -1)
sin, _ = torch.chunk(sin, 2, -1)
ctx.save_for_backward(cos, sin)
return apply_rotary(x, cos, sin)

@staticmethod
def backward(ctx, grad_output):
cos, sin = ctx.saved_tensors
return apply_rotary(grad_output, cos, sin, conjugate=True), None, None
34 changes: 8 additions & 26 deletions deeplink_ext/ascend_speed/scaled_masked_softmax.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,13 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext
from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type


assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(
ext, "scaled_masked_softmax_bwd"
)
platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._scaled_masked_softmax_npu import ScaledMaskedSoftmax
elif platform_type == PlatformType.TORCH_DIPU:
from ._scaled_masked_softmax_dipu import ScaledMaskedSoftmax
else:
raise ImportError

__all__ = ["ScaledMaskedSoftmax"]


class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch.empty_like(input)
ext.scaled_masked_softmax_fwd(out, input, mask, scale, fixed_triu_mask)
ctx.save_for_backward(out, mask)
ctx.scale = scale
ctx.fixed_triu_mask = fixed_triu_mask
return out

@staticmethod
def backward(ctx, grad_output):
out, mask = ctx.saved_tensors
grad_input = torch.empty_like(grad_output)
ext.scaled_masked_softmax_bwd(
grad_input, grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask
)
return grad_input, None, None, None

0 comments on commit 0d9ddf2

Please sign in to comment.