Skip to content

Commit

Permalink
feat: adapt to InternTrain and InternEvo respectively (#121)
Browse files Browse the repository at this point in the history
Adapt to InternTrain and InternEvo respectively.
  • Loading branch information
POI-WX authored Jul 29, 2024
1 parent 6de1ba5 commit 5769514
Show file tree
Hide file tree
Showing 26 changed files with 5,170 additions and 2,455 deletions.
36 changes: 24 additions & 12 deletions deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,27 @@
from .adamw import AdamW
except Exception as e:
print(_not_impl.format(op_name="adamw"))
from torch.optim import AdamW as AdamW
from torch.optim import AdamW

try:
from .flash_attention import FlashSelfAttention, FlashCrossAttention
from .flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .flash_attention_fallback import SelfAttention as FlashSelfAttention
from .flash_attention_fallback import CrossAttention as FlashCrossAttention

from .flash_attention_fallback import (
torch_attn_qkvpacked_func as flash_attn_qkvpacked_func,
torch_attn_kvpacked_func as flash_attn_kvpacked_func,
torch_attn_func as flash_attn_func,
torch_attn_varlen_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
torch_attn_varlen_kvpacked_func as flash_attn_varlen_kvpacked_func,
torch_attn_varlen_func as flash_attn_varlen_func,
)

try:
from .rms_norm import MixedFusedRMSNorm
Expand All @@ -24,20 +36,20 @@
)
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm


try:
from .rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .rotary_embedding import ApplyRotaryEmb
except:
print(_not_impl.format(op_name="rotary embedding"))
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_


__all__ = [
"AdamW",
"FlashSelfAttention",
"FlashCrossAttention",
"flash_attn_qkvpacked_func",
"flash_attn_kvpacked_func",
"flash_attn_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_attn_varlen_func",
"MixedFusedRMSNorm",
"ApplyRotaryEmb",
"ApplyRotaryEmbQKV_",
]
166 changes: 1 addition & 165 deletions deeplink_ext/internevo_ops/adamw.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,6 @@
# Copyright (c) 2024, DeepLink.

import torch
from torch.optim.optimizer import Optimizer
from typing import List
import deeplink_ext.cpp_extensions as ext
from deeplink_ext.interntrain_ops.adamw import AdamW


assert hasattr(ext, "adamw")

__all__ = ["AdamW"]


def fused_adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
step: int,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
if maximize is True:
raise RuntimeError(
"Deeplink Adamw with fused=True does not support maximize=True!"
)

for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
if amsgrad and len(max_exp_avg_sqs):
max_exp_avg_sq = max_exp_avg_sqs[i]
else:
max_exp_avg_sq = None

ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)


class AdamW(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
maximize: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
)
super(AdamW, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
amsgrad = group["amsgrad"]
beta1, beta2 = group["betas"]

if "step" in group:
group["step"] += 1
else:
group["step"] = 1

for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]

# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])

if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])

fused_adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
group["step"],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
)

return loss
Loading

0 comments on commit 5769514

Please sign in to comment.