-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support adamw op for ascend speed (#65)
Support adamw op for ascend speed. --------- Co-authored-by: Zhangzefeng <[email protected]>
- Loading branch information
1 parent
8261278
commit 220857e
Showing
3 changed files
with
76 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .rotary_embedding import apply_rotary, RotaryEmbedding | ||
from .adamw import adamw | ||
|
||
__all__ = ["apply_rotary", "RotaryEmbedding"] | ||
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import List | ||
import torch | ||
import deeplink_ext.cpp_extensions as ext | ||
|
||
|
||
assert hasattr(ext, "adamw") | ||
|
||
|
||
def adamw( | ||
params: List[torch.Tensor], | ||
grads: List[torch.Tensor], | ||
exp_avgs: List[torch.Tensor], | ||
exp_avg_sqs: List[torch.Tensor], | ||
max_exp_avg_sqs: List[torch.Tensor], | ||
state_steps: List[int], | ||
*, | ||
amsgrad: bool, | ||
beta1: float, | ||
beta2: float, | ||
lr: float, | ||
weight_decay: float, | ||
eps: float, | ||
maximize: bool, | ||
norm_coeff_scale: float | ||
): | ||
r"""Functional API that performs AdamW algorithm computation. | ||
See :class:`~torch.optim.AdamW` for details. | ||
""" | ||
|
||
assert maximize == False, "ascend diopiAdamW only support False 'maximize'." | ||
assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'." | ||
|
||
for i, param in enumerate(params): | ||
if norm_coeff_scale is not None: | ||
grad = grads[i].float() * norm_coeff_scale | ||
else: | ||
grad = grads[i] | ||
exp_avg = exp_avgs[i] | ||
exp_avg_sq = exp_avg_sqs[i] | ||
step = state_steps[i] | ||
if not max_exp_avg_sqs: | ||
max_exp_avg_sq = torch.Tensor().cuda() | ||
else: | ||
max_exp_avg_sq = max_exp_avg_sqs[i] | ||
ext.adamw( | ||
param, | ||
exp_avg, | ||
exp_avg_sq, | ||
max_exp_avg_sq, | ||
grad, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
weight_decay, | ||
step, | ||
amsgrad, | ||
) | ||
return params, exp_avgs, exp_avg_sqs |