Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NeosZhang committed Mar 20, 2024
1 parent 2d5104c commit 14fcf78
Showing 1 changed file with 90 additions and 50 deletions.
140 changes: 90 additions & 50 deletions deeplink_ext/internlm_ops/adamw/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "adamw")

def adamw(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float):

def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
Expand All @@ -32,13 +36,35 @@ def adamw(params: List[Tensor],
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
max_exp_avg_sq = max_exp_avg_sqs[i]
param, exp_avg, exp_avg_sq = ext.adamw(param, exp_avg, exp_avg_sq, max_exp_avg_sq, grad, lr,
beta1, beta2, eps, weight_decay, step, amsgrad)
param, exp_avg, exp_avg_sq = ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)
return params, exp_avgs, exp_avg_sq


class DeepLinkAdamW(torch.optim.optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, *, maximize: bool = False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
maximize: bool = False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -49,16 +75,22 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
)
super(DeeplinkAdamW, self).__init__(params, defaults)

def __setstate__(self, state):
super(DeeplinkAdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault('maximize', False)
group.setdefault("maximize", False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -79,53 +111,61 @@ def step(self, closure=None):
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
amsgrad = group["amsgrad"]
beta1, beta2 = group["betas"]

for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state["step"] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])

if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
max_exp_avg_sqs.append(state["max_exp_avg_sq"])

# update the steps for each param group update
state['step'] += 1
state["step"] += 1
# record the step after step update
state_steps.append(state['step'])
state_steps.append(state["step"])

# adamw_torch(params_with_grad,
adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=group['maximize'])
adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
)
return loss

0 comments on commit 14fcf78

Please sign in to comment.