Skip to content

Commit

Permalink
optimize all
Browse files Browse the repository at this point in the history
  • Loading branch information
yangbofun committed Mar 22, 2024
1 parent 345784e commit 4a5a5c4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
72 changes: 36 additions & 36 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,31 @@ at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault(

} // namespace

// auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
// at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr,
// float beta1, float beta2, float epsilon, float weight_decay,
// int64_t step, bool amsgrad) {
// // the diopiAdamW func has no "maximize" param
// callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr,
// beta1, beta2, epsilon, weight_decay, step, amsgrad);
// }

// auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask,
// double scale, bool fixed_triu_mask) {
// auto out = at::empty_like(input);
// callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
// return out;
// }

// auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output,
// const at::Tensor& out,
// const at::Tensor& mask, double scale,
// bool fixed_triu_mask) {
// auto grad_input = at::empty_like(grad_output);
// callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
// mask, scale, fixed_triu_mask);
// return grad_input;
// }
auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr,
float beta1, float beta2, float epsilon, float weight_decay,
int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask,
double scale, bool fixed_triu_mask) {
auto out = at::empty_like(input);
callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
return out;
}

auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output,
const at::Tensor& out,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
auto grad_input = at::empty_like(grad_output);
callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
mask, scale, fixed_triu_mask);
return grad_input;
}

auto extRmsNorm(const at::Tensor& input,
const OptionalIntArray& normalized_shape,
Expand Down Expand Up @@ -414,17 +414,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiApplyPenalty != nullptr) {
m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty");
}
// if (&diopiAdamW != nullptr) {
// m.def("adamw", &extAdamW, "deeplink ext_adamw");
// }
// if (&diopiScaledMaskedSoftmax != nullptr) {
// m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
// "deeplink ext_scaled_masked_softmax_fwd");
// }
// if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
// m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
// "deeplink ext_scaled_masked_softmax_bwd");
// }
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiScaledMaskedSoftmax != nullptr) {
m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
"deeplink ext_scaled_masked_softmax_fwd");
}
if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
"deeplink ext_scaled_masked_softmax_bwd");
}
}

} // namespace dipu::dipu_ext
4 changes: 2 additions & 2 deletions deeplink_ext/llm_ops_for_ascend_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")
assert hasattr(ext, "apply_rotary")
assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")
# assert hasattr(ext, "adamw")
# assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(ext, "scaled_masked_softmax_bwd")
assert hasattr(ext, "adamw")
assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(ext, "scaled_masked_softmax_bwd")


class DeepLinkFlashSelfAttention(torch.autograd.Function):
Expand Down

0 comments on commit 4a5a5c4

Please sign in to comment.