From a82ca108c4612a55c0a7cdb6b1a54dbc0f5db237 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 8 May 2025 23:24:32 +0000 Subject: [PATCH 1/3] Lower clamp --- codegen/xla_native_functions.yaml | 8 ++-- torch_xla/csrc/aten_xla_type.cpp | 22 ---------- torch_xla/csrc/ops/ops.cpp | 17 -------- torch_xla/csrc/ops/ops.h | 4 -- torch_xla/csrc/ops/ops_lower_fn.cpp | 49 ++++++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 54 +++++++++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.h | 10 +++++ torch_xla/csrc/tensor_methods.cpp | 8 ---- torch_xla/csrc/tensor_methods.h | 7 ---- 9 files changed, 117 insertions(+), 62 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 0f510975d982..801e5276a411 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -4,7 +4,7 @@ # - https://github.com/pytorch/xla/blob/master/docs/source/contribute/codegen_migration.md backend: XLA cpp_namespace: torch_xla -# full_codegen is the prefered method of code generation. Through this config +# full_codegen is the preferred method of code generation. Through this config # ops get implementations (and IR classes) generated. See # https://github.com/pytorch/xla/blob/master/docs/source/contribute/codegen_migration.md # for more details on differences on what gets generated or not. @@ -33,8 +33,11 @@ full_codegen: - bitwise_right_shift.Tensor - ceil - cholesky + - clamp - clamp.Tensor + - clamp_max - clamp_max.Tensor + - clamp_min - clamp_min.Tensor - _conj_copy - cos @@ -183,9 +186,6 @@ supported: - cat - celu - celu_ - - clamp - - clamp_max - - clamp_min - clone - constant_pad_nd - convolution_backward_overrideable diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 3cb1f6c51b95..b04169cc29ae 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1312,28 +1312,6 @@ at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, return self; } -at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, - const std::optional& min, - const std::optional& max) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), min, max)); -} - -at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, - const at::Scalar& max) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), std::nullopt, max)); -} - -at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, - const at::Scalar& min) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), min, std::nullopt)); -} - at::Tensor XLANativeFunctions::clone( const at::Tensor& self, std::optional /* memory_format */) { diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 527b07af31d0..5fe7da8c0fec 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -195,23 +195,6 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output, dim, GetXlaShape(grad_output).dimensions_size())); } -torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, - const torch::lazy::Value& min, - const torch::lazy::Value& max) { - auto lower_fn = [](const XlaNode& node, - LoweringContext* loctx) -> XlaOpVector { - xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); - xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); - xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); - xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type); - xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type); - return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); - }; - return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, - GetXlaShape(input), std::move(lower_fn)); -} - torch::lazy::NodePtr Celu(const torch::lazy::Value& input, const at::Scalar& alpha) { auto lower_fn = [=](const XlaNode& node, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 020bafdbf560..ddb11288ec39 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -121,10 +121,6 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output, const torch::lazy::Value& output, int64_t dim); -torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, - const torch::lazy::Value& min, - const torch::lazy::Value& max); - torch::lazy::NodePtr Celu(const torch::lazy::Value& input, const at::Scalar& alpha); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index ed25cf087f23..b38c621c4c1f 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -327,6 +327,43 @@ torch_xla::XlaOpVector Cholesky::Lower(LoweringContext* loctx) const { return ReturnOp(output, loctx); } +torch_xla::XlaOpVector Clamp::Lower(LoweringContext* loctx) const { + XLA_CHECK(has_min || has_max) + << "At least one of \'min\' or \'max\' must not be None"; + + // This is little bit ugly due to min and max tensors being optional, + // and operand[1] can be either min or max: + // if !has_min and has_max -> operand[1] is max + // if has_min and !has_max -> operand[1] is min + xla::XlaOp res = loctx->GetOutputOp(operand(0)); + if (has_min && has_max) { + auto promoted_min = + XlaHelpers::Promote(res, loctx->GetOutputOp(operand(1))); + res = xla::Max(promoted_min.first, promoted_min.second, + XlaHelpers::getBroadcastDimensions(promoted_min.first, + promoted_min.second)); + auto promoted_max = + XlaHelpers::Promote(res, loctx->GetOutputOp(operand(2))); + res = xla::Min(promoted_max.first, promoted_max.second, + XlaHelpers::getBroadcastDimensions(promoted_max.first, + promoted_max.second)); + } else if (has_min) { + auto promoted_min = + XlaHelpers::Promote(res, loctx->GetOutputOp(operand(1))); + res = xla::Max(promoted_min.first, promoted_min.second, + XlaHelpers::getBroadcastDimensions(promoted_min.first, + promoted_min.second)); + } else if (has_max) { + auto promoted_max = + XlaHelpers::Promote(res, loctx->GetOutputOp(operand(1))); + res = xla::Min(promoted_max.first, promoted_max.second, + XlaHelpers::getBroadcastDimensions(promoted_max.first, + promoted_max.second)); + } + + return ReturnOp(res, loctx); +} + torch_xla::XlaOpVector ClampTensor::Lower(LoweringContext* loctx) const { XLA_CHECK(has_min || has_max) << "At least one of \'min\' or \'max\' must not be None"; @@ -364,12 +401,24 @@ torch_xla::XlaOpVector ClampTensor::Lower(LoweringContext* loctx) const { return ReturnOp(res, loctx); } +torch_xla::XlaOpVector ClampMax::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); + return ReturnOp(xla::Min(xla_input, xla_other), loctx); +} + torch_xla::XlaOpVector ClampMaxTensor::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); return ReturnOp(xla::Min(xla_input, xla_other), loctx); } +torch_xla::XlaOpVector ClampMin::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); + return ReturnOp(xla::Max(xla_input, xla_other), loctx); +} + torch_xla::XlaOpVector ClampMinTensor::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 1881f6631c9f..3aaecbf5a02b 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -385,6 +385,40 @@ xla::Shape CholeskyOutputShape(const torch::lazy::Value& input, return GetXlaShape(input); } +xla::Shape ClampOutputShape(const torch::lazy::Value& input, + const std::optional& min, + const std::optional& max) { + // This shape function works in a bit of an odd/hacky way. + // If operands.size() > 1, operands[1] can be either min or + // max since they are both optional values. But in this code, + // we are always assuming operands[1] to be min if + // operands.size() > 1. This code works because xla::Min and + // xla::Max produce the same output shapes. + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + xla::XlaOp res = operands[0]; + if (operands.size() > 1) { + auto promoted = XlaHelpers::Promote(res, operands[1]); + res = xla::Max( + promoted.first, promoted.second, + XlaHelpers::getBroadcastDimensions(promoted.first, promoted.second)); + } + if (operands.size() > 2) { + auto promoted = XlaHelpers::Promote(res, operands[2]); + res = xla::Min( + promoted.first, promoted.second, + XlaHelpers::getBroadcastDimensions(promoted.first, promoted.second)); + } + return res; + }; + std::vector shapes; + for (auto& i : + GetValuesVectorWithOptional({input}, {&min, &max})) { + shapes.push_back(GetXlaShape(i)); + } + return InferOutputShape(shapes, lower_for_shape_fn); +} + xla::Shape ClampTensorOutputShape( const torch::lazy::Value& input, const std::optional& min, @@ -420,6 +454,16 @@ xla::Shape ClampTensorOutputShape( return InferOutputShape(shapes, lower_for_shape_fn); } +xla::Shape ClampMaxOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& other) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + return xla::Min(operands[0], operands[1]); + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, + lower_for_shape_fn); +} + xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& other) { auto lower_for_shape_fn = @@ -430,6 +474,16 @@ xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input, lower_for_shape_fn); } +xla::Shape ClampMinOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& other) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + return xla::Max(operands[0], operands[1]); + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, + lower_for_shape_fn); +} + xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index c280c79f4567..aa436c966c5b 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -104,13 +104,23 @@ xla::Shape CeilOutputShape(const torch::lazy::Value& input); xla::Shape CholeskyOutputShape(const torch::lazy::Value& input, const bool upper); +xla::Shape ClampOutputShape(const torch::lazy::Value& input, + const std::optional& min, + const std::optional& max); + xla::Shape ClampTensorOutputShape(const torch::lazy::Value& input, const std::optional& min, const std::optional& max); +xla::Shape ClampMaxOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& target); + xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& target); +xla::Shape ClampMinOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& target); + xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& target); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 80d799076048..bbae55b74c27 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1226,14 +1226,6 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) { input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha)); } -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max) { - MinMaxValues min_max = GetMinMaxValues(input, min, max); - return input->CreateFrom( - Clamp(input->GetIrValue(), min_max.min, min_max.max)); -} - XLATensorPtr clone(const XLATensorPtr& input) { XLATensorPtr cloned = input->CreateFrom(input->GetIrValue()); if (input->sharding_spec() != nullptr) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 79f6acd8049d..003fd90934bd 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -320,13 +320,6 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor); XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha); void celu_(XLATensorPtr& input, const at::Scalar& alpha); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); - XLATensorPtr clone(const XLATensorPtr& input); XLATensorPtr conj(const XLATensorPtr& input); From 709891f3c7984d9accb2d7a99b78cfd75079d956 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 29 May 2025 01:25:24 +0000 Subject: [PATCH 2/3] Add clamp impl back because XLANativeFunctions::hardtanh uses it --- torch_xla/csrc/ops/ops.cpp | 17 +++++++++++++++++ torch_xla/csrc/ops/ops.h | 4 ++++ torch_xla/csrc/tensor_methods.cpp | 8 ++++++++ torch_xla/csrc/tensor_methods.h | 7 +++++++ 4 files changed, 36 insertions(+) diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 5fe7da8c0fec..527b07af31d0 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -195,6 +195,23 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output, dim, GetXlaShape(grad_output).dimensions_size())); } +torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, + const torch::lazy::Value& min, + const torch::lazy::Value& max) { + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type); + xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type); + return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, + GetXlaShape(input), std::move(lower_fn)); +} + torch::lazy::NodePtr Celu(const torch::lazy::Value& input, const at::Scalar& alpha) { auto lower_fn = [=](const XlaNode& node, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index ddb11288ec39..020bafdbf560 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -121,6 +121,10 @@ torch::lazy::NodePtr SoftmaxBackwardOp(const torch::lazy::Value& grad_output, const torch::lazy::Value& output, int64_t dim); +torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, + const torch::lazy::Value& min, + const torch::lazy::Value& max); + torch::lazy::NodePtr Celu(const torch::lazy::Value& input, const at::Scalar& alpha); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bbae55b74c27..80d799076048 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1226,6 +1226,14 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) { input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha)); } +XLATensorPtr clamp(const XLATensorPtr& input, + const std::optional& min, + const std::optional& max) { + MinMaxValues min_max = GetMinMaxValues(input, min, max); + return input->CreateFrom( + Clamp(input->GetIrValue(), min_max.min, min_max.max)); +} + XLATensorPtr clone(const XLATensorPtr& input) { XLATensorPtr cloned = input->CreateFrom(input->GetIrValue()); if (input->sharding_spec() != nullptr) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 003fd90934bd..79f6acd8049d 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -320,6 +320,13 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor); XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha); void celu_(XLATensorPtr& input, const at::Scalar& alpha); +XLATensorPtr clamp(const XLATensorPtr& input, + const std::optional& min, + const std::optional& max); +XLATensorPtr clamp(const XLATensorPtr& input, + const std::optional& min, + const std::optional& max); + XLATensorPtr clone(const XLATensorPtr& input); XLATensorPtr conj(const XLATensorPtr& input); From 478432d2a3f1c5c69f3f29ae9b60cefa91dbc26f Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 2 Jun 2025 19:18:19 +0000 Subject: [PATCH 3/3] WIP --- test/cpp/test_aten_xla_tensor_6.cpp | 1367 +-------------------------- torch_xla/csrc/aten_xla_type.cpp | 8 + 2 files changed, 17 insertions(+), 1358 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index df3e1280b6e2..87ef21697036 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -23,1367 +23,18 @@ class AtenXlaTensorTest : public AtenXlaTensorTestBase {}; } // namespace -TEST_F(AtenXlaTensorTest, TestTransposedConv3DBackward) { - int in_channels = 4; - int out_channels = 8; - int kernel_size = 5; - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - for (int dilation = 1; dilation <= 2; ++dilation) { - for (int output_padding = 0; - output_padding < std::max(stride, dilation); ++output_padding) { - for (bool with_bias : {true, false}) { - for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. - auto testfn = [&](const std::vector& inputs) - -> torch::Tensor { - return torch::conv_transpose3d( - inputs[0], inputs[1], inputs[2], - /*stride=*/{stride, stride + 1, stride}, - /*padding=*/{padding, padding + 1, stride}, - /*output_padding=*/output_padding, - /*groups=*/groups, - /*dilation=*/{dilation, dilation + 1, dilation}); - }; - ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = torch::rand( - {4, out_channels, 14, 14, 14}, - torch::TensorOptions(torch::kDouble).requires_grad(true)); - torch::Tensor weight = torch::rand( - {out_channels, in_channels / groups, kernel_size, - kernel_size, kernel_size}, - torch::TensorOptions(torch::kDouble).requires_grad(true)); - torch::Tensor bias = - with_bias ? torch::rand({in_channels}, - torch::TensorOptions(torch::kDouble) - .requires_grad(true)) - : torch::Tensor(); - TestBackward({input, weight, bias}, device, testfn); - }); - } - }; - } - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestMaxPool2DBackward) { - int kernel_size = 3; - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_pool2d( - inputs[0], /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, - /*ceil_mode=*/ceil_mode); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {1, 64, 112, 112}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::max_pool2d", cpp_test::GetIgnoredCounters()); - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestMaxPool3DBackward) { - int kernel_size = 3; - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_pool3d( - inputs[0], - /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*dilation=*/{1, 1, 1}, - /*ceil_mode=*/ceil_mode); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {1, 64, 16, 16, 16}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::max_pool3d", cpp_test::GetIgnoredCounters()); - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestMaxPool2DNoBatchBackward) { - int kernel_size = 3; - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_pool2d( - inputs[0], /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, - /*ceil_mode=*/ceil_mode); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {64, 112, 112}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestMaxPool3DNoBatchBackward) { - int kernel_size = 3; - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_pool3d( - inputs[0], - /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*dilation=*/{1, 1, 1}, - /*ceil_mode=*/ceil_mode); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {64, 16, 16, 16}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::max_pool3d", cpp_test::GetIgnoredCounters()); - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestMaxUnpool2DBackward) { - int kernel_size = 2; - torch::Tensor input = - torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat)); - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - for (int dilation = 1; dilation <= 2; ++dilation) { - torch::Tensor output; - torch::Tensor indices; - std::tie(output, indices) = torch::max_pool2d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, - /*ceil_mode=*/ceil_mode); - - std::vector output_size({input.size(2), input.size(3)}); - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_unpool2d(inputs[0], inputs[1], output_size); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward({output.requires_grad_(true), indices}, device, - testfn); - }); - } - } - } - } - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestMaxUnpool3DBackward) { - int kernel_size = 2; - torch::Tensor input = - torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat)); - for (int stride = 1; stride <= 2; ++stride) { - for (int padding = 0; padding <= 1; ++padding) { - // Test ceil_mode=true through the CPU interop. - for (bool ceil_mode : {false, true}) { - for (int dilation = 1; dilation <= 2; ++dilation) { - torch::Tensor output; - torch::Tensor indices; - std::tie(output, indices) = torch::max_pool3d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, - /*dilation=*/{dilation, dilation, dilation}, - /*ceil_mode=*/ceil_mode); - - std::vector output_size( - {input.size(2), input.size(3), input.size(4)}); - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_unpool3d(inputs[0], inputs[1], output_size, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward({output.requires_grad_(true), indices}, device, - testfn); - }); - } - } - } - } - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestTanhBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::tanh(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-5); - }); -} - -TEST_F(AtenXlaTensorTest, TestSigmoidBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::sigmoid(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestLogSigmoidBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::log_sigmoid(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-5); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::log_sigmoid_forward", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLogSoftmaxBackward) { - for (int dim = -4; dim < 4; ++dim) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::log_softmax(inputs[0], dim); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {5, 3, 4, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_log_softmax", cpp_test::GetIgnoredCounters()); - } -} - -TEST_F(AtenXlaTensorTest, TestSoftmaxBackward) { - for (int dim = -4; dim < 4; ++dim) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::softmax(inputs[0], dim); - }; - - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {5, 3, 4, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); - }); - } -} - -TEST_F(AtenXlaTensorTest, TestSoftplusBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::softplus(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn, /*rtol=*/1e-4); - }); -} - -TEST_F(AtenXlaTensorTest, TestReluBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::relu(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestRreluBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::rrelu(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::hardshrink(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::randn({100}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) { - if (UsingTpu()) { - GTEST_SKIP(); - } - torch::Tensor lambdaTensor = - torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32)); - torch::Scalar lambda = lambdaTensor.item(); - - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::hardshrink(inputs[0], lambda); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::randn( - {100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::softshrink(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::randn({100}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) { - if (UsingTpu()) { - GTEST_SKIP(); - } - torch::Tensor lambdaTensor = - torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32)); - torch::Scalar lambda = lambdaTensor.item(); - - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::softshrink(inputs[0], lambda); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::randn( - {100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestHardtanhBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::hardtanh(inputs[0]); - }; +TEST_F(AtenXlaTensorTest, TestClamp) { + torch::Tensor operand = + torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); + torch::Tensor min = torch::zeros_like(operand); + torch::Tensor max = torch::ones_like(operand); + torch::Tensor out = torch::clamp(operand, min, max); ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::randn({100}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); + torch::Tensor xla_operand = CopyToDevice(operand, device); + torch::Tensor xla_out = torch::clamp(xla_operand, min, max); + AllClose(out, xla_out, /*rtol=*/1e-3, /*atol=*/1e-5); }); } -TEST_F(AtenXlaTensorTest, TestEluBackward) { - torch::Scalar alpha = 0.5; - torch::Scalar scale = 2.5; - torch::Scalar input_scale = 1.5; - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::elu(inputs[0], alpha, scale, input_scale); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestGeluBackward) { - for (const auto& approximate : {"none", "tanh"}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::gelu(inputs[0], approximate); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand( - {2, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - ExpectCounterChanged("xla::gelu_backward", cpp_test::GetIgnoredCounters()); - } -} - -TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) { - double negative_slope = 0.01; - auto testfn = [=](const std::vector& inputs) -> torch::Tensor { - return torch::leaky_relu(inputs[0], negative_slope); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestTransposeBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::t(inputs[0]); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({2, 3}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestAddMatMulBackward) { - int in_channels = 32; - int out_channels = 320; - int labels = 50; - // Test beta != 1. through the CPU interop. - for (double beta : {1., 2.}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::addmm(inputs[0], inputs[1], inputs[2], /*beta=*/beta); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({labels}, - torch::TensorOptions(torch::kFloat).requires_grad(true)), - torch::rand({in_channels, out_channels}, - torch::TensorOptions(torch::kFloat).requires_grad(true)), - torch::rand( - {out_channels, labels}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); - } -} - -TEST_F(AtenXlaTensorTest, TestBinaryCrossEntropyBackward) { - if (UsingTpu()) { - GTEST_SKIP(); - } - int batch = 6; - int classes = 2; - for (auto dtype : {torch::kFloat, torch::kDouble}) { - for (bool def_weight : {false, true}) { - torch::Tensor input = torch::rand( - {batch, classes}, torch::TensorOptions(dtype).requires_grad(true)); - torch::Tensor target = - torch::rand({batch, classes}, torch::TensorOptions(dtype)); - torch::Tensor weight; - if (def_weight) { - weight = torch::rand({batch, classes}, torch::TensorOptions(dtype)); - } - for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, - torch::Reduction::None}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::binary_cross_entropy( - /*self=*/inputs[0], /*target=*/inputs[1], - /*weight=*/inputs[2], - /*reduction=*/reduction); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-4, - /*atol=*/1e-7); - }); - } - } - } - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::binary_cross_entropy", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::binary_cross_entropy_backward", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestNllLossBackward) { - int batch = 6; - int classes = 2; - for (auto dtype : {torch::kFloat, torch::kDouble}) { - for (int ignore_index : {-1, 0, 1, 5}) { - for (bool def_weight : {false, true}) { - torch::Tensor input = torch::rand( - {batch, classes}, torch::TensorOptions(dtype).requires_grad(true)); - torch::Tensor target = - torch::randint(std::min(ignore_index, 0), classes, {batch}, - torch::TensorOptions(torch::kLong)); - torch::Tensor weight; - if (def_weight) { - weight = torch::rand({classes}, torch::TensorOptions(dtype)); - } - for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, - torch::Reduction::None}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::nll_loss( - /*self=*/inputs[0], /*target=*/inputs[1], - /*weight=*/inputs[2], - /*reduction=*/reduction, /*ignore_index=*/ignore_index); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); - }); - } - } - } - } - - ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nll_loss_forward", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nll_loss_backward", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestNllLoss2dBackward) { - int batch = 6; - int classes = 2; - int height = 3; - int width = 3; - for (auto dtype : {torch::kFloat, torch::kDouble}) { - for (int ignore_index : {-1, 0, 1, 5}) { - for (bool def_weight : {false, true}) { - torch::Tensor input = - torch::rand({batch, classes, height, width}, - torch::TensorOptions(dtype).requires_grad(true)); - torch::Tensor target = torch::randint( - std::min(ignore_index, 0), classes, {batch, height, width}, - torch::TensorOptions(torch::kLong)); - torch::Tensor weight; - if (def_weight) { - weight = torch::rand({classes}, torch::TensorOptions(dtype)); - } - for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, - torch::Reduction::None}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::nll_loss2d( - /*self=*/inputs[0], /*target=*/inputs[1], - /*weight=*/inputs[2], - /*reduction=*/reduction, /*ignore_index=*/ignore_index); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); - }); - } - } - } - } - - ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nll_loss2d_forward", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nll_loss2d_backward", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestSmoothL1LossBackward) { - torch::Tensor input = torch::randn( - {2, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor target = - torch::randn({2, 4}, torch::TensorOptions(torch::kFloat)); - for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, - torch::Reduction::Sum}) { - for (double beta : {0.25, 1.}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::smooth_l1_loss(/*input=*/inputs[0], /*target=*/inputs[1], - /*reduction=*/reduction, /*beta=*/beta); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::smooth_l1_loss_backward", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestViewBackward) { - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return inputs[0].view({-1, 320}); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward( - {torch::rand({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).requires_grad(true))}, - device, testfn); - }); -} - -TEST_F(AtenXlaTensorTest, TestBatchNorm2DBackward) { - double momentum = 0.1; - double eps = 0.5; - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::batch_norm( - /*input=*/inputs[0], /*weight=*/inputs[1], /*bias=*/inputs[2], - /*running_mean=*/inputs[3], /*running_var=*/inputs[4], - /*training=*/true, /*momentum=*/momentum, /*eps=*/eps, - /*cudnn_enabled=*/false); - }; - int num_features = 3; - torch::Tensor undef; - for (bool undef_weight_bias : {false, true}) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = - torch::rand({2, num_features, 4, 4}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor weight = - undef_weight_bias - ? undef - : torch::rand( - {num_features}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor bias = - undef_weight_bias - ? undef - : torch::rand( - {num_features}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor running_mean = - torch::zeros({num_features}, torch::TensorOptions(torch::kFloat)); - torch::Tensor running_var = - torch::ones({num_features}, torch::TensorOptions(torch::kFloat)); - TestBackward({input, weight, bias, running_mean, running_var}, device, - testfn, - /*rtol=*/1e-3, /*atol=*/1e-4); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm_backward", - cpp_test::GetIgnoredCounters()); - } -} - -TEST_F(AtenXlaTensorTest, TestBatchNorm3DBackward) { - double momentum = 0.1; - double eps = 0.5; - auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::batch_norm( - /*input=*/inputs[0], /*weight=*/inputs[1], /*bias=*/inputs[2], - /*running_mean=*/inputs[3], /*running_var=*/inputs[4], - /*training=*/true, /*momentum=*/momentum, /*eps=*/eps, - /*cudnn_enabled=*/false); - }; - int num_features = 3; - torch::Tensor undef; - for (bool undef_weight_bias : {false, true}) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = - torch::rand({2, num_features, 4, 4, 2}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor weight = - undef_weight_bias - ? undef - : torch::rand( - {num_features}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor bias = - undef_weight_bias - ? undef - : torch::rand( - {num_features}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor running_mean = - torch::zeros({num_features}, torch::TensorOptions(torch::kFloat)); - torch::Tensor running_var = - torch::ones({num_features}, torch::TensorOptions(torch::kFloat)); - TestBackward({input, weight, bias, running_mean, running_var}, device, - testfn, - /*rtol=*/1e-3, /*atol=*/1e-3); - }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm_backward", - cpp_test::GetIgnoredCounters()); - } -} - -TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) { - int batch = 10; - int classes = 5; - torch::Tensor undef; - for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, - torch::Reduction::Sum}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::binary_cross_entropy_with_logits( - /*input=*/inputs[0], /*target=*/inputs[1], /*weight=*/inputs[2], - /*pos_weight=*/inputs[3], - /*reduction=*/reduction); - }; - for (bool undef_weight : {false, true}) { - for (bool undef_pos_weight : {false, true}) { - torch::Tensor input = torch::rand( - {batch, classes}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor target = torch::rand( - {batch, classes}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor weight = - undef_weight - ? undef - : torch::rand({classes}, torch::TensorOptions(torch::kFloat)); - torch::Tensor pos_weight = - undef_pos_weight - ? undef - : torch::rand({classes}, torch::TensorOptions(torch::kFloat)); - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight, pos_weight}, device, testfn, - /*rtol=*/1e-3, /*atol=*/1e-5); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - // binary_cross_entropy_with_logits_backward is composed of - // sub/mul_/add_/exp_/add_/log_/... ops in upstream pytorch. - ExpectCounterChanged("xla::add", cpp_test::GetIgnoredCounters()); - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestKlDivBackward) { - torch::Tensor input = torch::rand( - {4, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor target = torch::rand( - {4, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); - for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::kl_div(/*self=*/inputs[0], /*target=*/inputs[1], reduction); - }; - ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target}, device, testfn, /*rtol=*/1e-4, - /*atol=*/1e-5); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { - int num_weights = 32; - for (int padding_idx = -1; padding_idx < num_weights; ++padding_idx) { - for (bool scale_grad_by_freq : {false, true}) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::embedding(inputs[0], inputs[1], - /*padding_idx=*/padding_idx, - /*scale_grad_by_freq=*/scale_grad_by_freq, - /*sparse=*/false); - }; - ForEachDevice([&](const torch::Device& device) { - torch::Tensor weight = torch::rand( - {num_weights, 7}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor indices = torch::randint( - num_weights, {3, 9, 4}, torch::TensorOptions(torch::kLong)); - TestBackward({weight, indices}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); - }); - } - } -} - -TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { - XlaDeviceType hw_type = - static_cast(bridge::GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) { - return; - } - torch::Tensor growth_tracker = - torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); - torch::Tensor current_scale = - torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); - torch::Tensor found_inf = - torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat)); - torch::Tensor not_found_inf = - torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); - float scale_growth_factor = 2.0; - float scale_backoff_factor = 0.5; - int growth_interval = 3; - - torch::Tensor growth_tracker_result0 = - torch::scalar_tensor(1, torch::TensorOptions(torch::kInt32)); - torch::Tensor current_scale_result0 = - torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); - torch::Tensor growth_tracker_result1 = - torch::scalar_tensor(2, torch::TensorOptions(torch::kInt32)); - torch::Tensor current_scale_result1 = - torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); - torch::Tensor growth_tracker_result2 = - torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); - torch::Tensor current_scale_result2 = - torch::scalar_tensor(8, torch::TensorOptions(torch::kFloat)); - torch::Tensor growth_tracker_result3 = - torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32)); - torch::Tensor current_scale_result3 = - torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat)); - - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_growth_tracker = CopyToDevice(growth_tracker, device); - torch::Tensor xla_current_scale = CopyToDevice(current_scale, device); - torch::Tensor xla_found_inf = CopyToDevice(found_inf, device); - torch::Tensor xla_not_found_inf = CopyToDevice(not_found_inf, device); - - torch::_amp_update_scale_(xla_current_scale, xla_growth_tracker, - xla_not_found_inf, scale_growth_factor, - scale_backoff_factor, growth_interval); - AllClose(current_scale_result0, xla_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); - AllEqual(growth_tracker_result0, xla_growth_tracker); - - torch::_amp_update_scale_(xla_current_scale, xla_growth_tracker, - xla_not_found_inf, scale_growth_factor, - scale_backoff_factor, growth_interval); - AllClose(current_scale_result1, xla_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); - AllEqual(growth_tracker_result1, xla_growth_tracker); - - // torch::_amp_update_scale_ returns the reference of current_scale - xla_current_scale = torch::_amp_update_scale_( - xla_current_scale, xla_growth_tracker, xla_not_found_inf, - scale_growth_factor, scale_backoff_factor, growth_interval); - AllClose(current_scale_result2, xla_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); - AllEqual(growth_tracker_result2, xla_growth_tracker); - - xla_current_scale = torch::_amp_update_scale_( - xla_current_scale, xla_growth_tracker, xla_found_inf, - scale_growth_factor, scale_backoff_factor, growth_interval); - AllClose(current_scale_result3, xla_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); - AllEqual(growth_tracker_result3, xla_growth_tracker); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_amp_update_scale_", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) { - torch::Tensor scalar_tensor = - torch::scalar_tensor(1., torch::TensorOptions(torch::kFloat)); - torch::Scalar scalar1 = scalar_tensor.item(); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_scalar_tensor = CopyToDevice(scalar_tensor, device); - torch::Scalar scalar2 = xla_scalar_tensor.item(); - ASSERT_EQ(scalar1.to(), scalar2.to()); - }); - if (DebugUtil::ExperimentEnabled("early_sync")) { - ExpectCounterChanged("EarlySyncLiveTensorsCount", - cpp_test::GetIgnoredCounters()); - } else { - ExpectCounterNotChanged("EarlySyncLiveTensorsCount", - cpp_test::GetIgnoredCounters()); - } - ExpectCounterChanged("aten::_local_scalar_dense", - cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerp) { - torch::Tensor start = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor res = torch::lerp(start, end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_start = CopyToDevice(start, device); - torch::Tensor xla_end = CopyToDevice(end, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerpScalar) { - torch::Tensor start = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Scalar weight = torch::Scalar(3.0); - torch::Tensor res = torch::lerp(start, end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_start = CopyToDevice(start, device); - torch::Tensor xla_end = CopyToDevice(end, device); - torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerpInplace) { - torch::Tensor input = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor input_copy = input.clone(); - input.lerp_(end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input_copy, device); - torch::Tensor xla_end = CopyToDevice(end, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - xla_input.lerp_(xla_end, xla_weight); - AllClose(xla_input, input); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) { - torch::Tensor input = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Scalar weight = torch::Scalar(3.0); - torch::Tensor input_copy = input.clone(); - input.lerp_(end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input_copy, device); - torch::Tensor xla_end = CopyToDevice(end, device); - xla_input.lerp_(xla_end, weight); - AllClose(xla_input, input); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerpOut) { - torch::Tensor start = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat)); - ; - torch::lerp_out(res, start, end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_start = CopyToDevice(start, device); - torch::Tensor xla_end = CopyToDevice(end, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options()); - torch::lerp_out(xla_res, xla_start, xla_end, xla_weight); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLerpScalarOut) { - torch::Tensor start = - torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Scalar weight = torch::Scalar(3.0); - torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat)); - torch::lerp_out(res, start, end, weight); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_start = CopyToDevice(start, device); - torch::Tensor xla_end = CopyToDevice(end, device); - torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options()); - torch::lerp_out(xla_res, xla_start, xla_end, weight); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLinspaceStartEndMatch) { - torch::Scalar start = 0; - torch::Scalar end = 10; - int64_t steps = 100; - torch::Tensor res = torch::linspace(start, end, steps); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_res = torch::linspace( - start, end, steps, torch::TensorOptions().device(device)); - AllClose(res, xla_res); - AllEqual(torch::scalar_tensor(start), xla_res[0]); - AllEqual(torch::scalar_tensor(end), xla_res[steps - 1]); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::linspace", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestLinspaceDtypes) { - torch::Scalar start = 1; - torch::Scalar end = 100; - int64_t steps = 5; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kDouble, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor res = torch::linspace( - start, end, steps, torch::TensorOptions().dtype(scalar_type)); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_res = torch::linspace( - start, end, steps, - torch::TensorOptions().dtype(scalar_type).device(device)); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::linspace", cpp_test::GetIgnoredCounters()); - }; -} - -TEST_F(AtenXlaTensorTest, TestLinspaceSmallSteps) { - torch::Scalar start = 0; - torch::Scalar end = 10; - for (int64_t steps : {0, 1, 2}) { - torch::Tensor res = torch::linspace(start, end, steps); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_res = torch::linspace( - start, end, steps, torch::TensorOptions().device(device)); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::linspace", cpp_test::GetIgnoredCounters()); - } -} - -TEST_F(AtenXlaTensorTest, TestLinspaceReverse) { - torch::Scalar start = 0; - torch::Scalar end = -10; - int64_t steps = 100; - torch::Tensor res = torch::linspace(start, end, steps); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_res = torch::linspace( - start, end, steps, torch::TensorOptions().device(device)); - AllClose(res, xla_res); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::linspace", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestNanToNum) { - for (torch::ScalarType scalar_type : - {torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor input = - isFloatingType(scalar_type) - ? torch::tensor( - {1.0, std::nan("1"), std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}, - torch::TensorOptions(scalar_type)) - : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type)); - torch::Tensor output = torch::nan_to_num(input); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = torch::nan_to_num(xla_input); - if (static_cast( - bridge::AtenDeviceToXlaDevice(device).type()) == - XlaDeviceType::TPU && - scalar_type == torch::kDouble) { - // Since TPU converts double to float (unlike CPU), the Inf entries are - // expected to be different. Skipping checks for Inf entries. - AllEqual(output[0], xla_output[0]); - AllEqual(output[1], xla_output[1]); - } else { - AllClose(output, xla_output); - } - }); - output = - torch::nan_to_num(input, /*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = torch::nan_to_num( - xla_input, /*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0); - AllClose(output, xla_output); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestNanToNumOut) { - for (torch::ScalarType scalar_type : - {torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor input = - isFloatingType(scalar_type) - ? torch::tensor( - {1.0, std::nan("1"), std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}, - torch::TensorOptions(scalar_type)) - : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type)); - torch::Tensor output = torch::zeros_like(input); - torch::nan_to_num_out(output, input); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = torch::zeros_like(input); - torch::nan_to_num_out(xla_output, xla_input); - if (static_cast( - bridge::AtenDeviceToXlaDevice(device).type()) == - XlaDeviceType::TPU && - scalar_type == torch::kDouble) { - // Since TPU converts double to float (unlike CPU), the Inf entries are - // expected to be different. Skipping checks for Inf entries. - AllEqual(output[0], xla_output[0]); - AllEqual(output[1], xla_output[1]); - } else { - AllClose(output, xla_output); - } - }); - torch::nan_to_num_out(output, input, /*nan=*/1.0, /*posinf=*/2.0, - /*neginf=*/3.0); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = torch::zeros_like(input); - torch::nan_to_num_out(xla_output, xla_input, /*nan=*/1.0, /*posinf=*/2.0, - /*neginf=*/3.0); - AllClose(output, xla_output); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestRoll) { - std::vector input_shape = {2, 3, 4}; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor input = - isFloatingType(scalar_type) - ? torch::rand(input_shape, torch::TensorOptions(scalar_type)) - : torch::randint(0, 100, input_shape, - torch::TensorOptions(scalar_type)); - std::vector> dim_powerset = { - {}, {0}, {1}, {2}, {0, 1}, {1, 2}, {2, 0}, {0, 1, 2}}; - std::vector>> shift_set = { - {{0}, {1}, {1}, {-24}, {24}, {-27}, {27}}, - {{0}, {-1}, {1}, {-5}, {5}}, - {{0}, {-1}, {1}, {-5}, {5}}, - {{0}, {-1}, {1}, {-5}, {5}}, - {{0, 0}, {-1, 4}}, - {{1, 2}, {0, -1}}, - {{0, 2}, {-1, 0}}, - {{4, 3, 2}, {-4, 3, 2}}, - }; - for (size_t i = 0; i < dim_powerset.size(); ++i) { - std::vector roll_dims = dim_powerset[i]; - for (bool negative_dims : {false, true}) { - if (negative_dims) { - std::for_each(roll_dims.begin(), roll_dims.end(), - [](int64_t& dim) { dim -= 3; }); - } - for (std::vector roll_shifts : shift_set[i]) { - torch::Tensor output = torch::roll(input, roll_shifts, roll_dims); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = - torch::roll(xla_input, roll_shifts, roll_dims); - AllClose(output, xla_output); - }); - } - } - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::roll", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestViewIsAliasOf) { - torch::Tensor a = torch::empty(4, torch::TensorOptions(torch::kFloat)); - torch::Tensor b = torch::empty(4, torch::TensorOptions(torch::kFloat)); - - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - EXPECT_EQ(!a.is_alias_of(b), !xla_a.is_alias_of(xla_b)); - - torch::Tensor c = a.view({2, 2}); - torch::Tensor xla_c = xla_a.view({2, 2}); - EXPECT_EQ(a.is_alias_of(c), xla_a.is_alias_of(xla_c)); - - torch::Tensor d = c.view({1, 4}); - torch::Tensor lazy_d = xla_c.view({1, 4}); - EXPECT_EQ(d.is_alias_of(c), lazy_d.is_alias_of(xla_c)); - EXPECT_EQ(d.is_alias_of(a), lazy_d.is_alias_of(xla_a)); - }); -} - -TEST_F(AtenXlaTensorTest, TestExpandIsAliasOf) { - torch::Tensor a = torch::empty(4, torch::TensorOptions(torch::kFloat)); - torch::Tensor b = a.expand(4, 3); - EXPECT_TRUE(a.is_alias_of(b)); - - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = xla_a.expand(4, 3); - EXPECT_EQ(a.is_alias_of(b), xla_a.is_alias_of(xla_b)); - }); -} - -TEST_F(AtenXlaTensorTest, TestCdistForward) { - torch::Tensor a = - torch::rand({2, 20, 5}, torch::TensorOptions(torch::kFloat)); - torch::Tensor b = - torch::rand({2, 10, 5}, torch::TensorOptions(torch::kFloat)); - std::vector p_list = {0.0, 1.0, 2.0, 5.0, - std::numeric_limits::infinity()}; - for (const auto& p : p_list) { - torch::Tensor c = torch::cdist(a, b, p); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - torch::Tensor xla_c = torch::cdist(xla_a, xla_b, p); - AllClose(c, xla_c); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestPdistForward) { - torch::Tensor a = torch::rand({10, 11}, torch::TensorOptions(torch::kFloat)); - std::vector p_list = {1.0, 2.0, 5.0}; - for (const auto& p : p_list) { - torch::Tensor c = torch::pdist(a, p); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_c = torch::pdist(xla_a, p); - AllClose(c, xla_c); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_pdist_forward", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestPdistForwardZeroSize) { - torch::Tensor a = torch::rand({0, 2}, torch::TensorOptions(torch::kFloat)); - std::vector p_list = {1.0, 2.0, 5.0}; - for (const auto& p : p_list) { - torch::Tensor c = torch::pdist(a, p); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_c = torch::pdist(xla_a, p); - AllClose(c, xla_c); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_pdist_forward", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestPdistForwardSingleRow) { - torch::Tensor a = torch::rand({1, 2}, torch::TensorOptions(torch::kFloat)); - std::vector p_list = {1.0, 2.0, 5.0}; - for (const auto& p : p_list) { - torch::Tensor c = torch::pdist(a, p); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_c = torch::pdist(xla_a, p); - AllClose(c, xla_c); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_pdist_forward", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestGlu) { - std::vector> sizes{ - {3, 8}, {3, 5, 6}, {3, 8, 5}, {3, 8, 8, 16}}; - std::vector dims{-1, -1, 1, 3}; - - auto size_it = sizes.begin(); - auto dim_it = dims.begin(); - for (; size_it != sizes.end() && dim_it != dims.end(); ++size_it, ++dim_it) { - torch::Tensor input = - torch::rand(*size_it, torch::TensorOptions(torch::kFloat)); - torch::Tensor output = torch::glu(input, *dim_it); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_output = torch::glu(xla_input, *dim_it); - AllClose(output, xla_output); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::glu", cpp_test::GetIgnoredCounters()); -} - } // namespace cpp_test } // namespace torch_xla diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b04169cc29ae..250e97422945 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1312,6 +1312,14 @@ at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, return self; } +at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, + const at::Scalar& min_val, + const at::Scalar& max_val) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + return bridge::AtenFromXlaTensor( + XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val)); +} + at::Tensor XLANativeFunctions::clone( const at::Tensor& self, std::optional /* memory_format */) {