Skip to content

Commit

Permalink
fix: add missing s_dmask's gradient option
Browse files Browse the repository at this point in the history
  • Loading branch information
ApsarasX committed Apr 17, 2024
1 parent caaba0d commit 52f1d1f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,8 @@ void InitXlaModuleBindings(py::module m) {
if (return_softmax) {
at::Tensor s_dmask =
bridge::AtenFromXlaTensor(std::move(std::get<3>(result_tensors)));
result_tuple[3] = torch::autograd::make_variable(s_dmask, false);
result_tuple[3] =
torch::autograd::make_variable(s_dmask, query.requires_grad());
}
return result_tuple;
},
Expand Down Expand Up @@ -1084,7 +1085,8 @@ void InitXlaModuleBindings(py::module m) {
if (return_softmax) {
at::Tensor s_dmask =
bridge::AtenFromXlaTensor(std::move(std::get<3>(result_tensors)));
result_tuple[3] = torch::autograd::make_variable(s_dmask, false);
result_tuple[3] =
torch::autograd::make_variable(s_dmask, query.requires_grad());
}
return result_tuple;
},
Expand Down

0 comments on commit 52f1d1f

Please sign in to comment.