Skip to content

Commit

Permalink
add scaled_masked_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Mar 21, 2024
1 parent 2a461b3 commit 9048208
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
25 changes: 25 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask,
double scale, bool fixed_triu_mask) {
auto out = at::empty_like(input);
callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
return out;
}

auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output,
const at::Tensor& out,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
auto grad_input = at::empty_like(grad_output);
callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
mask, scale, fixed_triu_mask);
return grad_input;
}

auto extRmsNorm(const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
Expand Down Expand Up @@ -400,6 +417,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiScaledMaskedSoftmax != nullptr) {
m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
"deeplink ext_scaled_masked_softmax_fwd");
}
if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
"deeplink ext_scaled_masked_softmax_bwd");
}
}

} // namespace dipu::dipu_ext
18 changes: 18 additions & 0 deletions deeplink_ext/llm_ops_for_ascend_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,21 @@ def adamw_for_ascend_speed(
amsgrad,
)
return params, exp_avgs, exp_avg_sqs


class DeepLinkScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = ext.scaled_masked_softmax_fwd(input, mask, scale, fixed_triu_mask)
ctx.save_for_backward(out, mask)
ctx.scale = scale
ctx.fixed_triu_mask = fixed_triu_mask
return out

@staticmethod
def backward(ctx, grad_output):
out, mask = ctx.saved_tensors
grad_input = ext.scaled_masked_softmax_bwd(
grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask
)
return grad_input

0 comments on commit 9048208

Please sign in to comment.