Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refact-test-case #126

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ jobs:
run: |
source /mnt/cache/share/platform/cienv/dipu_latest_ci
cd ${DEEPLINK_PATH}/${{ github.run_number }}/DeepLinkExt

export PYTHONPATH=$PWD:$PYTHONPATH

cd tests/
export DEEPLINK_EXT_PLATFORM_TYPE=torch_dipu
python -m pytest tests/dipu
python -m pytest -v ./

export DEEPLINK_EXT_PLATFORM_TYPE=torch_npu
python -m pytest tests/npu
python -m pytest -v ./
47 changes: 9 additions & 38 deletions deeplink_ext/ascend_speed/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,6 @@
__all__ = ["RotaryEmbedding"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def apply_rotary(x: torch.Tensor, cos, sin, confj=False, interleaved=False):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

print(cos_cat.shape)
print(x_view.shape)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


class RotaryEmbedding(torch.autograd.Function):
"""
Apply rotary positional embedding to input tensor x.
Expand All @@ -52,12 +20,15 @@ class RotaryEmbedding(torch.autograd.Function):

@staticmethod
def forward(ctx, x, cos, sin):
cos, _ = torch.chunk(cos, 2, -1)
sin, _ = torch.chunk(sin, 2, -1)
ctx.save_for_backward(cos, sin)
return apply_rotary(x, cos, sin)
out = torch_npu.npu_rotary_mul(x, cos, sin)
ctx.save_for_backward(out, cos, sin)
return out

@staticmethod
def backward(ctx, grad_output):
cos, sin = ctx.saved_tensors
return apply_rotary(grad_output, cos, sin, conjugate=True), None, None
out, cos, sin = ctx.saved_tensors
return (
torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0],
None,
None,
)
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch

from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type


@pytest.fixture(scope="session", autouse=True)
def import_module():
platform = deeplink_ext_get_platform_type()
if platform == PlatformType.TORCH_NPU:
import torch_npu
from torch_npu.contrib import transfer_to_npu
elif platform == PlatformType.TORCH_DIPU:
import torch_dipu
else:
raise ValueError("backend platform does not supported by deeplink_ext")
33 changes: 33 additions & 0 deletions tests/fusion_result.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[{
"graph_fusion": {
"RefreshInt64ToInt32FusionPass": {
"effect_times": "1",
"match_times": "1"
}
},
"session_and_graph_id": "0_0"
},{
"graph_fusion": {
"RefreshInt64ToInt32FusionPass": {
"effect_times": "1",
"match_times": "1"
}
},
"session_and_graph_id": "1_1"
},{
"graph_fusion": {
"RefreshInt64ToInt32FusionPass": {
"effect_times": "1",
"match_times": "1"
}
},
"session_and_graph_id": "2_2"
},{
"graph_fusion": {
"RefreshInt64ToInt32FusionPass": {
"effect_times": "1",
"match_times": "1"
}
},
"session_and_graph_id": "3_3"
}]
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,3 @@ def test_MixedFusedRMSNorm():
assert allclose(
grad_ref, grad_ext, rtol=1e-2, atol=1e-2
), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, MixedRMSNorm fails to pass the backward test!"


test_MixedFusedRMSNorm()
48 changes: 0 additions & 48 deletions tests/npu/easyllm/test_rms_norm_npu.py

This file was deleted.

130 changes: 0 additions & 130 deletions tests/npu/internevo/test_flash_attention_npu.py

This file was deleted.

47 changes: 0 additions & 47 deletions tests/npu/internevo/test_rotary_embedding_npu.py

This file was deleted.

Loading
Loading