From bf3636d54d4c5b71127c591c87d2c9b3a4a9eeef Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 10:51:34 +0800 Subject: [PATCH 1/7] refactor rms_norm --- csrc/extensions.cpp | 7 ++--- deeplink_ext/ascend_speed/__init__.py | 3 +- deeplink_ext/ascend_speed/rms_norm.py | 43 +++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 deeplink_ext/ascend_speed/rms_norm.py diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 23f8fbcb..02c72fc2 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -36,17 +36,16 @@ void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, beta1, beta2, epsilon, weight_decay, step, amsgrad); } -auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, +void extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, const OptionalIntArray& normalized_shape, const at::Tensor& weight, const at::Tensor& bias, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); - return std::make_tuple(std::move(output), std::move(inv_rms)); } -auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, +void extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, at::Tensor& grad_bias, const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, const at::Tensor& inv_rms, @@ -55,8 +54,6 @@ auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, grad_output, input, weight, bias, inv_rms, normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), - std::move(grad_bias)); } void extApplyRotary(at::Tensor& output, const at::Tensor& input, diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py index fab32da5..78c5a08d 100644 --- a/deeplink_ext/ascend_speed/__init__.py +++ b/deeplink_ext/ascend_speed/__init__.py @@ -1,4 +1,5 @@ from .rotary_embedding import apply_rotary, RotaryEmbedding from .adamw import adamw +from .rms_norm import RMSNorm -__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"] +__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "RMSNorm"] diff --git a/deeplink_ext/ascend_speed/rms_norm.py b/deeplink_ext/ascend_speed/rms_norm.py new file mode 100644 index 00000000..ace38358 --- /dev/null +++ b/deeplink_ext/ascend_speed/rms_norm.py @@ -0,0 +1,43 @@ +import torch +import deeplink_ext.cpp_extensions as ext + + +assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") + + +class RMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, eps): + bias = torch.Tensor().cuda() + output = torch.empty_like(hidden_states) + input_dtype = hidden_states.dtype + acc_dtype = ( + torch.float32 + if (input_dtype == torch.bfloat16 or input_dtype == torch.float16) + else input_dtype + ) + inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype) + ext.rms_norm(output, hidden_states, None, weight, bias, eps) + ctx.save_for_backward(hidden_states, inv_rms, weight, bias) + ctx.eps = eps + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias = ctx.saved_tensors + grad_input = torch.empty_like(hidden_states) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + hidden_states, + grad_output, + inv_rms, + None, + weight, + bias, + ctx.eps, + ) + return grad_input, grad_weight, None, None From c2a01307b07fec9abf05ae130baa38d924a42a24 Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 11:03:24 +0800 Subject: [PATCH 2/7] update --- deeplink_ext/ascend_speed/rms_norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/ascend_speed/rms_norm.py b/deeplink_ext/ascend_speed/rms_norm.py index ace38358..815f363c 100644 --- a/deeplink_ext/ascend_speed/rms_norm.py +++ b/deeplink_ext/ascend_speed/rms_norm.py @@ -13,11 +13,11 @@ def forward(ctx, hidden_states, weight, eps): input_dtype = hidden_states.dtype acc_dtype = ( torch.float32 - if (input_dtype == torch.bfloat16 or input_dtype == torch.float16) + if input_dtype in [torch.bfloat16, torch.float16] else input_dtype ) inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype) - ext.rms_norm(output, hidden_states, None, weight, bias, eps) + ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias) ctx.eps = eps return output From af9e76df54b423c229bb3d99e934c7f8bb031696 Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 11:43:11 +0800 Subject: [PATCH 3/7] update --- deeplink_ext/ascend_speed/rms_norm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deeplink_ext/ascend_speed/rms_norm.py b/deeplink_ext/ascend_speed/rms_norm.py index 815f363c..f7ca53aa 100644 --- a/deeplink_ext/ascend_speed/rms_norm.py +++ b/deeplink_ext/ascend_speed/rms_norm.py @@ -17,7 +17,7 @@ def forward(ctx, hidden_states, weight, eps): else input_dtype ) inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype) - ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) + ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias) ctx.eps = eps return output @@ -32,12 +32,12 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias, - hidden_states, grad_output, - inv_rms, - None, + hidden_states, weight, bias, + inv_rms, + weight.shape, ctx.eps, ) return grad_input, grad_weight, None, None From f55ac873591327bb353064bb916ebfee99686bff Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 15:47:28 +0800 Subject: [PATCH 4/7] add for flash attention --- deeplink_ext/ascend_speed/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py index fdaa7452..1a74253a 100644 --- a/deeplink_ext/ascend_speed/__init__.py +++ b/deeplink_ext/ascend_speed/__init__.py @@ -2,5 +2,6 @@ from .adamw import adamw from .scaled_masked_softmax import ScaledMaskedSoftmax from .rms_norm import RMSNorm +from .flash_attention import FlashSelfAttention -__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax", "RMSNorm"] +__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax", "RMSNorm", "FlashSelfAttention"] From 084f053d1e18ff750ae607cd094ecd1a8e744d5a Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 16:18:31 +0800 Subject: [PATCH 5/7] fix bug of construct inv_rms due to shape --- deeplink_ext/ascend_speed/rms_norm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deeplink_ext/ascend_speed/rms_norm.py b/deeplink_ext/ascend_speed/rms_norm.py index f7ca53aa..e78beb57 100644 --- a/deeplink_ext/ascend_speed/rms_norm.py +++ b/deeplink_ext/ascend_speed/rms_norm.py @@ -16,7 +16,11 @@ def forward(ctx, hidden_states, weight, eps): if input_dtype in [torch.bfloat16, torch.float16] else input_dtype ) - inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype) + inv_rms = torch.empty( + hidden_states.shape[:-1] + (1,), + dtype=acc_dtype, + device=hidden_states.device, + ) ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias) ctx.eps = eps From 61fd608d71cef837980fb9e4c64b95109c7b69f0 Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 16:21:58 +0800 Subject: [PATCH 6/7] update --- deeplink_ext/ascend_speed/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py index 1a74253a..f7a83059 100644 --- a/deeplink_ext/ascend_speed/__init__.py +++ b/deeplink_ext/ascend_speed/__init__.py @@ -4,4 +4,11 @@ from .rms_norm import RMSNorm from .flash_attention import FlashSelfAttention -__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax", "RMSNorm", "FlashSelfAttention"] +__all__ = [ + "apply_rotary", + "RotaryEmbedding", + "adamw", + "ScaledMaskedSoftmax", + "RMSNorm", + "FlashSelfAttention", +] From bd22e7d81e9f12814c31af044941998fd2c1cbe0 Mon Sep 17 00:00:00 2001 From: YPOI-WX Date: Tue, 2 Apr 2024 16:46:12 +0800 Subject: [PATCH 7/7] update accorging to review --- deeplink_ext/ascend_speed/rms_norm.py | 2 +- deeplink_ext/internlm_ops/__init__.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/deeplink_ext/ascend_speed/rms_norm.py b/deeplink_ext/ascend_speed/rms_norm.py index e78beb57..201eecae 100644 --- a/deeplink_ext/ascend_speed/rms_norm.py +++ b/deeplink_ext/ascend_speed/rms_norm.py @@ -17,7 +17,7 @@ def forward(ctx, hidden_states, weight, eps): else input_dtype ) inv_rms = torch.empty( - hidden_states.shape[:-1] + (1,), + list(hidden_states.shape[:-1]) + [1], dtype=acc_dtype, device=hidden_states.device, ) diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index 61f616ad..ccbc1b9b 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -1,7 +1,5 @@ # Copyright (c) 2024, DeepLink. -from . import mha - _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."