Skip to content

Commit efae8d6

Browse files
committed
Add clamp impl back because XLANativeFunctions::hardtanh uses it
1 parent aea84c2 commit efae8d6

File tree

6 files changed

+37
-16
lines changed

6 files changed

+37
-16
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - https://github.com/pytorch/xla/blob/master/docs/source/contribute/codegen_migration.md
55
backend: XLA
66
cpp_namespace: torch_xla
7-
# full_codegen is the prefered method of code generation. Through this config
7+
# full_codegen is the preferred method of code generation. Through this config
88
# ops get implementations (and IR classes) generated. See
99
# https://github.com/pytorch/xla/blob/master/docs/source/contribute/codegen_migration.md
1010
# for more details on differences on what gets generated or not.

test/cpp/BUILD

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,21 +130,6 @@ ptxla_cc_test(
130130
],
131131
)
132132

133-
# ptxla_cc_test(
134-
# name = "test_aten_xla_tensor_5",
135-
# size = "enormous",
136-
# srcs = ["test_aten_xla_tensor_5.cpp"],
137-
# deps = [
138-
# ":cpp_test_util",
139-
# ":torch_xla_test",
140-
# "//torch_xla/csrc/runtime:metrics",
141-
# "//torch_xla/csrc:tensor",
142-
# "//torch_xla/csrc:aten_cuda_functions",
143-
# "@com_google_googletest//:gtest_main",
144-
# "@xla//xla:permutation_util",
145-
# ],
146-
# )
147-
148133
# This tets is very large so it's split into shards.
149134
# To make it run fast, please add new shards when needed.
150135
[

torch_xla/csrc/ops/ops.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output,
195195
dim, GetXlaShape(grad_output).dimensions_size()));
196196
}
197197

198+
torch::lazy::NodePtr Clamp(const torch::lazy::Value& input,
199+
const torch::lazy::Value& min,
200+
const torch::lazy::Value& max) {
201+
auto lower_fn = [](const XlaNode& node,
202+
LoweringContext* loctx) -> XlaOpVector {
203+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
204+
xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1));
205+
xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2));
206+
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
207+
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type);
208+
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type);
209+
return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx);
210+
};
211+
return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max},
212+
GetXlaShape(input), std::move(lower_fn));
213+
}
214+
198215
torch::lazy::NodePtr Celu(const torch::lazy::Value& input,
199216
const at::Scalar& alpha) {
200217
auto lower_fn = [=](const XlaNode& node,

torch_xla/csrc/ops/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output,
121121
const torch::lazy::Value& output,
122122
int64_t dim);
123123

124+
torch::lazy::NodePtr Clamp(const torch::lazy::Value& input,
125+
const torch::lazy::Value& min,
126+
const torch::lazy::Value& max);
127+
124128
torch::lazy::NodePtr Celu(const torch::lazy::Value& input,
125129
const at::Scalar& alpha);
126130

torch_xla/csrc/tensor_methods.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,14 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
12261226
input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha));
12271227
}
12281228

1229+
XLATensorPtr clamp(const XLATensorPtr& input,
1230+
const std::optional<at::Scalar>& min,
1231+
const std::optional<at::Scalar>& max) {
1232+
MinMaxValues min_max = GetMinMaxValues(input, min, max);
1233+
return input->CreateFrom(
1234+
Clamp(input->GetIrValue(), min_max.min, min_max.max));
1235+
}
1236+
12291237
XLATensorPtr clone(const XLATensorPtr& input) {
12301238
XLATensorPtr cloned = input->CreateFrom(input->GetIrValue());
12311239
if (input->sharding_spec() != nullptr) {

torch_xla/csrc/tensor_methods.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,13 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor);
320320
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha);
321321
void celu_(XLATensorPtr& input, const at::Scalar& alpha);
322322

323+
XLATensorPtr clamp(const XLATensorPtr& input,
324+
const std::optional<at::Scalar>& min,
325+
const std::optional<at::Scalar>& max);
326+
XLATensorPtr clamp(const XLATensorPtr& input,
327+
const std::optional<at::Tensor>& min,
328+
const std::optional<at::Tensor>& max);
329+
323330
XLATensorPtr clone(const XLATensorPtr& input);
324331

325332
XLATensorPtr conj(const XLATensorPtr& input);

0 commit comments

Comments
 (0)