diff --git a/deeplink_ext/internlm_ops/adamw/deeplink.py b/deeplink_ext/internlm_ops/adamw/deeplink.py index 7bc5b2b2..e2bc8732 100644 --- a/deeplink_ext/internlm_ops/adamw/deeplink.py +++ b/deeplink_ext/internlm_ops/adamw/deeplink.py @@ -2,23 +2,27 @@ import torch import deeplink_ext.cpp_extensions as ext + assert hasattr(ext, "adamw") -def adamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[int], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float, - maximize: bool, - norm_coeff_scale: float): + +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + norm_coeff_scale: float +): r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. """ @@ -32,13 +36,35 @@ def adamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] max_exp_avg_sq = max_exp_avg_sqs[i] - param, exp_avg, exp_avg_sq = ext.adamw(param, exp_avg, exp_avg_sq, max_exp_avg_sq, grad, lr, - beta1, beta2, eps, weight_decay, step, amsgrad) + param, exp_avg, exp_avg_sq = ext.adamw( + param, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + grad, + lr, + beta1, + beta2, + eps, + weight_decay, + step, + amsgrad, + ) return params, exp_avgs, exp_avg_sq + class DeepLinkAdamW(torch.optim.optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, *, maximize: bool = False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + maximize: bool = False + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -49,16 +75,22 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + ) super(DeeplinkAdamW, self).__init__(params, defaults) - + def __setstate__(self, state): super(DeeplinkAdamW, self).__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False) - group.setdefault('maximize', False) - + group.setdefault("maximize", False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -79,53 +111,61 @@ def step(self, closure=None): state_sums = [] max_exp_avg_sqs = [] state_steps = [] - amsgrad = group['amsgrad'] - beta1, beta2 = group['betas'] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) if amsgrad: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) # update the steps for each param group update - state['step'] += 1 + state["step"] += 1 # record the step after step update - state_steps.append(state['step']) + state_steps.append(state["step"]) # adamw_torch(params_with_grad, - adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=group['maximize']) + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + ) return loss