Skip to content

Commit

Permalink
speed_rmsandrope
Browse files Browse the repository at this point in the history
  • Loading branch information
SHshenhao committed Mar 12, 2024
1 parent d9d0ba3 commit ac982b5
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 47 deletions.
6 changes: 5 additions & 1 deletion csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ auto extRmsNorm(const at::Tensor& input,
auto input_shape = input.sizes();
std::vector<int64_t> input_size(input_shape.begin(), input_shape.end());
input_size.back() = 1;
auto inv_rms = at::empty(input_size, input.options());
at::ScalarType acc_type = input.scalar_type();
if (acc_type == at::kBFloat16 || acc_type == at::kHalf) {
acc_type = at::kFloat;
}
auto inv_rms = at::empty(input_size, input.options().dtype(acc_type));
auto output = at::empty_like(input);
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
Expand Down
23 changes: 6 additions & 17 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,20 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output, inv_rms = ext.rms_norm(
hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps
hidden_states, normalized_shape, weight, bias, eps
)
output = output.half()
inv_rms = inv_rms.half()
ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
hidden_states = hidden_states.half()
weight = weight.half()
bias = bias.half()
ctx.intermediate_results = normalized_shape
return output

@staticmethod
def backward(ctx, grad_output):
normalized_shape = ctx.intermediate_results
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
normalized_shape = ctx.intermediate_results
hidden_states = hidden_states.float()
inv_rms = inv_rms.float()
weight = weight.float()
bias = bias.float()
grad_output = grad_output.float()
grad_input, grad_weight, grad_bias = ext.rms_norm_backward(
hidden_states, grad_output, inv_rms, normalized_shape, weight, bias, eps
)
grad_output = grad_output.half()
hidden_states = hidden_states.half()
inv_rms = inv_rms.half()
weight = weight.half()
bias = bias.half()
return grad_input, grad_weight, grad_bias, None, None


Expand All @@ -70,6 +55,8 @@ def __init__(self, hidden_size, eps=1e-5):
self.variance_epsilon = eps

def forward(self, hidden_states):
if (hidden_states.dtype != self.weight.dtype):
hidden_states = hidden_states.to(dtype=self.weight.dtype)
return _DeepLinkRMSNormFunction.apply(
hidden_states, self.weight, self.bias, self.variance_epsilon
)
Expand All @@ -83,6 +70,8 @@ def __init__(self, hidden_size, eps=1e-5):
self.variance_epsilon = eps

def forward(self, hidden_states):
if (hidden_states.dtype != self.weight.dtype):
hidden_states = hidden_states.to(dtype=self.weight.dtype)
return _DeepLinkRMSNormFunctionWithNormalizedShape.apply(
hidden_states,
self.weight,
Expand Down
65 changes: 40 additions & 25 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,72 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim)
)
assert qkv.size(-1) == rotary_dim
q_ro = qkv[:, :, 0, :, :rotary_dim]
cos_qk = None
sin_qk = None
if seqlen == rotary_seqlen:
cos_qk = cos
sin_qk = sin
else: # <
cos_qk = rearrange(cos[:seqlen], "s d -> s 1 d")
sin_qk = rearrange(sin[:seqlen], "s d -> s 1 d")
ext.apply_rotary(
q_ro,
q_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
k_ro,
k_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.save_for_backward(cos_qk, sin_qk, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
cos_qk, sin_qk, cos_k, sin_k = ctx.saved_tensors
interleaved = ctx.interleaved
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
rotary_dim = cos_qk.size(-1)
assert dqkv.size(-1) == rotary_dim
assert seqlen == cos_qk.size(0)
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
dq_ro,
dq_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
True,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
dk_ro,
dk_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
True,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
Expand All @@ -84,45 +92,52 @@ def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
assert sin.shape == (rotary_seqlen, rotary_dim)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

cos_qk = None
sin_qk = None
if seqlen == rotary_seqlen:
cos_qk = cos
sin_qk = sin
else: # <
cos_qk = rearrange(cos[:seqlen], "s d -> s 1 d")
sin_qk = rearrange(sin[:seqlen], "s d -> s 1 d")

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
False,
interleaved,
)

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.save_for_backward(cos_qk, sin_qk)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
cos_qk, sin_qk = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
rotary_dim = cos_qk.size(-1)
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
cos_qk,
sin_qk,
True,
ctx.interleaved,
)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_rms_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,40 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3):

return np.allclose(grad_x_base, grad_x_intern, rtol, atol, True)

def test_rms_normfp16(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-2, atol=1e-2):
x_base = torch.randn(5, 5, requires_grad=True).to(dtype=torch.float16).cuda()
x_base.retain_grad()

x_intern = x_base.clone()
x_intern.retain_grad()

hidden_size = 5

model_base = BaseRmsNorm(hidden_size).to(dtype=torch.float16).cuda()
out_base = model_base(x_base)
out_base.backward(torch.ones_like(x_base))
grad_x_base = x_base.grad.cpu().numpy()

model_deeplink = DeeplinkRmsNorm(hidden_size).to(dtype=torch.float16).cuda()
out_deeplink = model_deeplink(x_intern)
out_deeplink.backward(torch.ones_like(x_base))
grad_x_intern = x_intern.grad.cpu().numpy()

return np.allclose(grad_x_base, grad_x_intern, rtol, atol, True)

print(
"Test case: normalized_shape == None: grad_inputs closed ? ",
test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm),
)
print(
"Test case fp16: normalized_shape == None: grad_inputs closed ? ",
test_rms_normfp16(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm),
)
print(
"Test case: normalized_shape == weight.size(): grad_inputs closed ? ",
test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape),
)
print(
"Test case fp16: normalized_shape == weight.size(): grad_inputs closed ? ",
test_rms_normfp16(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape),
)
8 changes: 4 additions & 4 deletions tests/test_rotary_emb_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def RotaryEmbTest(func_name):
# 调用前向传播
if func_name == "RotaryEmbQKV":
res1 = torch_apply(input, cos, sin, cos_k, sin_k, interleaved)
res2 = dipu_apply(input1, cos1, sin1, cos_k, sin_k, interleaved)
res2 = dipu_apply(input1, cos1.repeat(1, 2), sin1.repeat(1, 2), cos_k, sin_k, interleaved)
elif func_name == "RotaryEmb":
res1 = torch_apply(input, cos, sin, interleaved)
res2 = dipu_apply(input1, cos1, sin1, interleaved)
res2 = dipu_apply(input1, cos1.repeat(1, 2), sin1.repeat(1, 2), interleaved)
else:
print(f"{func_name} is not supported.")
return False

# 验证前向传播结果
forward_correct = torch.allclose(res1, res2)
forward_correct = torch.allclose(res1, res2, rtol=1e-2, atol=1e-2)

# 计算第一个损失
c = torch.ones_like(res1)
Expand All @@ -61,7 +61,7 @@ def RotaryEmbTest(func_name):
# 验证第一个反向传播梯度
grad1 = input.grad
grad2 = input1.grad
backward_correct = torch.allclose(grad1, grad2)
backward_correct = torch.allclose(grad1, grad2, rtol=1e-2, atol=1e-2)
# 判断前向和反向传播结果是否都正确
if forward_correct and backward_correct:
print(f"{func_name} both forward and backward pass tests passed.")
Expand Down

0 comments on commit ac982b5

Please sign in to comment.