Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor for internevo #70

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@
"RMSNorm",
"RMSNormWithNormalizedShape",
"apply_rotary",
"adamw",
]
62 changes: 62 additions & 0 deletions deeplink_ext/internevo_ops/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List
import torch
import deeplink_ext.cpp_extensions as ext


__all__ = ["adamw"]

assert hasattr(ext, "adamw")


def adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""

assert (
maximize == False
), "The maximize parameter is not supported by diopiAdamW yet"

for i, param in enumerate(params):
if norm_coeff_scale is not None:
grad = grads[i].float() * norm_coeff_scale
else:
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if not max_exp_avg_sqs:
max_exp_avg_sq = torch.Tensor().cuda()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)
return params, exp_avgs, exp_avg_sqs
File renamed without changes.
116 changes: 116 additions & 0 deletions deeplink_ext/internevo_ops/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext


__all__ = ["RMSNorm", "RMSNormWithNormalizedShape"]

assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")


class _RMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output = torch.empty_like(hidden_states)
input_dtype = hidden_states.dtype
acc_dtype = (
torch.float32
if input_dtype in [torch.bfloat16, torch.float16]
else input_dtype
)
inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype)
ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias)
ctx.eps = eps
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias = ctx.saved_tensors
grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
hidden_states,
grad_output,
inv_rms,
None,
weight,
bias,
ctx.eps,
)
return grad_input, grad_weight, grad_bias, None


class _RMSNormFunctionWithNormalizedShape(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output = torch.empty_like(hidden_states)
input_dtype = hidden_states.dtype
acc_dtype = (
torch.float32
if input_dtype in [torch.bfloat16, torch.float16]
else input_dtype
)
inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype)
ext.rms_norm(
output, inv_rms, hidden_states, normalized_shape, weight, bias, eps
)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias)
ctx.eps = eps
ctx.normalized_shape = normalized_shape
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias = ctx.saved_tensors
grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
hidden_states,
grad_output,
inv_rms,
ctx.normalized_shape,
weight,
bias,
ctx.eps,
)
return grad_input, grad_weight, grad_bias, None, None


class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.bias = torch.zeros(hidden_size).cuda()
self.variance_epsilon = eps

def forward(self, hidden_states):
return _RMSNormFunction.apply(
hidden_states, self.weight, self.bias, self.variance_epsilon
)


class RMSNormWithNormalizedShape(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.bias = torch.zeros(hidden_size).cuda()
self.variance_epsilon = eps

def forward(self, hidden_states):
return _RMSNormFunctionWithNormalizedShape.apply(
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
self.weight.size(),
)
86 changes: 0 additions & 86 deletions deeplink_ext/internlm_ops/rms_norm.py

This file was deleted.

4 changes: 2 additions & 2 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _force_fallback():
delattr(cpp_ext, attr)

def _patch_flash_attn():
import deeplink_ext.internlm_ops as ext
import deeplink_ext.internevo_ops as ext
import flash_attn.losses.cross_entropy # type: ignore
import torch.nn

Expand All @@ -72,7 +72,7 @@ def CrossEntropyLossProxy(reduction, **_):
flash_attn.modules.mha.FlashCrossAttention = ext.mha.CrossAttention

def _patch_ops():
import deeplink_ext.internlm_ops as ext
import deeplink_ext.internevo_ops as ext
import flash_attn.layers.rotary # type: ignore
import internlm.model.embedding # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion tests/test_mha_internlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023, DeepLink.

import torch
import deeplink_ext.internlm_ops.mha as ext
import deeplink_ext.internevo_ops.mha as ext


def _run_self_attention(self_attn_module: type, qkv_data: torch.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rms_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch
import numpy as np
from deeplink_ext.internlm_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape
from deeplink_ext.internlm_ops.rms_norm_fallback import (
from deeplink_ext.internevo_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape
from deeplink_ext.internevo_ops.rms_norm_fallback import (
RMSNorm as RMSNorm_fb,
RMSNormWithNormalizedShape as RMSNormWithNormalizedShape_fb,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rotary_emb_internlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023, DeepLink.

import torch
from deeplink_ext.internlm_ops.rotary_embedding import apply_rotary
from deeplink_ext.internevo_ops.rotary_embedding import apply_rotary
from deeplink_ext.internlm_ops.rotary_embeddinig_fallback import (
apply_rotary as apply_rotary_fb,
)
Expand Down
Loading