diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden index b160c6f4a57acc..73848e39a3bfb9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden +++ b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden @@ -45,7 +45,7 @@ Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHa // Summary: // // Description: -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, const char* name, const char* raw_device_name) { +Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, bool grad_a, bool grad_b, const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("MatMul", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -53,6 +53,8 @@ Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTenso TF_RETURN_IF_ERROR(op_ptr->AddInput(b)); TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_a", transpose_a)); TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_b", transpose_b)); + TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_a", grad_a)); + TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_b", grad_b)); int num_retvals = 1; return op_ptr->Execute(absl::MakeSpan(product, 1), &num_retvals); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden index 99b797a7c112ea..06c5856c7cc451 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden +++ b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden @@ -28,7 +28,7 @@ namespace ops { Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHandle** y, const char* name = nullptr, const char* raw_device_name = nullptr); // -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, const char* name = nullptr, const char* raw_device_name = nullptr); +Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, bool grad_a = false, bool grad_b = false, const char* name = nullptr, const char* raw_device_name = nullptr); // Status IdentityN(AbstractContext* ctx, absl::Span input, absl::Span output, const char* name = nullptr, const char* raw_device_name = nullptr); diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir index f6e3ca10f5a279..56620e66870520 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir @@ -51,7 +51,7 @@ func.func @batchmatmulv2(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // SUPPORTED_FALLBACK_DEVICE: mhlo.dot_general // SUPPORTED_FALLBACK_DEVICE: mhlo.transpose - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> func.return %0 : tensor<3x4x4xf32> } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index b952fe58156eb7..9860cb1e521610 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -542,7 +542,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: mhlo.reduce // CHECK: mhlo.dot_general // CHECK: mhlo.transpose - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> func.return %0 : tensor<3x4x4xf32> } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 76a91179da62be..c13ee6062325d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -32,6 +32,8 @@ class BatchMatMulOp : public XlaOpKernel { explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_x", &grad_x_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_y", &grad_y_)); if (ctx->HasAttr("Tout")) { DataType output_type; @@ -50,13 +52,15 @@ class BatchMatMulOp : public XlaOpKernel { : xla::PrecisionConfig::HIGHEST; auto result = xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, MaybeConjugate(ctx->Input(1), adj_y_), adj_y_, - precision, preferred_element_type_); + precision, grad_x_, grad_y_, preferred_element_type_); ctx->SetOutput(0, result); } private: bool adj_x_; bool adj_y_; + bool grad_x_; + bool grad_y_; std::optional preferred_element_type_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 8eaa39a1fcde12..959829bc99d841 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" +#include "tensorflow/compiler/jit/flags.h" namespace tensorflow { namespace { @@ -246,6 +247,8 @@ StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + xla::XlaOp conv_forward; + // For 2D convolution, there should be 4 dimensions. int num_dims = attrs.num_spatial_dims + 2; if (input_shape.dimensions_size() != num_dims) { @@ -329,13 +332,32 @@ StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, xla::PrecisionConfig precision_config = GetPrecisionConfig(); if (padding_type != xla::PaddingType::PADDING_INVALID) { - return xla::DynamicConvForward( + conv_forward = xla::DynamicConvForward( conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, dims, /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count, /*batch_group_count=*/1, &precision_config, padding_type); + } else { + conv_forward = xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, + /*feature_group_count=*/attrs.depthwise ? in_depth + : feature_group_count, + /*batch_group_count=*/1, &precision_config); + } + + // TODO : CHECK_EQ(HloOpcode::kConvolution, conv_forward->opcode()); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED + auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + state = tensorflow::GetMlirBridgeRolloutState(std::nullopt); + if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(conv_forward, "call_context", + "kForward")); } +#endif return xla::ConvGeneralDilated( conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, @@ -361,6 +383,8 @@ StatusOr MakeXlaBackpropInputConvOp(StringPiece type_string, TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(out_backprop)); + xla::XlaOp input_backprop; + int64_t in_depth = input_shape.dimensions(feature_dim), filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), feature_group_count = @@ -429,7 +453,7 @@ StatusOr MakeXlaBackpropInputConvOp(StringPiece type_string, filter = xla::Rev(filter, kernel_spatial_dims); if (padding_type != xla::PaddingType::PADDING_INVALID) { TF_RET_CHECK(input_sizes != nullptr); - return xla::DynamicConvInputGrad( + input_backprop = xla::DynamicConvInputGrad( *input_sizes, out_backprop, filter, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums, /*feature_group_count=*/ @@ -443,6 +467,18 @@ StatusOr MakeXlaBackpropInputConvOp(StringPiece type_string, /*feature_group_count=*/ feature_group_count, /*batch_group_count=*/1, &precision_config); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED + auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + state = tensorflow::GetMlirBridgeRolloutState(std::nullopt); + if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(input_backprop, "call_context", + "kBackpropData")); + } +#endif + + return input_backprop; } StatusOr MakeXlaBackpropFilterConvOp(StringPiece type_string, @@ -598,6 +634,18 @@ StatusOr MakeXlaBackpropFilterConvOp(StringPiece type_string, /*batch_group_count=*/batch_group_count, &precision_config); } + // TODO : CHECK_EQ(HloOpcode::kConvolution, filter_backprop->opcode()); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED + auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + state = tensorflow::GetMlirBridgeRolloutState(std::nullopt); + if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(filter_backprop, "call_context", + "kBackpropFilter")); + } +#endif + if (attrs.depthwise) { filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions()); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 86c0d97e97f936..cdffcbeab35285 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -39,6 +39,8 @@ class MatMulOp : public XlaOpKernel { : XlaOpKernel(ctx), is_sparse_(is_sparse) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_a", &grad_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_b", &grad_b_)); if (is_sparse) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Ta", &a_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tb", &b_type_)); @@ -96,13 +98,15 @@ class MatMulOp : public XlaOpKernel { ? xla::PrecisionConfig::DEFAULT : xla::PrecisionConfig::HIGHEST; ctx->SetOutput(0, - xla::BatchDot(a, transpose_a_, b, transpose_b_, precision)); + xla::BatchDot(a, transpose_a_, b, transpose_b_, precision, grad_a_, grad_b_)); } private: bool is_sparse_; bool transpose_a_; bool transpose_b_; + bool grad_a_; + bool grad_b_; DataType a_type_; DataType b_type_; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d5c708e4b764ad..c325c0866fb8c6 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -234,6 +234,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/jit:flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index eb4a8a5e0b5b31..55e73b71253f77 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/jit/flags.h" #include #include @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/jit/flags.h" namespace xla { @@ -385,25 +387,27 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, absl::Span y_config, absl::Span output_config, xla::PrecisionConfig::Precision precision, - std::optional preferred_element_type) { + std::optional preferred_element_type, + bool grad_x, bool grad_y) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { auto x_diagonal_labels = EinsumDiagonalLabels(x_config); if (x_diagonal_labels) { return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y, - y_config, output_config, precision, preferred_element_type); + y_config, output_config, precision, preferred_element_type, + grad_x, grad_y); } auto y_diagonal_labels = EinsumDiagonalLabels(y_config); if (y_diagonal_labels) { return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels->at(0), output_config, precision, - preferred_element_type); + preferred_element_type, grad_x, grad_y); } auto output_diagonal_labels = EinsumDiagonalLabels(output_config); if (output_diagonal_labels) { return EinsumInverseDiagonal( Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0), - precision, preferred_element_type), + precision, preferred_element_type, grad_x, grad_y), output_config); } @@ -547,6 +551,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, precision_proto.add_operand_precision(precision); auto dot = DotGeneral(x, y, dnums, &precision_proto, preferred_element_type); + + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // Set grad_x, grad_y attributes, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED + auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + state = tensorflow::GetMlirBridgeRolloutState(std::nullopt); + if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(dot, "grad_x", + (grad_x ? "true" : "false"))); + TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(dot, "grad_y", + (grad_y ? "true" : "false"))); + } +#endif + dot = Transpose(dot, transpose_dims); if (transpose_rank == output_rank) { return dot; @@ -573,11 +591,12 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision, std::optional preferred_element_type) { - return BatchDot(x, false, y, false, precision, preferred_element_type); + return BatchDot(x, false, y, false, precision, false, false, preferred_element_type); } XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y, PrecisionConfig::Precision precision, + bool grad_x, bool grad_y, std::optional preferred_element_type) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { @@ -588,7 +607,8 @@ XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y, if (transpose_y) { std::swap(string[6 + 3], string[6 + 4]); } - return Einsum(x, y, string, precision, preferred_element_type); + return Einsum(x, y, string, precision, preferred_element_type, grad_x, + grad_y); }); } @@ -709,12 +729,14 @@ std::string NormalizeEinsumString(absl::string_view einsum_config) { XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, PrecisionConfig::Precision precision, - std::optional preferred_element_type) { + std::optional preferred_element_type, + bool grad_x, bool grad_y) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { auto new_config = NormalizeEinsumString(einsum_config); if (!new_config.empty()) { - return Einsum(x, y, new_config, precision, preferred_element_type); + return Einsum(x, y, new_config, precision, preferred_element_type, grad_x, + grad_y); } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); @@ -722,7 +744,8 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, auto einsum_config_numeric, ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank())); return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], - einsum_config_numeric[2], precision, preferred_element_type); + einsum_config_numeric[2], precision, preferred_element_type, + grad_x, grad_y); }); } diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index b24feca3ea8e6d..49a816fdc38dc6 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -97,6 +97,7 @@ xla::XlaOp BatchDot( xla::XlaOp BatchDot( xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + bool grad_x = false, bool grad_y = false, std::optional preferred_element_type = std::nullopt); // Parse an einsum string into dimension numbers: @@ -128,7 +129,8 @@ std::string NormalizeEinsumString(absl::string_view einsum_config); xla::XlaOp Einsum( xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt); + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); xla::XlaOp Einsum( xla::XlaOp x, absl::string_view einsum_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); @@ -143,7 +145,8 @@ xla::XlaOp Einsum( xla::XlaOp x, absl::Span x_config, xla::XlaOp y, absl::Span y_config, absl::Span output_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt); + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/xla/examples/axpy/BUILD b/tensorflow/compiler/xla/examples/axpy/BUILD index 8c1442922c3432..e6512ddf60de50 100644 --- a/tensorflow/compiler/xla/examples/axpy/BUILD +++ b/tensorflow/compiler/xla/examples/axpy/BUILD @@ -1,8 +1,9 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) -xla_cc_test( +tf_cc_test( name = "stablehlo_compile_test", srcs = ["stablehlo_compile_test.cc"], data = ["stablehlo_axpy.mlir"], diff --git a/tensorflow/compiler/xla/hlo/evaluator/BUILD b/tensorflow/compiler/xla/hlo/evaluator/BUILD index a4253c21179701..8d837cd06effb4 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/BUILD +++ b/tensorflow/compiler/xla/hlo/evaluator/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -81,7 +82,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_evaluator_test", srcs = ["hlo_evaluator_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD index 2ccbf23c02f9d7..70e713c5a65495 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary", "xla_cc_test") load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -188,7 +189,7 @@ build_test( ], ) -xla_cc_test( +tf_cc_test( name = "auto_sharding_test", srcs = ["auto_sharding_test.cc"], deps = [ @@ -203,7 +204,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "auto_sharding_solver_test", srcs = ["auto_sharding_solver_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/hlo/transforms/BUILD b/tensorflow/compiler/xla/hlo/transforms/BUILD index 1f399dbb837739..64878df1655432 100644 --- a/tensorflow/compiler/xla/hlo/transforms/BUILD +++ b/tensorflow/compiler/xla/hlo/transforms/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -24,7 +25,7 @@ cc_library( deps = ["//tensorflow/compiler/xla/service:hlo_pass"], ) -xla_cc_test( +tf_cc_test( name = "hlo_constant_splitter_test", srcs = ["hlo_constant_splitter_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/hlo/utils/BUILD b/tensorflow/compiler/xla/hlo/utils/BUILD index b478af17925411..73b1e8d96fea34 100644 --- a/tensorflow/compiler/xla/hlo/utils/BUILD +++ b/tensorflow/compiler/xla/hlo/utils/BUILD @@ -1,6 +1,7 @@ # Description: # Implementation of XLA’s HLO utilities used for higher-level transformations. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load( "//tensorflow/compiler/xla:xla.bzl", @@ -45,7 +46,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_live_range_test", srcs = ["hlo_live_range_test.cc"], deps = [ @@ -75,7 +76,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_matchers_test", srcs = ["hlo_matchers_test.cc"], deps = [ @@ -113,7 +114,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_sharding_util_test", srcs = [ "hlo_sharding_util_test.cc", diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index b76587b9648d96..54bc3489c41591 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -168,7 +168,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { F64Attr:$alpha_real, F64Attr:$alpha_imag, F64Attr:$beta, - OptionalAttr:$algorithm); + OptionalAttr:$algorithm, + OptionalAttr:$grad_x, + OptionalAttr:$grad_y + ); } def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandSegments]> { @@ -185,7 +188,9 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS F64Attr:$alpha_imag, F64Attr:$beta, CublasLtMatmulEpilogueAttr:$epilogue, - I64Attr:$algorithm); + I64Attr:$algorithm, + OptionalAttr:$grad_x, + OptionalAttr:$grad_y); } def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8", [AttrSizedOperandSegments]> { @@ -206,7 +211,9 @@ def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8", [AttrSizedOpe F64Attr:$alpha_imag, F64Attr:$beta, CublasLtMatmulEpilogueAttr:$epilogue, - I64Attr:$algorithm); + I64Attr:$algorithm, + OptionalAttr:$grad_x, + OptionalAttr:$grad_y); } def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index e879fe61ed455a..35bf9f83c93ede 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library", @@ -58,7 +59,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "semaphore_test", srcs = ["semaphore_test.cc"], deps = [ @@ -95,7 +96,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tracked_device_buffer_test", srcs = ["tracked_device_buffer_test.cc"], deps = [ @@ -145,7 +146,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_api_test", srcs = ["pjrt_api_test.cc"], deps = [ @@ -205,7 +206,7 @@ cc_library( alwayslink = 1, ) -xla_cc_test( +tf_cc_test( name = "pjrt_client_test_cpu", srcs = ["pjrt_client_test_cpu.cc"], deps = [ @@ -236,7 +237,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_executable_test", srcs = ["pjrt_executable_test.cc"], deps = [ @@ -277,7 +278,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_compiler_test", srcs = ["pjrt_compiler_test.cc"], deps = [ @@ -385,7 +386,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_stream_executor_client_test", srcs = ["pjrt_stream_executor_client_test.cc"], deps = [ @@ -504,7 +505,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tracked_tfrt_cpu_device_buffer_test", srcs = ["tracked_tfrt_cpu_device_buffer_test.cc"], deps = [ @@ -618,7 +619,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tfrt_cpu_pjrt_client_test", srcs = ["tfrt_cpu_pjrt_client_test.cc"], deps = [ @@ -652,7 +653,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "lru_cache_test", srcs = ["lru_cache_test.cc"], deps = [ @@ -689,7 +690,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "transpose_test", srcs = ["transpose_test.cc"], deps = [ @@ -754,7 +755,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tf_pjrt_client_test", srcs = ["tf_pjrt_client_test.cc"], deps = [ @@ -779,7 +780,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "host_callback_test", srcs = ["host_callback_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/pjrt/c/BUILD b/tensorflow/compiler/xla/pjrt/c/BUILD index 8b86f30ba2380a..a898392d215558 100644 --- a/tensorflow/compiler/xla/pjrt/c/BUILD +++ b/tensorflow/compiler/xla/pjrt/c/BUILD @@ -4,6 +4,7 @@ load( "tf_cuda_tests_tags", ) load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -122,7 +123,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_c_api_gpu_test", srcs = ["pjrt_c_api_gpu_test.cc"], tags = tf_cuda_tests_tags(), @@ -140,7 +141,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_c_api_helpers_test", srcs = ["pjrt_c_api_helpers_test.cc"], deps = [ @@ -155,7 +156,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_c_api_cpu_test", srcs = ["pjrt_c_api_cpu_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/pjrt/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD index 9109ee6518bcf0..555a3532f3084c 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/BUILD +++ b/tensorflow/compiler/xla/pjrt/distributed/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") licenses(["notice"]) diff --git a/tensorflow/compiler/xla/pjrt/gpu/BUILD b/tensorflow/compiler/xla/pjrt/gpu/BUILD index 3b749c980316fd..e843b3a7cba080 100644 --- a/tensorflow/compiler/xla/pjrt/gpu/BUILD +++ b/tensorflow/compiler/xla/pjrt/gpu/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/tsl:tsl.bzl", "if_nccl") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load("//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library") @@ -86,7 +87,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "se_gpu_pjrt_client_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_client_test.cc"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), @@ -159,7 +160,7 @@ cc_library( ] + if_nccl(["@local_config_nccl//:nccl"]), ) -xla_cc_test( +tf_cc_test( name = "pjrt_client_test_se_gpu", srcs = ["pjrt_client_test_se_gpu.cc"], tags = [ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ed0b717abff8f7..19f21def8a3723 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -5,6 +5,7 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/compiler/xla:xla.bzl", "xla_cc_test", @@ -616,7 +617,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "outfeed_receiver_test_cpu", size = "small", srcs = ["outfeed_receiver_test.cc"], diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index ae36a1ebe8026e..dca44056390610 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package_group( name = "friends", @@ -90,7 +91,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "array_test", size = "small", srcs = ["array_test.cc"], @@ -102,7 +103,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "future_test", size = "small", srcs = ["future_test.cc"], @@ -115,7 +116,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "index_domain_test", size = "small", srcs = ["index_domain_test.cc"], @@ -125,7 +126,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "index_test", size = "small", srcs = ["index_test.cc"], @@ -135,7 +136,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "memory_test", size = "small", srcs = ["memory_test.cc"], @@ -145,7 +146,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "shape_test", size = "small", srcs = ["shape_test.cc"], @@ -155,7 +156,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "sharding_test", size = "small", srcs = ["sharding_test.cc"], @@ -226,7 +227,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "array_test_no_impl", srcs = [], deps = [ @@ -247,7 +248,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "client_test_no_impl", srcs = [], deps = [ @@ -273,7 +274,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "tuple_test_no_impl", srcs = [], deps = [ @@ -318,7 +319,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "serdes_test", srcs = ["serdes_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index bed257ecb9ee98..b3774455109d86 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load("//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package_group( name = "friends", @@ -93,7 +94,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "xla_program_serdes_test", srcs = ["xla_program_serdes_test.cc"], deps = [ @@ -133,7 +134,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "xla_sharding_serdes_test", srcs = ["xla_sharding_serdes_test.cc"], deps = [ @@ -165,7 +166,7 @@ cc_library( ) # TODO(hyeontaek): Move this target out of pjrt_ifrt. -xla_cc_test( +tf_cc_test( name = "xla_executable_test_no_impl", srcs = [], deps = [ @@ -176,7 +177,7 @@ xla_cc_test( ) # TODO(hyeontaek): Move this target out of pjrt_ifrt. -xla_cc_test( +tf_cc_test( name = "xla_sharding_test", size = "small", srcs = ["xla_sharding_test.cc"], @@ -250,7 +251,7 @@ cc_library( alwayslink = True, ) -xla_cc_test( +tf_cc_test( name = "pjrt_array_impl_test_tfrt_cpu", size = "small", srcs = ["pjrt_array_impl_test_tfrt_cpu.cc"], @@ -263,7 +264,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_client_impl_test_tfrt_cpu", size = "small", srcs = [], @@ -274,7 +275,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_executable_impl_test_tfrt_cpu", size = "small", srcs = ["pjrt_executable_impl_test_tfrt_cpu.cc"], @@ -287,7 +288,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pjrt_tuple_impl_test_tfrt_cpu", size = "small", srcs = [], diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 855e9d4a46ba9f..804287666e4a4d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary", @@ -119,7 +120,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "async_collective_creator_test", srcs = ["async_collective_creator_test.cc"], deps = [ @@ -144,7 +145,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "async_op_canonicalizer_test", srcs = ["async_op_canonicalizer_test.cc"], deps = [ @@ -172,7 +173,7 @@ cc_library( deps = [":change_op_data_type"], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_promotion_test", srcs = ["all_reduce_promotion_test.cc"], deps = [ @@ -203,7 +204,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_reassociate_test", srcs = ["all_reduce_reassociate_test.cc"], deps = [ @@ -235,7 +236,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_folder_test", srcs = ["all_reduce_folder_test.cc"], deps = [ @@ -268,7 +269,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "broadcast_canonicalizer_test", srcs = ["broadcast_canonicalizer_test.cc"], deps = [ @@ -297,7 +298,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "bfloat16_conversion_folding_test", srcs = ["bfloat16_conversion_folding_test.cc"], deps = [ @@ -331,7 +332,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "float_normalization_test", srcs = ["float_normalization_test.cc"], deps = [ @@ -372,7 +373,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "bfloat16_propagation_test", srcs = ["bfloat16_propagation_test.cc"], deps = [ @@ -404,7 +405,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "collective_permute_decomposer_test", srcs = ["collective_permute_decomposer_test.cc"], deps = [ @@ -432,7 +433,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "constant_value_test", srcs = ["constant_value_test.cc"], deps = [ @@ -469,7 +470,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "value_range_test", srcs = ["value_range_test.cc"], deps = [ @@ -482,7 +483,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "convert_async_collectives_to_sync_test", srcs = ["convert_async_collectives_to_sync_test.cc"], deps = [ @@ -525,7 +526,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "collective_pipeliner_test", srcs = ["collective_pipeliner_test.cc"], deps = [ @@ -591,7 +592,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ @@ -610,7 +611,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_opcode_test", srcs = ["hlo_opcode_test.cc"], deps = [ @@ -656,7 +657,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "sharding_propagation_test", srcs = [ "sharding_propagation_test.cc", @@ -691,7 +692,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "sharding_remover_test", size = "small", srcs = [ @@ -722,7 +723,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], deps = [ @@ -758,7 +759,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], deps = [ @@ -786,7 +787,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pattern_matcher_test", srcs = ["pattern_matcher_test.cc"], deps = [ @@ -812,7 +813,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "pattern_matcher_gmock_test", srcs = ["pattern_matcher_gmock_test.cc"], deps = [ @@ -825,7 +826,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_reachability_test", srcs = ["hlo_reachability_test.cc"], deps = [ @@ -839,7 +840,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], deps = [ @@ -859,7 +860,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_sharding_test", srcs = ["hlo_sharding_test.cc"], deps = [ @@ -893,7 +894,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "call_graph_test", srcs = ["call_graph_test.cc"], deps = [ @@ -942,7 +943,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "call_inliner_test", size = "small", srcs = ["call_inliner_test.cc"], @@ -972,7 +973,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_computation_deduplicator_test", size = "small", srcs = ["hlo_computation_deduplicator_test.cc"], @@ -993,7 +994,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "flatten_call_graph_test", srcs = ["flatten_call_graph_test.cc"], deps = [ @@ -1173,7 +1174,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "latency_hiding_scheduler_test", srcs = ["latency_hiding_scheduler_test.cc"], deps = [ @@ -1202,7 +1203,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "latency_hiding_scheduler_preparation_test", srcs = ["latency_hiding_scheduler_preparation_test.cc"], deps = [ @@ -1232,7 +1233,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "profile_guided_latency_estimator_test", srcs = ["profile_guided_latency_estimator_test.cc"], deps = [ @@ -1348,7 +1349,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "shaped_buffer_test", srcs = ["shaped_buffer_test.cc"], deps = [ @@ -1531,7 +1532,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "name_uniquer_test", srcs = ["name_uniquer_test.cc"], deps = [ @@ -1581,7 +1582,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "buffer_assignment_test", srcs = ["buffer_assignment_test.cc"], deps = [ @@ -1638,7 +1639,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_ordering_test", size = "small", srcs = ["hlo_ordering_test.cc"], @@ -1686,7 +1687,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ @@ -1707,7 +1708,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_module_group_test", srcs = ["hlo_module_group_test.cc"], # TODO(b/148211710) Test fails in OSS. @@ -1783,7 +1784,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ @@ -1801,7 +1802,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_input_output_alias_config_test", srcs = ["hlo_input_output_alias_config_test.cc"], deps = [ @@ -1845,7 +1846,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_memory_scheduler_test", srcs = ["hlo_memory_scheduler_test.cc"], deps = [ @@ -1896,7 +1897,7 @@ cc_library( ] + if_google(["@com_google_absl//absl/types:source_location"]), ) -xla_cc_test( +tf_cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], deps = [ @@ -1965,7 +1966,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "fusion_node_indexing_evaluation_test", srcs = ["fusion_node_indexing_evaluation_test.cc"], deps = [ @@ -1980,7 +1981,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_creation_utils_test", srcs = ["hlo_creation_utils_test.cc"], deps = [ @@ -2088,7 +2089,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "scatter_expander_test", srcs = ["scatter_expander_test.cc"], deps = [ @@ -2127,7 +2128,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "triangular_solve_expander_test", size = "medium", srcs = ["triangular_solve_expander_test.cc"], @@ -2201,7 +2202,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "real_imag_expander_test", size = "small", srcs = ["real_imag_expander_test.cc"], @@ -2262,7 +2263,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "convolution_4d_expander_test", srcs = ["convolution_4d_expander_test.cc"], deps = [ @@ -2292,7 +2293,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "convolution_pred_expander_test", srcs = ["convolution_pred_expander_test.cc"], deps = [ @@ -2384,7 +2385,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "algebraic_simplifier_test", srcs = ["algebraic_simplifier_test.cc"], deps = [ @@ -2434,7 +2435,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "simplify_fp_conversions_test", srcs = ["simplify_fp_conversions_test.cc"], deps = [ @@ -2471,7 +2472,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "logistic_expander_test", srcs = ["logistic_expander_test.cc"], deps = [ @@ -2514,7 +2515,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "collectives_schedule_linearizer_test", srcs = ["collectives_schedule_linearizer_test.cc"], deps = [ @@ -2599,7 +2600,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "bitcast_dtypes_expander_test", srcs = ["bitcast_dtypes_expander_test.cc"], deps = [ @@ -2613,7 +2614,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "all_gather_broadcast_reorder_test", srcs = ["all_gather_broadcast_reorder_test.cc"], deps = [ @@ -2648,7 +2649,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_gather_combiner_test", srcs = ["all_gather_combiner_test.cc"], deps = [ @@ -2685,7 +2686,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_combiner_test", srcs = ["all_reduce_combiner_test.cc"], deps = [ @@ -2715,7 +2716,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_contiguous_test", srcs = ["all_reduce_contiguous_test.cc"], deps = [ @@ -2752,7 +2753,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "reduce_scatter_combiner_test", srcs = ["reduce_scatter_combiner_test.cc"], deps = [ @@ -2778,7 +2779,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_simplifier_test", srcs = ["all_reduce_simplifier_test.cc"], deps = [ @@ -2814,7 +2815,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "reduce_scatter_decomposer_test", srcs = ["reduce_scatter_decomposer_test.cc"], deps = [ @@ -2845,7 +2846,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "reduce_scatter_reassociate_test", srcs = ["reduce_scatter_reassociate_test.cc"], deps = [ @@ -2869,7 +2870,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "batch_dot_simplification_test", srcs = ["batch_dot_simplification_test.cc"], deps = [ @@ -2881,7 +2882,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gather_expander_test", srcs = ["gather_expander_test.cc"], deps = [ @@ -2917,7 +2918,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "conditional_simplifier_test", srcs = ["conditional_simplifier_test.cc"], deps = [ @@ -2961,7 +2962,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "conditional_code_motion_test", srcs = ["conditional_code_motion_test.cc"], deps = [ @@ -3003,7 +3004,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "convolution_group_converter_test", size = "small", srcs = ["convolution_group_converter_test.cc"], @@ -3060,7 +3061,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "space_to_batch_converter_test", size = "small", srcs = ["space_to_batch_converter_test.cc"], @@ -3088,7 +3089,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_analysis_test", srcs = ["while_loop_analysis_test.cc"], deps = [ @@ -3123,7 +3124,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ @@ -3156,7 +3157,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_trip_count_annotator_test", srcs = ["while_loop_trip_count_annotator_test.cc"], deps = [ @@ -3189,7 +3190,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "defuser_test", srcs = ["defuser_test.cc"], deps = [ @@ -3202,7 +3203,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "despecializer_test", srcs = ["despecializer_test.cc"], deps = [ @@ -3235,7 +3236,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dot_decomposer_test", srcs = ["dot_decomposer_test.cc"], deps = [ @@ -3263,7 +3264,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dot_dimension_merger_test", srcs = ["dot_dimension_merger_test.cc"], deps = [ @@ -3286,7 +3287,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dot_merger_test", srcs = ["dot_merger_test.cc"], deps = [ @@ -3313,7 +3314,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "convert_mover_test", srcs = ["convert_mover_test.cc"], deps = [ @@ -3356,7 +3357,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_gather_decomposer_test", srcs = ["all_gather_decomposer_test.cc"], deps = [ @@ -3381,7 +3382,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tuple_simplifier_test", srcs = ["tuple_simplifier_test.cc"], deps = [ @@ -3438,7 +3439,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "reduce_decomposer_test", srcs = ["reduce_decomposer_test.cc"], deps = [ @@ -3452,7 +3453,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reshape_decomposer_test", srcs = ["reshape_decomposer_test.cc"], deps = [ @@ -3518,7 +3519,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dynamic_dimension_simplifier_test", srcs = ["dynamic_dimension_simplifier_test.cc"], deps = [ @@ -3611,7 +3612,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "dynamic_dimension_inference_test", srcs = ["dynamic_dimension_inference_test.cc"], deps = [ @@ -3633,7 +3634,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], deps = [ @@ -3766,7 +3767,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ @@ -3809,7 +3810,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_execution_profile_test", srcs = ["hlo_execution_profile_test.cc"], tags = ["no_mac_arm64"], @@ -3823,7 +3824,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], deps = [ @@ -3844,7 +3845,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], deps = [ @@ -3870,7 +3871,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_module_metadata_test", srcs = ["hlo_module_metadata_test.cc"], deps = [ @@ -3975,7 +3976,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_dataflow_analysis_test", srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ @@ -4015,7 +4016,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_phi_graph_test", srcs = ["hlo_phi_graph_test.cc"], deps = [ @@ -4049,7 +4050,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_value_semantics_analysis_test", srcs = ["hlo_value_semantics_analysis_test.cc"], deps = [ @@ -4077,7 +4078,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_replication_analysis_test", srcs = ["hlo_replication_analysis_test.cc"], deps = [ @@ -4114,7 +4115,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_liveness_analysis_test", srcs = ["hlo_liveness_analysis_test.cc"], deps = [ @@ -4178,7 +4179,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_alias_analysis_test", srcs = ["hlo_alias_analysis_test.cc"], deps = [ @@ -4246,7 +4247,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tuple_points_to_analysis_test", srcs = ["tuple_points_to_analysis_test.cc"], deps = [ @@ -4368,7 +4369,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "copy_insertion_test", srcs = ["copy_insertion_test.cc"], deps = [ @@ -4389,7 +4390,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "loop_schedule_linearizer_test", srcs = ["loop_schedule_linearizer_test.cc"], deps = [ @@ -4443,7 +4444,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "memory_space_assignment_best_fit_repacker_test", srcs = ["memory_space_assignment_best_fit_repacker_test.cc"], deps = [ @@ -4486,7 +4487,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "memory_space_assignment_test", srcs = ["memory_space_assignment_test.cc"], deps = [ @@ -4531,7 +4532,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "memory_space_propagation_test", srcs = ["memory_space_propagation_test.cc"], deps = [ @@ -4606,7 +4607,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_verifier_test", srcs = ["hlo_verifier_test.cc"], deps = [ @@ -4673,7 +4674,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_rematerialization_test_utils_test", srcs = ["hlo_rematerialization_test_utils_test.cc"], deps = [ @@ -4684,7 +4685,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ @@ -4703,7 +4704,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_dce_test", srcs = ["hlo_dce_test.cc"], deps = [ @@ -4721,7 +4722,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_module_dce_test", srcs = ["hlo_module_dce_test.cc"], deps = [ @@ -4737,7 +4738,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], deps = [ @@ -4810,7 +4811,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_pass_pipeline_test", srcs = ["hlo_pass_pipeline_test.cc"], deps = [ @@ -4841,7 +4842,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_cse_test", srcs = ["hlo_cse_test.cc"], deps = [ @@ -4882,7 +4883,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_constant_folding_test", srcs = ["hlo_constant_folding_test.cc"], deps = [ @@ -4959,7 +4960,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_domain_test", srcs = ["hlo_domain_test.cc"], deps = [ @@ -4994,7 +4995,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_element_type_converter_test", srcs = ["hlo_element_type_converter_test.cc"], deps = [ @@ -5016,7 +5017,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "conditional_canonicalizer_test", srcs = ["conditional_canonicalizer_test.cc"], deps = [ @@ -5202,7 +5203,7 @@ cc_library( alwayslink = 1, ) -xla_cc_test( +tf_cc_test( name = "hlo_graph_dumper_test", srcs = ["hlo_graph_dumper_test.cc"], deps = [ @@ -5237,7 +5238,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "transpose_folding_test", srcs = ["transpose_folding_test.cc"], deps = [ @@ -5278,7 +5279,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "zero_sized_hlo_elimination_test", srcs = ["zero_sized_hlo_elimination_test.cc"], deps = [ @@ -5304,7 +5305,7 @@ cc_library( deps = ["//tensorflow/compiler/xla/stream_executor"], ) -xla_cc_test( +tf_cc_test( name = "stream_pool_test", srcs = ["stream_pool_test.cc"], deps = [ @@ -5330,7 +5331,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_proto_util_test", srcs = ["hlo_proto_util_test.cc"], deps = [ @@ -5438,7 +5439,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "sort_simplifier_test", srcs = ["sort_simplifier_test.cc"], deps = [ @@ -5468,7 +5469,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "stable_sort_expander_test", srcs = ["stable_sort_expander_test.cc"], deps = [ @@ -5496,7 +5497,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "tuple_util_test", srcs = ["tuple_util_test.cc"], deps = [ @@ -5523,7 +5524,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "root_instruction_sinker_test", srcs = ["root_instruction_sinker_test.cc"], deps = [ @@ -5552,7 +5553,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_util_test", srcs = ["while_util_test.cc"], deps = [ @@ -5590,7 +5591,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_all_reduce_code_motion_test", srcs = ["while_loop_all_reduce_code_motion_test.cc"], deps = [ @@ -5632,7 +5633,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_concat_code_motion_test", srcs = ["while_loop_concat_code_motion_test.cc"], deps = [ @@ -5674,7 +5675,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_invariant_code_motion_test", srcs = ["while_loop_invariant_code_motion_test.cc"], deps = [ @@ -5707,7 +5708,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_expensive_invariant_code_motion_test", srcs = ["while_loop_expensive_invariant_code_motion_test.cc"], deps = [ @@ -5734,7 +5735,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "while_loop_constant_sinking_test", srcs = ["while_loop_constant_sinking_test.cc"], deps = [ @@ -5792,7 +5793,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "indexed_array_analysis_test", srcs = ["indexed_array_analysis_test.cc"], deps = [ @@ -5833,7 +5834,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_parser_test", size = "small", srcs = ["hlo_parser_test.cc"], @@ -5909,7 +5910,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "optimize_input_output_buffer_alias_test", srcs = ["optimize_input_output_buffer_alias_test.cc"], deps = [ @@ -5976,7 +5977,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dynamic_index_splitter_test", srcs = ["dynamic_index_splitter_test.cc"], deps = [ @@ -5990,7 +5991,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "ar_crs_combiner_test", srcs = ["ar_crs_combiner_test.cc"], deps = [ @@ -6003,7 +6004,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], deps = [ @@ -6020,7 +6021,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], deps = [ @@ -6048,7 +6049,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "conditional_to_select_test", srcs = ["conditional_to_select_test.cc"], deps = [ @@ -6138,7 +6139,7 @@ cc_library( visibility = ["//visibility:public"], ) -xla_cc_test( +tf_cc_test( name = "custom_call_status_test", srcs = ["custom_call_status_test.cc"], deps = [ @@ -6157,7 +6158,7 @@ cc_library( deps = [":custom_call_status"], ) -xla_cc_test( +tf_cc_test( name = "slice_sinker_test", srcs = ["slice_sinker_test.cc"], deps = [ @@ -6255,7 +6256,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "collective_transformation_reorderer_test", srcs = ["collective_transformation_reorderer_test.cc"], deps = [ @@ -6267,7 +6268,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "collective_ops_utils_test", srcs = ["collective_ops_utils_test.cc"], deps = [ @@ -6299,7 +6300,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "topk_rewriter_test", srcs = ["topk_rewriter_test.cc"], deps = [ @@ -6333,7 +6334,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "operand_upcaster_test", srcs = ["operand_upcaster_test.cc"], deps = [ @@ -6357,7 +6358,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "result_caster_test", srcs = ["result_caster_test.cc"], deps = [ @@ -6395,7 +6396,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "convert_operand_folding_test", srcs = ["convert_operand_folding_test.cc"], deps = [ @@ -6425,7 +6426,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "xla_debug_info_manager_test", srcs = ["xla_debug_info_manager_test.cc"], deps = [ @@ -6481,7 +6482,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "mapped_ptr_container_sorter_test", srcs = ["mapped_ptr_container_sorter_test.cc"], deps = [ @@ -6541,7 +6542,7 @@ tf_proto_library( cc_api_version = 2, ) -xla_cc_test( +tf_cc_test( name = "compilation_environments_test", srcs = ["compilation_environments_test.cc"], deps = [ @@ -6606,7 +6607,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "scatter_simplifier_test", srcs = ["scatter_simplifier_test.cc"], deps = [ @@ -6630,7 +6631,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "select_and_scatter_expander_test", srcs = ["select_and_scatter_expander_test.cc"], deps = [ @@ -6641,7 +6642,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "layout_normalization_test", srcs = [ "layout_normalization_test.cc", @@ -6712,7 +6713,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "stochastic_convert_decomposer_test", srcs = ["stochastic_convert_decomposer_test.cc"], deps = [ @@ -6749,7 +6750,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gather_simplifier_test", srcs = ["gather_simplifier_test.cc"], deps = [ @@ -6759,7 +6760,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "change_op_data_type_test", srcs = ["change_op_data_type_test.cc"], deps = [ @@ -6869,7 +6870,7 @@ xla_aot_compile_gpu_runtime_autotuning( module = "xla_aot_compile_test_convolution.mlir", ) -xla_cc_test( +tf_cc_test( name = "xla_aot_compile_cpu_test", srcs = ["xla_aot_compile_cpu_test.cc"], data = [":xla_aot_compile_test_cpu_executable"], @@ -6891,7 +6892,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "xla_aot_compile_stablehlo_cpu_test", srcs = ["xla_aot_compile_stablehlo_cpu_test.cc"], data = [":xla_aot_compile_stablehlo_test_cpu_executable"], @@ -6913,7 +6914,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "xla_aot_compile_gpu_test", srcs = if_cuda_is_configured(["xla_aot_compile_gpu_test.cc"]), data = if_cuda_is_configured([ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 55e4e99a492507..374d6d0986aa15 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -2,6 +2,7 @@ # LLVM-based CPU backend for XLA. load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS", @@ -1202,7 +1203,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], shard_count = 10, @@ -1229,7 +1230,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "runtime_fft_test", srcs = [ "runtime_fft_impl.h", @@ -1246,7 +1247,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_instruction_fusion_test", srcs = ["cpu_instruction_fusion_test.cc"], deps = [ @@ -1263,7 +1264,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "xfeed_manager_test", size = "small", srcs = ["xfeed_manager_test.cc"], @@ -1306,7 +1307,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "ir_emission_utils_test", srcs = ["ir_emission_utils_test.cc"], deps = [ @@ -1348,7 +1349,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_layout_assignment_test", size = "small", srcs = ["cpu_layout_assignment_test.cc"], @@ -1393,7 +1394,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "conv_canonicalization_test", srcs = ["conv_canonicalization_test.cc"], deps = [ @@ -1417,7 +1418,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "shape_partition_test", srcs = ["shape_partition_test.cc"], deps = [ @@ -1447,7 +1448,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "parallel_task_assignment_test", srcs = ["parallel_task_assignment_test.cc"], deps = [ @@ -1515,7 +1516,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_eigen_tensor_alignment_test", size = "small", srcs = ["cpu_eigen_tensor_alignment_test.cc"], @@ -1528,7 +1529,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "vectorized_reduce_with_no_vector_registers_test", size = "small", srcs = ["vectorized_reduce_with_no_vector_registers_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index bb39d3300af6de..f816103ce331c5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -37,7 +38,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_dyn_shape_test", srcs = ["cpu_dyn_shape_test.cc"], deps = [ @@ -51,7 +52,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], deps = [ @@ -70,7 +71,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_bytesizeof_test", srcs = ["cpu_bytesizeof_test.cc"], deps = [ @@ -81,7 +82,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_external_constants_test", srcs = ["cpu_external_constants_test.cc"], deps = [ @@ -94,7 +95,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_noalias_test", srcs = ["cpu_noalias_test.cc"], deps = [ @@ -115,7 +116,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_intrinsic_test", srcs = ["cpu_intrinsic_test.cc"], deps = [ @@ -132,7 +133,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_eigen_dot_operation_test", srcs = ["cpu_eigen_dot_operation_test.cc"], tags = ["no_mac_arm64"], @@ -149,7 +150,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_profiling_test", srcs = ["cpu_profiling_test.cc"], deps = [ @@ -166,7 +167,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "tree_reduction_rewriter_test", srcs = ["tree_reduction_rewriter_test.cc"], deps = [ @@ -192,7 +193,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], deps = [ @@ -216,7 +217,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_literal_caching_test", srcs = ["cpu_literal_caching_test.cc"], deps = [ @@ -231,7 +232,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_outfeed_test", srcs = ["cpu_outfeed_test.cc"], deps = [ @@ -245,7 +246,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_key_value_sort_test", srcs = ["cpu_key_value_sort_test.cc"], deps = [ @@ -259,7 +260,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_spmd_compile_test", srcs = ["cpu_spmd_compile_test.cc"], deps = [ @@ -276,7 +277,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_topk_test", srcs = ["cpu_topk_test.cc"], deps = [ @@ -292,7 +293,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_vectorization_test", srcs = ["cpu_vectorization_test.cc"], deps = [ @@ -309,7 +310,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "cpu_while_test", srcs = ["cpu_while_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 07c15218dd7790..f34297927e5b6e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -6,6 +6,7 @@ load( "//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library", ) +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/tsl/platform:build_config_root.bzl", "if_static", @@ -63,7 +64,7 @@ tf_proto_library( ], ) -xla_cc_test( +tf_cc_test( name = "backend_configs_test", srcs = ["backend_configs_test.cc"], deps = [ @@ -140,7 +141,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "custom_call_test", srcs = if_gpu_is_configured(["custom_call_test.cc"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), @@ -169,7 +170,7 @@ xla_cc_test( ]), ) -xla_cc_test( +tf_cc_test( name = "gpu_copy_insertion_test", srcs = if_gpu_is_configured(["gpu_copy_insertion_test.cc"]), tags = tf_cuda_tests_tags(), @@ -225,7 +226,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "target_util_test", srcs = ["target_util_test.cc"], deps = [ @@ -258,7 +259,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_device_info_test", srcs = ["gpu_device_info_test.cc"], tags = tf_cuda_tests_tags(), @@ -864,7 +865,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "non_atomically_upgradeable_rw_lock_test", srcs = ["non_atomically_upgradeable_rw_lock_test.cc"], deps = [ @@ -1041,7 +1042,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "ir_emission_utils_test", srcs = ["ir_emission_utils_test.cc"], deps = [ @@ -1253,7 +1254,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gemm_rewriter_triton_test", srcs = ["gemm_rewriter_triton_test.cc"], deps = [ @@ -1407,7 +1408,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "gemm_algorithm_picker_test", srcs = ["gemm_algorithm_picker_test.cc"], tags = [ @@ -1475,7 +1476,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "matmul_utils_test", srcs = ["matmul_utils_test.cc"], deps = [ @@ -1506,7 +1507,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "dot_dimension_sorter_test", srcs = ["dot_dimension_sorter_test.cc"], tags = tf_cuda_tests_tags() + ["no_rocm"], @@ -1529,7 +1530,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_async_collective_annotator_test", srcs = ["gpu_async_collective_annotator_test.cc"], deps = [ @@ -1554,7 +1555,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_convert_async_collectives_to_sync_test", srcs = ["gpu_convert_async_collectives_to_sync_test.cc"], deps = [ @@ -1607,7 +1608,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "conv_algorithm_picker_test", srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), tags = [ @@ -1714,7 +1715,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "move_copy_to_users_test", srcs = ["move_copy_to_users_test.cc"], deps = [ @@ -1725,7 +1726,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_conv_rewriter_test", srcs = ["gpu_conv_rewriter_test.cc"], deps = [ @@ -1812,7 +1813,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], tags = [ @@ -1860,7 +1861,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "priority_fusion_test", srcs = ["priority_fusion_test.cc"], tags = ["no_pip"], @@ -1902,7 +1903,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "multi_output_fusion_test", srcs = ["multi_output_fusion_test.cc"], tags = [ @@ -1920,7 +1921,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "softmax_rewriter_triton_test", srcs = ["softmax_rewriter_triton_test.cc"], deps = [ @@ -1945,7 +1946,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_sanitize_constant_names_test", srcs = ["gpu_sanitize_constant_names_test.cc"], deps = [ @@ -1978,7 +1979,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "fusion_merger_test", srcs = ["fusion_merger_test.cc"], tags = [ @@ -2014,7 +2015,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_conv_padding_legalization_test", srcs = ["gpu_conv_padding_legalization_test.cc"], deps = [ @@ -2045,7 +2046,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cudnn_support_utils_test", srcs = ["cudnn_support_utils_test.cc"], deps = [ @@ -2089,7 +2090,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cudnn_pad_for_convolutions_test", srcs = ["cudnn_pad_for_convolutions_test.cc"], deps = [ @@ -2123,7 +2124,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cudnn_vectorize_convolutions_test", srcs = ["cudnn_vectorize_convolutions_test.cc"], deps = [ @@ -2156,7 +2157,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cudnn_simplify_padding_test", srcs = ["cudnn_simplify_padding_test.cc"], deps = [ @@ -2207,7 +2208,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cublas_pad_for_gemms_test", srcs = ["cublas_pad_for_gemms_test.cc"], tags = [ @@ -2534,7 +2535,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "gpu_compiler_test", srcs = ["gpu_compiler_test.cc"], tags = tf_cuda_tests_tags(), @@ -2554,7 +2555,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "auto_sharding_gpu_compiler_test", srcs = ["auto_sharding_gpu_compiler_test.cc"], tags = tf_cuda_tests_tags() + ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS @@ -2661,7 +2662,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "nvptx_compiler_test", srcs = if_gpu_is_configured([ "nvptx_compiler_test.cc", @@ -2685,7 +2686,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_aot_compilation_test", srcs = if_cuda_is_configured([ "gpu_aot_compilation_test.cc", @@ -2778,7 +2779,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "all_reduce_blueconnect_test", srcs = ["all_reduce_blueconnect_test.cc"], deps = [ @@ -2861,7 +2862,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_layout_assignment_test", srcs = ["gpu_layout_assignment_test.cc"], tags = tf_cuda_tests_tags(), @@ -2905,7 +2906,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_hlo_schedule_test", srcs = [ "gpu_hlo_schedule_test.cc", @@ -2924,7 +2925,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], tags = [ @@ -3007,7 +3008,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_hlo_cost_analysis_test", srcs = ["gpu_hlo_cost_analysis_test.cc"], deps = [ @@ -3062,7 +3063,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_performance_model_test", srcs = ["gpu_performance_model_test.cc"], deps = [ @@ -3115,7 +3116,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_op_profiler_run", timeout = "eternal", srcs = ["hlo_op_profiler_run.cc"], @@ -3145,7 +3146,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_op_profiler_test", srcs = if_cuda_is_configured(["hlo_op_profiler_test.cc"]), tags = tf_cuda_tests_tags() + ["no_rocm",], @@ -3177,7 +3178,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "buffer_comparator_test", srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]), tags = tf_cuda_tests_tags(), @@ -3207,7 +3208,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_fusible_test", srcs = ["gpu_fusible_test.cc"], tags = [ @@ -3246,7 +3247,7 @@ cc_library( ]), ) -xla_cc_test( +tf_cc_test( name = "cudnn_fused_conv_rewriter_test", srcs = ["cudnn_fused_conv_rewriter_test.cc"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), @@ -3346,7 +3347,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "cudnn_fused_mha_rewriter_test", srcs = ["cudnn_fused_mha_rewriter_test.cc"], tags = tf_cuda_tests_tags(), @@ -3394,7 +3395,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "conv_layout_normalization_test", srcs = ["conv_layout_normalization_test.cc"], tags = tf_cuda_tests_tags() + ["no_rocm"], @@ -3433,7 +3434,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "variadic_op_splitter_test", srcs = ["variadic_op_splitter_test.cc"], tags = [ @@ -3478,7 +3479,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_algorithm_denylist_test", srcs = ["hlo_algorithm_denylist_test.cc"], data = ["data/hlo_algorithm_denylist.pbtxt"], @@ -3507,7 +3508,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "alias_passthrough_params_test", srcs = ["alias_passthrough_params_test.cc"], tags = [ @@ -3540,7 +3541,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "horizontal_loop_fusion_test", srcs = ["horizontal_loop_fusion_test.cc"], tags = tf_cuda_tests_tags(), @@ -3580,7 +3581,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "horizontal_input_fusion_test", srcs = ["horizontal_input_fusion_test.cc"], tags = tf_cuda_tests_tags(), @@ -3638,7 +3639,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_splitter_test", srcs = ["reduction_splitter_test.cc"], deps = [ @@ -3953,7 +3954,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_fusion_stats_test", srcs = ["hlo_fusion_stats_test.cc"], tags = [ @@ -3981,7 +3982,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "scatter_slice_simplifier_test", srcs = ["scatter_slice_simplifier_test.cc"], deps = [ @@ -4055,7 +4056,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "topk_splitter_test", srcs = ["topk_splitter_test.cc"], deps = [ @@ -4122,7 +4123,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "copy_fusion_test", srcs = ["copy_fusion_test.cc"], deps = [ @@ -4134,7 +4135,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "autotuner_util_test", srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), deps = if_cuda_is_configured([ diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 3663e51639e680..30c5d73d8862e5 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -61,6 +61,7 @@ message CudnnConvBackendConfig { // Serialization of the graph described by the convolution and adjacent // pointwise ops. optional string serialized_graph = 9; + string call_context = 10; } // Backend config for the GEMM operation running through cuBLAS. @@ -94,6 +95,12 @@ message GemmBackendConfig { } Epilogue epilogue = 13; + + int64 lhs_stride = 14; + int64 rhs_stride = 15; + + bool grad_x = 16; + bool grad_y = 17; } // Backend config for bitcast operation generated from MLIR MHLO dialect. diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index 44040912ad89a7..6e04ca5a348194 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -188,7 +188,7 @@ StatusOr> GetAlgorithms( /* filter_data = */ DeviceMemoryBase(nullptr), config.output_descriptor, /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc, - use_fallback, nullptr, numeric_options, &runners)); + se::dnn::CallContext::kNone, use_fallback, nullptr, numeric_options, &runners)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, @@ -208,7 +208,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec, - ScratchAllocator* scratch_allocator, se::Stream* stream, + ScratchAllocator* scratch_allocator, se::dnn::CallContext call_context, se::Stream* stream, const se::NumericOptions& numeric_options) { TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); @@ -227,7 +227,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, params.config->input_descriptor, params.input_buf, params.config->filter_descriptor, params.filter_buf, params.config->output_descriptor, params.output_buf, - params.config->conv_desc, /* use_fallback = */ false, scratch_allocator, + params.config->conv_desc, call_context, /* use_fallback = */ false, scratch_allocator, numeric_options, &runners)); return runners; @@ -921,10 +921,15 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( ScratchAllocator scratch_allocator(device_ordinal, allocator); + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); + se::dnn::CallContext call_context = + GetCallContext(backend_config.call_context()); + TF_ASSIGN_OR_RETURN( std::vector> runners, GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer, - stream_exec, &scratch_allocator, stream, + stream_exec, &scratch_allocator, call_context, stream, numeric_options)); std::vector profile_results; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index a571906a3ea54c..46bdb43893ed4a 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -457,6 +457,22 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->dot_dimension_numbers(); *gemm_backend_config.mutable_precision_config() = instr->precision_config(); + HloInstruction *lhs = instr->mutable_operand(0); + HloInstruction *rhs = instr->mutable_operand(1); + auto attributes = instr->frontend_attributes().map(); + gemm_backend_config.set_grad_x(attributes["grad_x"] == "true"); + gemm_backend_config.set_grad_y(attributes["grad_y"] == "true"); + + int64_t lhs_batch_dims_size = + instr->dot_dimension_numbers().lhs_batch_dimensions_size(); + int64_t lhs_stride = lhs->shape().dimensions(lhs_batch_dims_size) * + lhs->shape().dimensions(lhs_batch_dims_size + 1); + int64_t rhs_stride = rhs->shape().dimensions(lhs_batch_dims_size) * + rhs->shape().dimensions(lhs_batch_dims_size + 1); + + gemm_backend_config.set_lhs_stride(lhs_stride); + gemm_backend_config.set_rhs_stride(rhs_stride); + // First try to match the fp8 gemm pattern. TF_ASSIGN_OR_RETURN(bool supported_by_cublaslt, GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); @@ -1743,7 +1759,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { dot_dims.rhs_contracting_dimensions(), /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(), gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), - /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision)); + /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision, + gemm_backend_config.grad_x(), gemm_backend_config.grad_y())); if (matrix_name == "lhs" || matrix_name == "a") { return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc index bee7667cdeb538..9900ef6eab3a1c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -705,8 +705,11 @@ StatusOr RunOnInstruction(HloInstruction* conv) { return false; } - TF_RETURN_IF_ERROR( - custom_call->set_backend_config(GetDefaultBackendConfig())); + auto backend_config = GetDefaultBackendConfig(); + auto attributes = conv->frontend_attributes().map(); + backend_config.set_call_context(attributes["call_context"]); + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << custom_call->ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 235769658a152c..55fddbb45b32d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -81,7 +81,8 @@ Status RunGpuConvUnfused(const GpuConvParams& params, se::Stream* stream, params.config->input_descriptor, params.config->filter_descriptor, params.config->output_descriptor, - params.config->conv_desc}; + params.config->conv_desc, + params.config->call_context}; TF_ASSIGN_OR_RETURN(auto* runner, lazy_runner->GetOrCreateRunner(config, stream)); @@ -314,6 +315,16 @@ int64_t GetVectCSize(FilterLayout layout) { } // anonymous namespace +se::dnn::CallContext GetCallContext(const absl::string_view call_context) { + if (call_context == "kForward") + return se::dnn::CallContext::kForward; + else if (call_context == "kBackpropData") + return se::dnn::CallContext::kBackpropData; + else if (call_context == "kBackpropFilter") + return se::dnn::CallContext::kBackpropFilter; + return se::dnn::CallContext::kNone; +} + StatusOr GetGpuConvConfig( const GpuConvDescriptor& desc, const absl::string_view inst_as_string) { GpuConvConfig config; @@ -327,6 +338,7 @@ StatusOr GetGpuConvConfig( config.output_type = result_shape.element_type(); config.kind = desc.kind; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); + config.call_context = GetCallContext(backend_config.call_context()); config.conv_result_scale = backend_config.conv_result_scale(); config.serialized_graph = backend_config.serialized_graph(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index c746584b53a31b..f821a1ee454e5f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -66,6 +66,7 @@ struct GpuConvConfig { se::dnn::BatchDescriptor output_descriptor; se::dnn::ConvolutionDescriptor conv_desc; se::dnn::BatchDescriptor bias_descriptor; + se::dnn::CallContext call_context; Shape input_shape; Shape filter_shape; @@ -198,6 +199,8 @@ struct RunConvOptions { // This file contains low-level routines for running cudnn convolutions. +se::dnn::CallContext GetCallContext(const absl::string_view call_context); + // Calls into cudnn to run the specified convolution. // // We provide one overload which takes a scratch buffer, and another which takes diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 3a9061ca35e69b..e17aaf4b35b44c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -791,6 +791,28 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { op.getResultScale().convertToDouble()); descriptor.backend_config.set_reordered_int8_nchw_vect( op.getBackendConfig().getIsCudnnReorderedInt8()); + auto attr_call_context = + op->template getAttrOfType("call_context"); + if (attr_call_context) { + descriptor.backend_config.set_call_context( + attr_call_context.getValue().str()); + } else { + std::string call_context = "kNone"; + switch (descriptor.kind) { + case CudnnConvKind::kForward: + call_context = "kForward"; + break; + case CudnnConvKind::kBackwardInput: + call_context = "kBackpropData"; + break; + case CudnnConvKind::kBackwardFilter: + call_context = "kBackpropFilter"; + break; + default: + break; + } + descriptor.backend_config.set_call_context(call_context); + } }; auto set_activation_mode = [&](auto op) -> Status { diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index 4bf23c883495d3..807a051b7c327e 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -45,7 +45,9 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" #include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" -#endif // GOOGLE_CUDA +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { @@ -263,12 +265,13 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision) { + std::optional algorithm, int64_t compute_precision, + bool gx, bool gy) { return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims, rhs_shape, rhs_batch_dims, rhs_contracting_dims, /*c_shape=*/output_shape, /*bias_shape_ptr=*/nullptr, output_shape, alpha_real, alpha_imag, beta, algorithm, - compute_precision); + compute_precision, gx, gy); } /*static*/ StatusOr GemmConfig::For( @@ -278,7 +281,7 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, absl::Span rhs_contracting_dims, const Shape& c_shape, const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision) { + int64_t compute_precision, bool gx, bool gy) { absl::Span lhs_col_dims = lhs_contracting_dims; TF_ASSIGN_OR_RETURN( std::vector lhs_row_dims, @@ -388,6 +391,8 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, beta, algorithm, compute_precision, + gx, + gy }; } @@ -405,13 +410,17 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers(); const Shape& output_shape = gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); + auto attributes = gemm->frontend_attributes().map(); + bool gx = (attributes["grad_x"] == "true"); + bool gy = (attributes["grad_y"] == "true"); return GemmConfig::For( lhs_shape, dot_dims.lhs_batch_dimensions(), dot_dims.lhs_contracting_dimensions(), rhs_shape, dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), - output_shape, config.alpha_real(), config.alpha_imag(), config.beta(), - algorithm, se::blas::kDefaultComputePrecision); + /*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(), + config.beta(), algorithm, se::blas::kDefaultComputePrecision, + gx, gy); } /*static*/ StatusOr GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) { @@ -419,6 +428,13 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, std::optional algorithm; if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); + bool gx=false, gy=false; + auto attr_grad_x = op.getGradX(); + if (attr_grad_x) + gx=attr_grad_x.value(); + auto attr_grad_y = op.getGradY(); + if (attr_grad_y) + gx=attr_grad_y.value(); int64_t compute_precision = 0; // Default if (op.getPrecisionConfig().has_value()) { @@ -438,7 +454,8 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, dot_dims.getRhsBatchingDimensions(), dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), - op.getBeta().convertToDouble(), algorithm, compute_precision); + op.getBeta().convertToDouble(), algorithm, compute_precision, + gx, gy); } StatusOr GetBlasComputationType( @@ -553,7 +570,8 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, se::blas::AlgorithmType algorithm, se::blas::ComputePrecision compute_precision, const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result) { + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType(); PrimitiveType output_type = primitive_util::NativeToPrimitiveType(); @@ -568,13 +586,14 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, batch_size, - computation_type, algorithm, numeric_options, profile_result); + computation_type, algorithm, numeric_options, profile_result, + context); } else { return stream->ThenBlasGemmWithAlgorithm( lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, computation_type, algorithm, - numeric_options, profile_result); + numeric_options, profile_result, context); } } @@ -586,7 +605,8 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, std::optional algorithm, se::blas::ComputePrecision compute_precision, const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result) { + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); se::DeviceMemory output_data(output.data); @@ -594,7 +614,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, if (algorithm) { return DoGemmWithAlgorithm( batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm, - compute_precision, numeric_options, profile_result); + compute_precision, numeric_options, profile_result, context); } #endif @@ -604,13 +624,13 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, batch_size, - numeric_options); + numeric_options, context); } return stream->ThenBlasGemm( lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, numeric_options); + &output_data, output.leading_dim_stride, numeric_options, context); } } // namespace @@ -645,6 +665,17 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, if (!algorithm) algorithm = config.algorithm; + se::blas::CallContext context = se::blas::CallContext::kNone; + if (config.grad_x) { + context = must_swap_operands + ? se::blas::CallContext::kBackpropInput2 + : se::blas::CallContext::kBackpropInput1; + } + if (config.grad_y) { + context = must_swap_operands + ? se::blas::CallContext::kBackpropInput1 + : se::blas::CallContext::kBackpropInput2; + } std::tuple operand_types{ lhs_layout.dtype, rhs_layout.dtype, output_layout.dtype}; @@ -658,7 +689,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, \ static_cast(config.alpha.real()), \ static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result); \ + config.compute_precision, numeric_options, profile_result, context); \ } #define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ @@ -671,7 +702,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, \ static_cast(config.alpha), \ static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result); \ + config.compute_precision, numeric_options, profile_result, context); \ } if (output_layout.dtype == S32) { @@ -680,7 +711,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, static_cast(config.alpha.real()), static_cast(config.beta), stream, *algorithm, - se::blas::kDefaultComputePrecision, numeric_options, profile_result); + se::blas::kDefaultComputePrecision, numeric_options, profile_result, context); } TYPED_GEMM(F32, BF16, BF16, BF16) @@ -786,6 +817,7 @@ StatusOr AsBlasLtMatrixLayout( } #if TF_HIPBLASLT +#if TF_ROCM_VERSION < 50700 using cudaDataType_t = hipblasDatatype_t; #define CUDA_R_16BF HIPBLAS_R_16B #define CUDA_R_16F HIPBLAS_R_16F @@ -793,6 +825,15 @@ using cudaDataType_t = hipblasDatatype_t; #define CUDA_R_64F HIPBLAS_R_64F #define CUDA_C_32F HIPBLAS_C_32F #define CUDA_C_64F HIPBLAS_C_64F +#else +using cudaDataType_t = hipblasltDatatype_t; +#define CUDA_R_16BF HIPBLASLT_R_16B +#define CUDA_R_16F HIPBLASLT_R_16F +#define CUDA_R_32F HIPBLASLT_R_32F +#define CUDA_R_64F HIPBLASLT_R_64F +#define CUDA_C_32F HIPBLASLT_C_32F +#define CUDA_C_64F HIPBLASLT_C_64F +#endif #endif template diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index d51613c3821f12..e6299af05bde7c 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -114,7 +114,8 @@ struct GemmConfig { absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, - std::optional algorithm, int64_t compute_precision); + std::optional algorithm, int64_t compute_precision, + bool grad_x, bool grad_y); // As above with additional `c_shape` and `bias_shape_ptr` parameter, both // which are only necessarily for F8 gemms. @@ -125,7 +126,7 @@ struct GemmConfig { absl::Span rhs_contracting_dims, const Shape& c_shape, const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision); + int64_t compute_precision, bool grad_x, bool grad_y); template algorithm; int64_t compute_precision; + bool grad_x, grad_y; }; StatusOr GetBlasComputationType( diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index bf557858872610..05d4e35d1597c7 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( @@ -233,7 +234,7 @@ cuda_library( ], ) -xla_cc_test( +tf_cc_test( name = "topk_kernel_test", srcs = if_cuda_is_configured(["topk_kernel_test.cc"]), linkstatic = 1, @@ -253,7 +254,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "topk_test", srcs = if_cuda_is_configured(["topk_test.cc"]), tags = tf_cuda_tests_tags(), diff --git a/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc b/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc index 97b94614bdf873..20ac1f35cb12e6 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/cublas_lt_matmul.cc @@ -103,7 +103,7 @@ static absl::Status CublasLtMatmulImpl( a, b, c, algorithm, alpha_real, alpha_imag, beta, dot_dims.lhs_batch, dot_dims.lhs_contract, dot_dims.rhs_batch, dot_dims.rhs_contract, precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision))); + : *absl::c_max_element(precision), false, false)); })); // Get the matmul plan for this instance of matmul. diff --git a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc index 7a26ed0d40ced4..84f459c9e6ff35 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc @@ -130,7 +130,8 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, dot_dims.lhs_batch, dot_dims.lhs_contract, dot_dims.rhs_batch, dot_dims.rhs_contract, precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision)); + : *absl::c_max_element(precision), + false, false); return ToAbsl(gemm_config); })); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/support.h b/tensorflow/compiler/xla/service/gpu/runtime/support.h index 671b299108088b..cc6f0a6a9846bb 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/support.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/support.h @@ -99,10 +99,11 @@ inline StatusOr GetGemmConfig( const runtime::StridedMemrefView& out, int64_t algorithm, double alpha_real, double alpha_imag, double beta, absl::Span lhs_batch, absl::Span lhs_contract, absl::Span rhs_batch, - absl::Span rhs_contract, int64_t compute_precision) { + absl::Span rhs_contract, int64_t compute_precision, bool grad_x, bool grad_y) { return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), rhs_batch, rhs_contract, ToShape(out), alpha_real, - alpha_imag, beta, algorithm, compute_precision); + alpha_imag, beta, algorithm, + se::blas::kDefaultComputePrecision, grad_x, grad_y); } // adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 9b6a0ed54cc724..d6eb3bd7f81c5a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -1,6 +1,7 @@ # Description: GPU-specific XLA tests. For example, codegen tests that # verify the IR emitted. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load( @@ -68,7 +69,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "element_wise_row_vectorization_test", srcs = ["element_wise_row_vectorization_test.cc"], tags = tf_cuda_tests_tags(), @@ -80,7 +81,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "pred_arithmetic_test", srcs = ["pred_arithmetic_test.cc"], tags = tf_cuda_tests_tags(), @@ -91,7 +92,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_reduce_scatter_creator_test", srcs = ["gpu_reduce_scatter_creator_test.cc"], deps = [ @@ -106,7 +107,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_spmd_e2e_compile_test", size = "small", srcs = ["gpu_spmd_e2e_compile_test.cc"], @@ -120,7 +121,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gemm_rewrite_test", srcs = if_cuda_is_configured(["gemm_rewrite_test.cc"]), tags = tf_cuda_tests_tags(), @@ -142,7 +143,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gemm_broadcast_folding_rewrite_test", srcs = [ "gemm_broadcast_folding_rewrite_test.cc", @@ -161,7 +162,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_too_many_blocks_test", srcs = [ "gpu_too_many_blocks_test.cc", @@ -177,7 +178,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_degenerate_dim_remover_test", srcs = [ "reduction_degenerate_dim_remover_test.cc", @@ -200,7 +201,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_layout_normalizer_test", srcs = [ "reduction_layout_normalizer_test.cc", @@ -225,7 +226,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "tree_reduction_rewriter_test", srcs = [ "tree_reduction_rewriter_test.cc", @@ -249,7 +250,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "swap_conv_operands_test", srcs = [ "swap_conv_operands_test.cc", @@ -275,7 +276,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_vectorization_test", srcs = [ "reduction_vectorization_test.cc", @@ -303,7 +304,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_dimension_grouper_test", srcs = [ "reduction_dimension_grouper_test.cc", @@ -326,7 +327,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "parallel_reduction_test", srcs = [ "parallel_reduction_test.cc", @@ -348,7 +349,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_compilation_parallelism_test", srcs = [ "gpu_compilation_parallelism_test.cc", @@ -362,7 +363,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], tags = tf_cuda_tests_tags(), @@ -379,7 +380,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_copy_alone_test", srcs = [ "gpu_copy_alone_test.cc", @@ -393,7 +394,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_dyn_shape_test", srcs = ["gpu_dyn_shape_test.cc"], tags = tf_cuda_tests_tags(), @@ -408,7 +409,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], tags = tf_cuda_tests_tags(), @@ -419,7 +420,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_index_test", srcs = ["gpu_index_test.cc"], tags = tf_cuda_tests_tags(), @@ -440,7 +441,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_infeed_test", srcs = ["infeed_test.cc"], tags = tf_cuda_tests_tags(), @@ -485,7 +486,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "transpose_emitter_test", srcs = ["transpose_emitter_test.cc"], tags = tf_cuda_tests_tags(), @@ -501,7 +502,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "reduction_emitter_test", srcs = ["reduction_emitter_test.cc"], tags = tf_cuda_tests_tags(), @@ -517,7 +518,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_ldg_test", srcs = ["gpu_ldg_test.cc"], tags = tf_cuda_tests_tags(), @@ -535,7 +536,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_noalias_test", srcs = ["gpu_noalias_test.cc"], tags = tf_cuda_tests_tags(), @@ -552,7 +553,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_fusion_test", srcs = ["gpu_fusion_test.cc"], tags = tf_cuda_tests_tags(), @@ -565,7 +566,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_fusion_pipeline_test", srcs = ["gpu_fusion_pipeline_test.cc"], tags = tf_cuda_tests_tags(), @@ -583,7 +584,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_unrolling_test", srcs = ["gpu_unrolling_test.cc"], tags = tf_cuda_tests_tags(), @@ -597,7 +598,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_alignment_test", testonly = True, srcs = ["gpu_alignment_test.cc"], @@ -614,7 +615,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_atomic_test", srcs = ["gpu_atomic_test.cc"], tags = tf_cuda_tests_tags(), @@ -626,7 +627,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_input_fusible_slice_test", srcs = ["gpu_input_fusible_slice_test.cc"], tags = tf_cuda_tests_tags(), @@ -660,7 +661,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "select_and_scatter_test", srcs = ["select_and_scatter_test.cc"], tags = tf_cuda_tests_tags(), @@ -672,7 +673,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "sorting_test", srcs = ["sorting_test.cc"], tags = tf_cuda_tests_tags(), @@ -804,7 +805,7 @@ filegroup( ], ) -xla_cc_test( +tf_cc_test( name = "kernel_launch_test", srcs = ["kernel_launch_test.cc"], tags = tf_cuda_tests_tags(), @@ -817,7 +818,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "in_place_op_test", srcs = ["in_place_op_test.cc"], tags = tf_cuda_tests_tags(), @@ -828,7 +829,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "dynamic_shared_memory_test", srcs = if_cuda_is_configured(["dynamic_shared_memory_test.cc"]), tags = tf_cuda_tests_tags(), @@ -870,7 +871,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "gpu_fused_mha_test", srcs = ["gpu_fused_mha_test.cc"], tags = tf_cuda_tests_tags(), diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index e82ec13e7af33d..e6c087bf7491fe 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -46,7 +47,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "alias_analysis_test", srcs = ["alias_analysis_test.cc"], deps = [ @@ -294,7 +295,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "ir_array_test", srcs = ["ir_array_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index 467871ca225f91..072f6cc3d11195 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -76,7 +77,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "spmd_partitioner_test", srcs = ["spmd_partitioner_test.cc"], deps = [ @@ -95,7 +96,7 @@ xla_cc_test( ], ) -xla_cc_test( +tf_cc_test( name = "canonicalize_all_gather_for_cse_test", srcs = ["canonicalize_all_gather_for_cse_test.cc"], deps = [ @@ -125,7 +126,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "schedule_aware_collective_ops_cse_test", srcs = ["schedule_aware_collective_ops_cse_test.cc"], deps = [ @@ -197,7 +198,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "stateful_rng_spmd_partitioner_test", srcs = ["stateful_rng_spmd_partitioner_test.cc"], deps = [ @@ -234,7 +235,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "collective_permute_motion_test", srcs = ["collective_permute_motion_test.cc"], deps = [ @@ -260,7 +261,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "partition_assignment_test", srcs = ["partition_assignment_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/stream_executor/blas.h b/tensorflow/compiler/xla/stream_executor/blas.h index 7fd5e05edaafc8..18364d3211fbfa 100644 --- a/tensorflow/compiler/xla/stream_executor/blas.h +++ b/tensorflow/compiler/xla/stream_executor/blas.h @@ -118,6 +118,16 @@ enum class ComputationType { kTF32AsF32, // Allow downcast to TF32 precision. }; +// Call context information for GEMM API calls +// This is extra information that can optionally be passed down to the blas +// library, so that it can pick the efficient imlpementation based on context +enum class CallContext { + kNone = 0, // No information + kForward = 1, // call happens in "forward" pass + kBackpropInput1 = 2, // call happens in "backprop" pass for the first input + kBackpropInput2 = 4, // call happens in "backprop" pass for the second input +}; + // Converts a ComputationType to a string. std::string ComputationTypeString(ComputationType ty); @@ -318,7 +328,8 @@ class BlasSupport { const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options) = 0; + const NumericOptions &numeric_options, + blas::CallContext context) = 0; // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( @@ -343,7 +354,8 @@ class BlasSupport { DeviceMemoryBase *c, DataType type_c, int ldc, ComputationType computation_type, AlgorithmType algorithm, const NumericOptions &numeric_options, - ProfileResult *output_profile_result) = 0; + ProfileResult *output_profile_result, + blas::CallContext context) = 0; virtual tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, @@ -353,7 +365,8 @@ class BlasSupport { const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, int64_t stride_c, int batch_count, ComputationType computation_type, AlgorithmType algorithm, const NumericOptions &numeric_options, - ProfileResult *output_profile_result) = 0; + ProfileResult *output_profile_result, + blas::CallContext context) = 0; // Computes a batch of matrix-matrix product with general matrices. // This is a batched version of DoBlasGemm. @@ -367,7 +380,8 @@ class BlasSupport { float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, + blas::CallContext context) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, float alpha, @@ -377,7 +391,8 @@ class BlasSupport { DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, + blas::CallContext context) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, float alpha, @@ -386,7 +401,8 @@ class BlasSupport { float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, + blas::CallContext context) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, double alpha, @@ -395,7 +411,8 @@ class BlasSupport { double beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, + blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, std::complex alpha, @@ -403,7 +420,8 @@ class BlasSupport { DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, + blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, std::complex alpha, @@ -411,7 +429,7 @@ class BlasSupport { DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) = 0; + ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; // Batched gemm with strides instead of pointer arrays. virtual tsl::Status DoBlasGemmStridedBatched( @@ -420,7 +438,7 @@ class BlasSupport { const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options) = 0; + const NumericOptions &numeric_options, blas::CallContext context) = 0; // Solves a triangular matrix equation. // @@ -570,7 +588,7 @@ class BlasSupport { uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ const void *beta, DeviceMemoryBase *c, int ldc, \ - const NumericOptions &numeric_options) override; \ + const NumericOptions &numeric_options, blas::CallContext context) override; \ bool GetBlasGemmAlgorithms(Stream *stream, \ std::vector *out_algorithms) \ override; \ @@ -582,7 +600,7 @@ class BlasSupport { const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ const NumericOptions &numeric_options, \ - blas::ProfileResult *output_profile_result) override; \ + blas::ProfileResult *output_profile_result, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, float alpha, \ @@ -590,7 +608,7 @@ class BlasSupport { DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, float alpha, \ @@ -598,21 +616,21 @@ class BlasSupport { DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice a, \ int lda, DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, double alpha, \ DeviceMemorySlice a, int lda, DeviceMemorySlice b, \ int ldb, double beta, DeviceMemorySlice c, int ldc, \ int batch_count, const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, std::complex alpha, \ @@ -620,7 +638,7 @@ class BlasSupport { DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ int ldc, int batch_count, const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, std::complex alpha, \ @@ -628,14 +646,14 @@ class BlasSupport { DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ int ldc, int batch_count, const NumericOptions &numeric_options, \ - ScratchAllocator *scratch_allocator) override; \ + ScratchAllocator *scratch_allocator, blas::CallContext context) override; \ tsl::Status DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, int64_t stride_a, \ const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, \ DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, \ - const NumericOptions &numeric_options) override; \ + const NumericOptions &numeric_options, blas::CallContext context) override; \ tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, const void *alpha, \ @@ -645,7 +663,7 @@ class BlasSupport { blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ const NumericOptions &numeric_options, \ - blas::ProfileResult *output_profile_result) override; \ + blas::ProfileResult *output_profile_result, blas::CallContext context) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ uint64_t n, float alpha, const DeviceMemory &a, \ diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc index 76baf619e5d120..fe493dd5657327 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc @@ -586,7 +586,8 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #if CUDA_VERSION < 11000 @@ -781,7 +782,8 @@ tsl::Status CUDABlas::DoBlasGemmWithAlgorithm( blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { TF_ASSIGN_OR_RETURN( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); @@ -815,7 +817,8 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { TF_ASSIGN_OR_RETURN( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); @@ -1121,7 +1124,8 @@ bool CUDABlas::DoBlasGemmBatched( int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of fp16 inside DoBlasGemmBatchedInternal. tsl::Status status = DoBlasGemmBatchedInternal( @@ -1141,7 +1145,8 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of bf16 inside DoBlasGemmBatchedInternal. tsl::Status status = DoBlasGemmBatchedInternal( @@ -1162,7 +1167,8 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { tsl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1181,7 +1187,8 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, double beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { tsl::Status status = DoBlasGemmBatchedInternal( cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1199,7 +1206,7 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator,blas::CallContext context) { tsl::Status status = DoBlasGemmBatchedInternal( cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1217,7 +1224,7 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, blas::CallContext context) { tsl::Status status = DoBlasGemmBatchedInternal( cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1234,7 +1241,8 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched( const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #if CUDA_VERSION < 11000 if (dtype == dnn::kHalf) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index b789ed146afba9..8c380bf6b27627 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -6300,7 +6300,7 @@ tsl::Status CudnnSupport::DoConvolve( DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, - dnn::ProfileResult* profile_result) { + dnn::CallContext call_context, dnn::ProfileResult* profile_result) { cudnnDataType_t cudnn_type = ToCudnnDataType(element_type, input_descriptor.layout()); CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type); @@ -6817,7 +6817,8 @@ tsl::Status CudnnSupport::GetConvolveRunners( DeviceMemoryBase /*filter_data*/, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase /*output_data*/, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context, bool use_fallback, ScratchAllocator* /*scratch_allocator*/, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { @@ -6881,7 +6882,8 @@ tsl::Status CudnnSupport::GetConvolveRunners( for (const auto& algo : algorithms) { auto runner_or = ConvolveRunnerFromDesc( stream, algo, kind, input_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor); + filter_descriptor, output_descriptor, convolution_descriptor, + call_context); if (!runner_or.ok()) { // Failures here can result from trying to query the workspace size // for algorithms that aren't supported for the present configuration. @@ -6947,7 +6949,8 @@ CudnnSupport::ConvolveRunnerFromDesc( dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor) { + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context) { if (!algorithm_desc.is_cudnn_frontend()) { CudnnConvolutionDescriptor conv( convolution_descriptor, @@ -8733,7 +8736,8 @@ bool CudnnSupport::DoMatMul(Stream* stream, if (!stream ->ThenBlasGemm(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, NumericOptions{}) + input_data, k, output_data, m, NumericOptions{}, + blas::CallContext::kNone) .ok()) { return false; } @@ -8816,7 +8820,8 @@ bool CudnnSupport::DoMatMul(Stream* stream, stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), - ldc, batch_count, NumericOptions{}); + ldc, batch_count, NumericOptions{}, + blas::CallContext::kNone); } return stream->ok(); diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h index 1949f26265d0b1..ee38c7ad45d883 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h @@ -226,7 +226,7 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, + dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_exec_plans) override; @@ -237,7 +237,8 @@ class CudnnSupport : public dnn::DnnSupport { dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor) override; + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context) override; tsl::Status GetGraphConvolveRunners( dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -481,6 +482,7 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, + dnn::CallContext call_context, dnn::ProfileResult* output_profile_result) override; tsl::Status DoFusedConvolve( diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc index 9c47caeb811888..d3167f2da64cd2 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/dnn.cc @@ -132,7 +132,7 @@ tsl::Status DnnSupport::GetConvolveRunners( const dnn::BatchDescriptor& /*output_descriptor*/, DeviceMemoryBase /*output_data*/, const dnn::ConvolutionDescriptor& /*convolution_descriptor*/, - bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/, + dnn::CallContext call_context, bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/, const NumericOptions& /*numeric_options*/, std::vector>* /*exec_plans*/) { return tsl::errors::Unimplemented("GetConvolveRunners not implemented."); @@ -145,7 +145,8 @@ DnnSupport::ConvolveRunnerFromDesc( dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor) { + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context) { return tsl::errors::Unimplemented("ConvolveRunnerFromDesc not implemented."); } @@ -329,7 +330,7 @@ bool DnnSupport::GetMIOpenConvolveAlgorithms( const dnn::BatchDescriptor& /*output_descriptor*/, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& /*convolution_descriptor*/, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* /*out_algorithms*/) { return false; } diff --git a/tensorflow/compiler/xla/stream_executor/dnn.h b/tensorflow/compiler/xla/stream_executor/dnn.h index e5cb61779e4d1b..177e0bf61038ce 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.h +++ b/tensorflow/compiler/xla/stream_executor/dnn.h @@ -67,6 +67,16 @@ enum class DimIndex : int { Z = 2, }; +// Call context information for GEMM API calls +// This is extra information that can optionally be passed down to the blas +// library, so that it can pick the efficient imlpementation based on context +enum class CallContext { + kNone = 0, // No information + kForward = 1, // call happens in "forward" pass + kBackpropData = 2, // call happens in "backprop" pass for data + kBackpropFilter = 4, // call happens in "backprop" pass for filter +}; + // Return a reordered dims. std::vector ReorderDims(const std::vector& input, const DataLayout& from, const DataLayout& to); @@ -1605,7 +1615,7 @@ class DnnSupport { DeviceMemoryBase output_data, const ConvolutionDescriptor& convolution_descriptor, AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, - ProfileResult* output_profile_result) = 0; + dnn::CallContext call_context, ProfileResult* output_profile_result) = 0; virtual tsl::Status GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, @@ -1616,7 +1626,7 @@ class DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, + dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_exec_plans); @@ -1627,7 +1637,8 @@ class DnnSupport { dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor); + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context); virtual tsl::Status GetGraphConvolveRunners( dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -1781,7 +1792,7 @@ class DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms); // Returns a list of supported rnn algorithms. diff --git a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h index 1ce60c3af57dc5..cd9a19f838c262 100644 --- a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h +++ b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h @@ -146,6 +146,7 @@ struct ConvOp { const FilterDescriptor& filter_descriptor; const BatchDescriptor& output_descriptor; const ConvolutionDescriptor& convolution_descriptor; + dnn::CallContext call_context; }; static tsl::StatusOr>> @@ -154,7 +155,8 @@ struct ConvOp { return stream->ConvolveRunnerFromDesc( desc, config.kind, config.input_type, config.output_type, config.input_descriptor, config.filter_descriptor, - config.output_descriptor, config.convolution_descriptor); + config.output_descriptor, config.convolution_descriptor, + config.call_context); } }; diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc index 34ace3fa2513be..021819f2029ce5 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.cc @@ -146,7 +146,7 @@ tsl::Status BlasLt::Init() { return std::move(layout); } -hipblasDatatype_t BlasLt::MatrixLayout::type() const { return HIPBLAS_R_32F; } +hipblasltDatatype_t BlasLt::MatrixLayout::type() const { return HIPBLASLT_R_32F; } /*static*/ tsl::StatusOr BlasLt::MatmulDesc::Create( blas::ComputationType compute_type, blas::DataType scale_type, @@ -176,8 +176,8 @@ hipblasLtComputeType_t BlasLt::MatmulDesc::compute_type() const { return HIPBLASLT_COMPUTE_F32; } -hipblasDatatype_t BlasLt::MatmulDesc::scale_type() const { - return HIPBLAS_R_32F; +hipblasltDatatype_t BlasLt::MatmulDesc::scale_type() const { + return HIPBLASLT_R_32F; } hipblasPointerMode_t BlasLt::MatmulDesc::pointer_mode() const { diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h index 071437ddfdd1de..5f8258a4ad6306 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h @@ -18,9 +18,21 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" #include "tensorflow/tsl/platform/status.h" +#include "rocm/rocm_config.h" #if TF_HIPBLASLT -#include "rocm/rocm_config.h" +#if TF_ROCM_VERSION < 50700 +#define hipblasltDatatype_t hipblasDatatype_t +#define HIPBLASLT_R_16F HIPBLAS_R_16F +#define HIPBLASLT_R_16B HIPBLAS_R_16B +#define HIPBLASLT_R_32F HIPBLAS_R_32F +#define HIPBLASLT_R_64F HIPBLAS_R_64F +#define HIPBLASLT_R_8I HIPBLAS_R_8I +#define HIPBLASLT_R_32I HIPBLAS_R_32I +#define HIPBLASLT_C_32F HIPBLAS_R_32F +#define HIPBLASLT_C_64F HIPBLAS_R_64F +#endif + #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h" #include "tensorflow/compiler/xla/stream_executor/rocm/hipblaslt_wrapper.h" @@ -54,7 +66,7 @@ class BlasLt { std::optional leading_dim_stride = std::nullopt, std::optional batch_stride = std::nullopt); - hipblasDatatype_t type() const; + hipblasltDatatype_t type() const; hipblasLtMatrixLayout_t get() const { return handle_.get(); } @@ -92,7 +104,7 @@ class BlasLt { PointerMode pointer_mode = PointerMode::kHost); hipblasLtComputeType_t compute_type() const; - hipblasDatatype_t scale_type() const; + hipblasltDatatype_t scale_type() const; hipblasPointerMode_t pointer_mode() const; hipblasLtMatmulDesc_t get() const { return handle_.get(); } diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc index d7dfdf95472d3c..bc965818c559b2 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.cc @@ -17,6 +17,18 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "rocm/rocm_config.h" +#if TF_ROCM_VERSION < 50700 +#define hipblasltDatatype_t hipblasDatatype_t +#define HIPBLASLT_R_16F HIPBLAS_R_16F +#define HIPBLASLT_R_16B HIPBLAS_R_16B +#define HIPBLASLT_R_32F HIPBLAS_R_32F +#define HIPBLASLT_R_64F HIPBLAS_R_64F +#define HIPBLASLT_R_8I HIPBLAS_R_8I +#define HIPBLASLT_R_32I HIPBLAS_R_32I +#define HIPBLASLT_C_32F HIPBLAS_R_32F +#define HIPBLASLT_C_64F HIPBLAS_R_64F +#endif namespace stream_executor { namespace rocm { @@ -30,27 +42,27 @@ tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) { return tsl::OkStatus(); } -hipblasDatatype_t AsHipblasDataType(blas::DataType type) { +hipblasltDatatype_t AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3FN: LOG(FATAL) << "hipblaslt does not support F8 yet"; case blas::DataType::kHalf: - return HIPBLAS_R_16F; + return HIPBLASLT_R_16F; case blas::DataType::kBF16: - return HIPBLAS_R_16B; + return HIPBLASLT_R_16B; case blas::DataType::kFloat: - return HIPBLAS_R_32F; + return HIPBLASLT_R_32F; case blas::DataType::kDouble: - return HIPBLAS_R_64F; + return HIPBLASLT_R_64F; case blas::DataType::kInt8: - return HIPBLAS_R_8I; + return HIPBLASLT_R_8I; case blas::DataType::kInt32: - return HIPBLAS_R_32I; + return HIPBLASLT_R_32I; case blas::DataType::kComplexFloat: - return HIPBLAS_C_32F; + return HIPBLASLT_C_32F; case blas::DataType::kComplexDouble: - return HIPBLAS_C_64F; + return HIPBLASLT_C_64F; default: LOG(FATAL) << "unknown data type"; } diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h index a05b332492061c..eb99a0501d51ac 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/hip_blas_utils.h @@ -23,6 +23,12 @@ limitations under the License. #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status.h" +#include "rocm/rocm_config.h" +#if TF_ROCM_VERSION < 50700 +#define hipblasltDatatype_t hipblasDatatype_t +#endif + + namespace stream_executor { namespace rocm { @@ -30,7 +36,7 @@ namespace rocm { TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr)) tsl::Status ToStatus(hipblasStatus_t status, const char* prefix); -hipblasDatatype_t AsHipblasDataType(blas::DataType type); +hipblasltDatatype_t AsHipblasDataType(blas::DataType type); hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type); hipblasOperation_t AsHipblasOperation(blas::Transpose trans); diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc index 2dc79958fb70a4..f2b474cf0931d6 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "rocm/rocm_config.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h" @@ -420,7 +421,8 @@ tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext call_context) { blas_log("DoBlasGemm"); VLOG(1) << absl::StreamFormat( "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u " @@ -458,6 +460,17 @@ tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, tsl::StatusOr maybe_hasXDLOPS = GpuDriver::GetMFMASupport(); if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.value()) { VLOG(1) << "Using rocblas_gemm_ex"; + + bool is_backprop = + (call_context == blas::CallContext::kBackpropInput1) || + (call_context == blas::CallContext::kBackpropInput2); + + uint32_t flags = rocblas_gemm_flags_none; +#if TF_ROCM_VERSION >= 50000 + if (is_backprop) { + flags = rocblas_gemm_flags_fp16_alt_impl; + } +#endif return DoBlasInternalStatus( wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), @@ -465,7 +478,7 @@ tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r, ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(), rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, flags); } else { VLOG(1) << "Using rocblas_hgemm"; const Eigen::half alpha_half(*static_cast(alpha)); @@ -544,7 +557,8 @@ tsl::Status ROCMBlas::DoBlasGemmWithAlgorithm( blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { // ROCM TODO: properly implement the interface return tsl::errors::Internal("DoBlasGemmWithAlgorithm ", "is not implemented on ROCm yet"); @@ -558,7 +572,8 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { // ROCM TODO: properly implement the interface return tsl::errors::Internal("DoBlasGemmStridedBatchedWithAlgorithm ", "is not implemented on ROCm yet"); @@ -863,7 +878,8 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { blas_log("DoBlasGemmBatched"); const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); @@ -886,7 +902,8 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { blas_log("DoBlasGemmBatched"); const Eigen::bfloat16 alpha_bf16(alpha); const Eigen::bfloat16 beta_bf16(beta); @@ -909,7 +926,8 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, @@ -929,7 +947,8 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, double beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, @@ -948,7 +967,8 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, @@ -968,7 +988,7 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, blas::CallContext context) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, @@ -1098,7 +1118,8 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatched( const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { VLOG(1) << absl::StreamFormat( "doing rocBLAS SGEMM Strided Batched: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc index 8c36dcce5dadd0..0e278573d63ff9 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc @@ -3138,7 +3138,8 @@ class RocmConvRunner : public dnn::ConvRunner { BatchDescriptor input_descriptor, BatchDescriptor output_descriptor, FilterDescriptor filter_descriptor, - ConvolutionDescriptor conv_descriptor) + ConvolutionDescriptor conv_descriptor, + dnn::CallContext call_context) : parent_(parent), miopen_(miopen), algo_id_(algo_id), @@ -3148,7 +3149,18 @@ class RocmConvRunner : public dnn::ConvRunner { input_desc_{input_descriptor, ToMIOpenDataType(input_type)}, output_desc_{output_descriptor, ToMIOpenDataType(input_type)}, filter_desc_{filter_descriptor, ToMIOpenDataType(input_type)}, - conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)} {} + conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)}, + call_context_(call_context) { + bool is_backprop = (call_context == dnn::CallContext::kBackpropData) || + (call_context == dnn::CallContext::kBackpropFilter); + +#if TF_ROCM_VERSION >= 50000 + if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { + wrap::miopenSetConvolutionAttribute( + conv_desc_.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + } +#endif + } std::string ToString() const override { return dnn::AlgorithmDesc{algo_id_, false, workspace_size_}.ToString(); @@ -3273,6 +3285,7 @@ class RocmConvRunner : public dnn::ConvRunner { ScopedTensorDescriptor output_desc_; ScopedFilterDescriptor filter_desc_; ScopedConvolutionDescriptor conv_desc_; + dnn::CallContext call_context_; }; tsl::Status MIOpenSupport::DoConvolve( @@ -3284,12 +3297,13 @@ tsl::Status MIOpenSupport::DoConvolve( DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, - dnn::ProfileResult* output_profile_result) { + dnn::CallContext call_context, dnn::ProfileResult* output_profile_result) { TF_ASSIGN_OR_RETURN( auto runner, ConvolveRunnerFromDesc(stream, algorithm_desc, kind, element_type, output_type, input_descriptor, filter_descriptor, - output_descriptor, convolution_descriptor)); + output_descriptor, convolution_descriptor, + call_context)); return (*runner)(stream, output_profile_result, scratch_memory, input_data, filter_data, output_data); @@ -3302,7 +3316,7 @@ tsl::Status MIOpenSupport::GetConvolveRunners( const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_runners) { if (input_type != output_type) { @@ -3316,7 +3330,7 @@ tsl::Status MIOpenSupport::GetConvolveRunners( if (!GetMIOpenConvolveAlgorithms( kind, input_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, scratch_allocator, &profile_results)) { + convolution_descriptor, scratch_allocator, call_context, &profile_results)) { return tsl::Status( absl::StatusCode::kUnknown, "GetConvolveRunners: GetMIOpenConvolveAlgorithms failed"); @@ -3327,7 +3341,8 @@ tsl::Status MIOpenSupport::GetConvolveRunners( auto runner, ConvolveRunnerFromDesc( stream, profile_result.algorithm(), kind, input_type, output_type, input_descriptor, filter_descriptor, - output_descriptor, convolution_descriptor)); + output_descriptor, convolution_descriptor, + call_context)); out_runners->push_back(std::move(runner)); } @@ -3341,7 +3356,8 @@ MIOpenSupport::ConvolveRunnerFromDesc( dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor) { + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context) { if (input_type != output_type) { return tsl::errors::Unimplemented( absl::StrFormat("MIOpen backend does not support different input and " @@ -3358,7 +3374,7 @@ MIOpenSupport::ConvolveRunnerFromDesc( return {std::make_unique( parent_, miopen_.get(), algorithm_desc.algo_id(), *workspace_size, kind, input_type, use_immediate_mode_, input_descriptor, output_descriptor, - filter_descriptor, convolution_descriptor)}; + filter_descriptor, convolution_descriptor, call_context)}; } bool MIOpenSupport::GetMIOpenConvolveAlgorithms( @@ -3368,19 +3384,19 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms( DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms) { return use_immediate_mode_ ? GetMIOpenConvolveAlgorithmsImmediateMode( kind, element_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, scratch_allocator, - out_algorithms) + call_context, out_algorithms) : GetMIOpenConvolveAlgorithmsFindMode( kind, element_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, scratch_allocator, - out_algorithms); + call_context, out_algorithms); } bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( @@ -3390,7 +3406,7 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms) { auto miopen = miopen_->GetHandle(parent_, stream); @@ -3403,6 +3419,15 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( ScopedConvolutionDescriptor conv{convolution_descriptor, ToMIOpenDataType(element_type)}; + bool is_backprop = (call_context == dnn::CallContext::kBackpropData) || + (call_context == dnn::CallContext::kBackpropFilter); + +#if TF_ROCM_VERSION >= 50000 + if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) { + wrap::miopenSetConvolutionAttribute( + conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + } +#endif // First determine the number of algorityhms available size_t maxSolutionCount = 0; @@ -3599,7 +3624,7 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms) { auto miopen = miopen_->GetHandle(parent_, stream); @@ -3612,6 +3637,15 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( ScopedConvolutionDescriptor conv{convolution_descriptor, ToMIOpenDataType(element_type)}; + bool is_backprop = (call_context == dnn::CallContext::kBackpropData) || + (call_context == dnn::CallContext::kBackpropFilter); + +#if TF_ROCM_VERSION >= 50000 + if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) { + wrap::miopenSetConvolutionAttribute( + conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + } +#endif // Determine the workspace memory size that will need by the call to Find size_t scratch_memory_size = 0; switch (kind) { @@ -4015,7 +4049,8 @@ tsl::Status ROCmFusedMatmulRunner::gemm(Stream* stream, DeviceMemoryBase a_data, return stream->ThenBlasGemm( tb, ta, _n, _m, _k, static_cast>(b_data), _ldb, static_cast>(a_data), _lda, - static_cast*>(&c_data), _ldc, NumericOptions{}); + static_cast*>(&c_data), _ldc, NumericOptions{}, + blas::CallContext::kNone); } template @@ -4182,7 +4217,8 @@ bool MIOpenSupport::DoMatMul(Stream* stream, if (!stream ->ThenBlasGemm(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, NumericOptions{}) + input_data, k, output_data, m, NumericOptions{}, + blas::CallContext::kNone) .ok()) { return false; } @@ -4265,7 +4301,7 @@ bool MIOpenSupport::DoMatMul(Stream* stream, stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), - ldc, batch_count, NumericOptions{}); + ldc, batch_count, NumericOptions{}, blas::CallContext::kNone); } return stream->ok(); diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h index 1649f2987877dc..8e40674dd2cf88 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h @@ -241,7 +241,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, + dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_runners) override; @@ -252,7 +252,8 @@ class MIOpenSupport : public dnn::DnnSupport { dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor) override; + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context) override; bool GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, @@ -262,7 +263,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms) override; bool GetRnnAlgorithms( @@ -334,6 +335,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, + dnn::CallContext call_context, dnn::ProfileResult* output_profile_result) override; tsl::Status DoFusedConvolve( @@ -798,7 +800,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms); bool GetMIOpenConvolveAlgorithmsFindMode( @@ -809,7 +811,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms); SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport); diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index f3c8dd104e5867..3f389c53670559 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -508,6 +508,7 @@ Stream &Stream::ThenConvolve( filter_descriptor, filter_data, output_descriptor, *output, convolution_descriptor, /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(), + dnn::CallContext::kForward, /*output_profile_result=*/nullptr) .ok()); } @@ -1562,11 +1563,12 @@ Stream &Stream::ThenBlasGemmBatched( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, numeric_options, - /*scratch_allocator=*/nullptr); + /*scratch_allocator=*/nullptr, + context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1575,7 +1577,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1584,11 +1587,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( float, DeviceMemorySlice, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *> + ScratchAllocator *, blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1597,7 +1600,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1606,22 +1610,24 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( float, DeviceMemorySlice, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *> + const NumericOptions &, ScratchAllocator *, + blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenBlasGemmBatched( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, const NumericOptions &numeric_options) { + int ldc, int batch_count, const NumericOptions &numeric_options, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, numeric_options, - /*scratch_allocator=*/nullptr); + /*scratch_allocator=*/nullptr, + context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1629,7 +1635,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1637,11 +1643,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( ThenBlasImpl, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *> + const NumericOptions &, ScratchAllocator *, blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, @@ -1651,11 +1657,13 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, numeric_options, - /*scratch_allocator=*/nullptr); + /*scratch_allocator=*/nullptr, + context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1664,7 +1672,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1673,11 +1682,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( double, DeviceMemorySlice, int, DeviceMemorySlice, int, double, DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *> + ScratchAllocator *, blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenBlasGemmBatched( @@ -1686,11 +1695,12 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, numeric_options, - /*scratch_allocator=*/nullptr); + /*scratch_allocator=*/nullptr, + context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1700,7 +1710,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, + blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1709,11 +1720,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( std::complex, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, DeviceMemorySlice>, int, int, - const NumericOptions &, ScratchAllocator *> + const NumericOptions &, ScratchAllocator *, blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenBlasGemmBatched( @@ -1722,11 +1733,13 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options) { + int ldc, int batch_count, const NumericOptions &numeric_options, + blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, numeric_options, - /*scratch_allocator=*/nullptr); + /*scratch_allocator=*/nullptr, + context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1736,7 +1749,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator) { + ScratchAllocator *scratch_allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1745,11 +1758,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( std::complex, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, DeviceMemorySlice>, - int, int, const NumericOptions &, ScratchAllocator *> + int, int, const NumericOptions &, ScratchAllocator *, blas::CallContext> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator); + numeric_options, scratch_allocator, context); } Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index 2445d1b6ee4e1d..5c457e62704a9e 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -359,6 +359,7 @@ class Stream { const dnn::ConvolutionDescriptor &convolution_descriptor, ScratchAllocator *scratch_allocator, const dnn::AlgorithmConfig &algorithm_config, + dnn::CallContext call_context, dnn::ProfileResult *output_profile_result) { DeviceMemory scratch_memory; dnn::AlgorithmDesc algorithm_desc; @@ -373,7 +374,8 @@ class Stream { input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, algorithm_desc, - scratch_memory, output_profile_result); + scratch_memory, call_context, + output_profile_result); } return tsl::errors::Unimplemented("DNN library is not found."); } @@ -426,14 +428,16 @@ class Stream { const dnn::BatchDescriptor &input_descriptor, const dnn::FilterDescriptor &filter_descriptor, const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor) { + const dnn::ConvolutionDescriptor &convolution_descriptor, + dnn::CallContext call_context) { dnn::DnnSupport *dnn_support = parent_->AsDnn(); if (!dnn_support) { return tsl::errors::Unimplemented("DNN library is not found."); } return dnn_support->ConvolveRunnerFromDesc( this, algorithm_desc, kind, element_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor); + filter_descriptor, output_descriptor, convolution_descriptor, + call_context); } tsl::StatusOr> @@ -954,11 +958,12 @@ class Stream { const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { InputType alpha{1.0}; InputType beta{0.0}; return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, numeric_options); + ldc, numeric_options, context); } // TODO(reedwm): Update all callers (if there are any) to pass correct @@ -968,9 +973,10 @@ class Stream { uint64_t m, uint64 n, uint64 k, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, - DeviceMemory *c, int ldc) { + DeviceMemory *c, int ldc, + blas::CallContext context) { return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, - NumericOptions{}); + NumericOptions{}, context); } template @@ -979,7 +985,8 @@ class Stream { const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, - int ldc, const NumericOptions &numeric_options) { + int ldc, const NumericOptions &numeric_options, + blas::CallContext context) { static_assert( detail::is_any_of, @@ -1010,7 +1017,8 @@ class Stream { return blas->DoBlasGemm(this, transa, transb, m, n, k, blas::ToDataType::value, alpha_ptr, a, - lda, b, ldb, beta_ptr, c, ldc, numeric_options); + lda, b, ldb, beta_ptr, c, ldc, numeric_options, + context); } // TODO(reedwm): Update all callers to pass correct NumericOptions. @@ -1020,9 +1028,9 @@ class Stream { const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, - int ldc) { + int ldc, blas::CallContext context) { return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, NumericOptions{}); + ldc, NumericOptions{}, context); } template @@ -1032,12 +1040,13 @@ class Stream { const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { OutputType alpha{1}; OutputType beta{0}; return ThenBlasGemmWithAlgorithm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - computation_type, algorithm, NumericOptions{}, output_profile_result); + computation_type, algorithm, NumericOptions{}, output_profile_result, context); } template @@ -1048,7 +1057,8 @@ class Stream { DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { TF_RETURN_IF_ERROR( CheckTypesForExtendedBlas( computation_type)); @@ -1071,7 +1081,7 @@ class Stream { blas::ToDataType::value, lda, b, blas::ToDataType::value, ldb, beta_ptr, c, blas::ToDataType::value, ldc, computation_type, algorithm, - numeric_options, output_profile_result); + numeric_options, output_profile_result, context); if (output_profile_result) { // The error is recorded in the profile. return ::tsl::OkStatus(); @@ -1087,7 +1097,8 @@ class Stream { int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult *output_profile_result, + blas::CallContext context) { TF_RETURN_IF_ERROR( CheckTypesForExtendedBlas( computation_type)); @@ -1108,7 +1119,8 @@ class Stream { blas::ToDataType::value, lda, stride_a, b, blas::ToDataType::value, ldb, stride_b, beta_ptr, c, blas::ToDataType::value, ldc, stride_c, batch_count, - computation_type, algorithm, numeric_options, output_profile_result); + computation_type, algorithm, numeric_options, output_profile_result, + context); if (output_profile_result) { // The error is recorded in the profile. return ::tsl::OkStatus(); @@ -1126,21 +1138,24 @@ class Stream { DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options); + const NumericOptions &numeric_options, + blas::CallContext context); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options); + const NumericOptions &numeric_options, + blas::CallContext context); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options); + const NumericOptions &numeric_options, + blas::CallContext context); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, @@ -1149,28 +1164,30 @@ class Stream { std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, - const NumericOptions &numeric_options); + const NumericOptions &numeric_options, + blas::CallContext context); Stream &ThenBlasGemmBatched( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options); + int ldc, int batch_count, const NumericOptions &numeric_options, + blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator,blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator,blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, @@ -1179,14 +1196,15 @@ class Stream { float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator, + blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator,blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, @@ -1194,7 +1212,7 @@ class Stream { DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator,blas::CallContext context); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, @@ -1202,7 +1220,7 @@ class Stream { DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator); + ScratchAllocator *scratch_allocator,blas::CallContext context); template tsl::Status ThenBlasGemmStridedBatched( @@ -1211,7 +1229,8 @@ class Stream { int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options) { + const NumericOptions &numeric_options, + blas::CallContext context) { static_assert( detail::is_any_of, @@ -1238,7 +1257,7 @@ class Stream { return blas->DoBlasGemmStridedBatched( this, transa, transb, m, n, k, blas::ToDataType::value, alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, numeric_options); + stride_c, batch_count, numeric_options, context); } // See BlasSupport::DoBlasTrsm. diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc index 43b0e5f3c1af5f..9845de719bd49b 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc @@ -272,7 +272,7 @@ tsl::Status StreamExecutor::GetConvolveRunners( const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { dnn::DnnSupport* dnn_support = AsDnn(); @@ -282,7 +282,7 @@ tsl::Status StreamExecutor::GetConvolveRunners( return dnn_support->GetConvolveRunners( use_cudnn_frontend, kind, input_type, output_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, use_fallback, + output_descriptor, output_data, convolution_descriptor, call_context, use_fallback, scratch_allocator, numeric_options, out_exec_plans); } @@ -356,7 +356,7 @@ bool StreamExecutor::GetMIOpenConvolveAlgorithms( DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, + ScratchAllocator* scratch_allocator, dnn::CallContext call_context, std::vector* out_algorithms) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { @@ -365,7 +365,7 @@ bool StreamExecutor::GetMIOpenConvolveAlgorithms( return dnn_support->GetMIOpenConvolveAlgorithms( kind, element_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, scratch_allocator, out_algorithms); + convolution_descriptor, scratch_allocator, call_context, out_algorithms); } bool StreamExecutor::GetRnnAlgorithms( diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h index 74eb78cc10268d..f77dc18b451ae9 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h @@ -373,6 +373,7 @@ class StreamExecutor { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::CallContext call_context, bool use_fallback, ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_exec_plans); @@ -423,6 +424,7 @@ class StreamExecutor { DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, ScratchAllocator* scratch_allocator, + dnn::CallContext call_context, std::vector* out_algorithms); // Returns the list of supported algorithms for rnn operation. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4567a4acbb4db5..dddf0738292e3f 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -8,6 +8,7 @@ load( ) load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", @@ -2280,7 +2281,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "llvm_compiler_test", srcs = if_gpu_is_configured(["llvm_compiler_test.cc"]), tags = tf_cuda_tests_tags(), @@ -2375,7 +2376,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "local_client_aot_test", srcs = [ "local_client_aot_test.cc", @@ -2458,7 +2459,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_metadata_test", srcs = [ "hlo_metadata_test.cc", @@ -2525,7 +2526,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], deps = [ @@ -2585,7 +2586,7 @@ xla_test( ) # A demo of test that loads an hlo module from a file and compares results on gpu and cpu. -xla_cc_test( +tf_cc_test( name = "sample_file_test", srcs = ["sample_file_test.cc"], data = ["isolated_convolution.hlo"], @@ -2639,7 +2640,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "multiple_devices_on_host_test", srcs = ["multiple_devices_on_host_test.cc"], args = ["--xla_force_host_platform_device_count=4"], @@ -2770,7 +2771,7 @@ xla_test( ], ) -xla_cc_test( +tf_cc_test( name = "tile_assignment_test", srcs = ["tile_assignment_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 425c144d1a9033..81b2ff0a6b7ab8 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -1,6 +1,6 @@ """Build rules for XLA testing.""" -load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", @@ -22,8 +22,8 @@ def xla_test( disabled_backends = [], real_hardware_only = False, args = [], - tags = [], copts = [], + tags = [], data = [], backend_tags = {}, backend_args = {}, @@ -132,11 +132,11 @@ def xla_test( for lib_dep in xla_test_library_deps: backend_deps += ["%s_%s" % (lib_dep, backend)] - xla_cc_test( + tf_cc_test( name = test_name, srcs = srcs, tags = tags + backend_tags.get(backend, []) + this_backend_tags, - copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + this_backend_copts, args = args + this_backend_args, deps = deps + backend_deps, diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 21867d3f9f6def..db86bd4fc27409 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -21,6 +21,7 @@ load( "xla_cc_test", "xla_py_proto_library", ) +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library", @@ -238,7 +239,7 @@ xla_cc_binary( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_extractor_test", srcs = ["hlo_extractor_test.cc"], deps = [ @@ -327,7 +328,7 @@ xla_cc_binary( ]), ) -xla_cc_test( +tf_cc_test( name = "interactive_graphviz_bin_test", srcs = ["interactive_graphviz_bin_test.cc"], data = [ @@ -363,7 +364,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_module_loader_test", srcs = ["hlo_module_loader_test.cc"], deps = [ @@ -463,7 +464,7 @@ xla_cc_binary( ]), ) -xla_cc_test( +tf_cc_test( name = "run_hlo_module_bin_test", srcs = ["run_hlo_module_bin_test.cc"], data = [ @@ -497,7 +498,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_control_flow_flattening_test", srcs = ["hlo_control_flow_flattening_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/tools/hlo_bisect/BUILD b/tensorflow/compiler/xla/tools/hlo_bisect/BUILD index 6ac516f14c1b5d..2aced2b5ce7601 100644 --- a/tensorflow/compiler/xla/tools/hlo_bisect/BUILD +++ b/tensorflow/compiler/xla/tools/hlo_bisect/BUILD @@ -1,5 +1,6 @@ # Description: # A tool for reducing a HLO module that produces incorrect results. +load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary") load( "//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary", @@ -55,7 +56,7 @@ cc_library( ], ) -xla_cc_test( +tf_cc_test( name = "hlo_bisect_state_test", srcs = ["hlo_bisect_state_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index e09c4d9a2d50b6..71f5afd9ef945d 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -880,6 +880,10 @@ void SetMatmulAttributes(OpT op, const xla::gpu::GemmBackendConfig& config, } op.setPrecisionConfigAttr( xla::ConvertPrecisionConfig(&config.precision_config(), &builder)); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + op.setGradXAttr(builder.getBoolAttr(config.grad_x())); + op.setGradYAttr(builder.getBoolAttr(config.grad_y())); +#endif } tsl::StatusOr AsLhloEpilogue( @@ -1221,6 +1225,8 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( attrs.set(op.getBackendConfigAttrName(), config); op->setAttrs(attrs.getDictionary(op->getContext())); + op->setAttr("call_context", + builder_.getStringAttr(backend_config.call_context())); return op.getOperation(); }; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index 69fc08bb3d364f..90fa1b7bfdf45a 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -740,7 +740,8 @@ void LaunchConvBackpropFilterOpImpl( OP_REQUIRES_OK(context, stream->ThenBlasGemm( se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, a_ptr, - n, b_ptr, m, &c_ptr, n, GetNumericOptions())); + n, b_ptr, m, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kBackpropInput2)); return; } else if (!is_grouped_convolution && dims.filter_size(0) == dims.input_size(0) && @@ -762,7 +763,8 @@ void LaunchConvBackpropFilterOpImpl( OP_REQUIRES_OK(context, stream->ThenBlasGemm( se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, b_ptr, - n, a_ptr, m, &c_ptr, n, GetNumericOptions())); + n, a_ptr, m, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kBackpropInput2)); return; } diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index 403a6122d7f273..b13dec83cc8c51 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -264,7 +264,8 @@ void LaunchConv2DBackpropFilterOpImpl( OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, a_ptr, n, b_ptr, m, &c_ptr, - n, GetNumericOptions())); + n, GetNumericOptions(), + se::blas::CallContext::kBackpropInput2)); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -289,7 +290,8 @@ void LaunchConv2DBackpropFilterOpImpl( OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, b_ptr, n, a_ptr, m, &c_ptr, - n, GetNumericOptions())); + n, GetNumericOptions(), + se::blas::CallContext::kBackpropInput2)); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index bf6b7dc986e3fb..fbe730c337b3b7 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -162,7 +162,8 @@ void LaunchConv2DBackpropInputOpGpuImpl( OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + stream_executor::blas::CallContext::kBackpropInput1)); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -189,7 +190,8 @@ void LaunchConv2DBackpropInputOpGpuImpl( OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + stream_executor::blas::CallContext::kBackpropInput1)); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index 5e15e72a66eaa0..8ab7b1d523cd1e 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -741,7 +741,8 @@ void LaunchConvBackpropInputOpImpl( OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, k, &c_ptr, - n, GetNumericOptions())); + n, GetNumericOptions(), + stream_executor::blas::CallContext::kBackpropInput1)); return; } else if (!is_grouped_convolution && dims.filter_size(0) == dims.input_size(0) && @@ -765,7 +766,8 @@ void LaunchConvBackpropInputOpImpl( OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, k, &c_ptr, - n, GetNumericOptions())); + n, GetNumericOptions(), + stream_executor::blas::CallContext::kBackpropInput1)); return; } @@ -1156,7 +1158,6 @@ class Conv3DBackpropInputOp : public OpKernel { bool cudnn_use_autotune_; }; - #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint("T"), \ diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index b380f87d588565..7e17c6932cfa23 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -259,7 +259,8 @@ void LaunchConv3DOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } else if (!is_grouped_convolution && filter_planes == in_planes && filter_rows == in_rows && filter_cols == in_cols && @@ -280,7 +281,8 @@ void LaunchConv3DOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } diff --git a/tensorflow/core/kernels/conv_ops_gpu.cc b/tensorflow/core/kernels/conv_ops_gpu.cc index 085c1d4c1765ef..476e804e30e8c6 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cc @@ -247,6 +247,22 @@ StatusOr> AutotuneUnfusedConv( auto* stream = ctx->op_device_context()->stream(); + se::dnn::CallContext call_context = se::dnn::CallContext::kNone; + switch (kind) { + case se::dnn::ConvolutionKind::FORWARD: + call_context = se::dnn::CallContext::kForward; + break; + case se::dnn::ConvolutionKind::BACKWARD_DATA: + call_context = se::dnn::CallContext::kBackpropData; + break; + case se::dnn::ConvolutionKind::BACKWARD_FILTER: + call_context = se::dnn::CallContext::kBackpropFilter; + break; + default: + return errors::InvalidArgument( + absl::StrFormat("Unknown ConvolutionKind %d", kind)); + } + if (!autotune_map->Find(conv_parameters, &autotune_entry)) { profiler::ScopedAnnotation annotation("cudnn_autotuning"); @@ -283,7 +299,7 @@ StatusOr> AutotuneUnfusedConv( TF_RETURN_IF_ERROR(stream->parent()->GetConvolveRunners( CudnnUseFrontend(), kind, element_type, element_type, stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc, output_ptr, - conv_desc, /*use_fallback=*/false, &rz_allocator, GetNumericOptions(), + conv_desc, call_context, /*use_fallback=*/false, &rz_allocator, GetNumericOptions(), &runners)); auto launch_func = [&](se::ScratchAllocator* allocator_used, @@ -328,7 +344,7 @@ StatusOr> AutotuneUnfusedConv( TF_RETURN_IF_ERROR(stream->parent()->GetConvolveRunners( CudnnUseFrontend(), kind, element_type, element_type, stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc, - output_ptr, conv_desc, /*use_fallback=*/true, &rz_allocator, + output_ptr, conv_desc, call_context, /*use_fallback=*/true, &rz_allocator, GetNumericOptions(), &fallback_runners)); TF_ASSIGN_OR_RETURN(auto fallback_results, @@ -353,7 +369,7 @@ StatusOr> AutotuneUnfusedConv( if (!stream->parent()->GetMIOpenConvolveAlgorithms( kind, se::dnn::ToDataType::value, stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc, output_ptr, conv_desc, - &scratch_allocator, &algorithms)) { + &scratch_allocator, call_context, &algorithms)) { return errors::Unknown( "Failed to get convolution algorithm. This is probably " "because MIOpen failed to initialize, so try looking to " @@ -379,7 +395,7 @@ StatusOr> AutotuneUnfusedConv( output_ptr, conv_desc, &scratch_allocator, se::dnn::AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()), - &profile_result); + call_context, &profile_result); if (miopen_launch_status.ok() && profile_result.is_valid()) { results.emplace_back(); auto& result = results.back(); diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index c3ba9dd4cd4397..c9a456d2c31d13 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -174,6 +174,22 @@ Status LaunchAutotunedConv(const AutotuneEntry& autotune_entry, const se::dnn::ConvolutionDescriptor& conv_desc, const se::dnn::BatchDescriptor& output_desc, se::DeviceMemory out_ptr) { + se::dnn::CallContext call_context = se::dnn::CallContext::kNone; + switch (kind) { + case se::dnn::ConvolutionKind::FORWARD: + call_context = se::dnn::CallContext::kForward; + break; + case se::dnn::ConvolutionKind::BACKWARD_DATA: + call_context = se::dnn::CallContext::kBackpropData; + break; + case se::dnn::ConvolutionKind::BACKWARD_FILTER: + call_context = se::dnn::CallContext::kBackpropFilter; + break; + default: + return errors::InvalidArgument( + absl::StrFormat("Unknown ConvolutionKind %d", kind)); + } + if (!autotune_entry.is_algorithm_config()) { const auto& runners = autotune_entry.GetOpRunners(); se::dnn::DataType element_type = se::dnn::ToDataType::value; @@ -201,7 +217,7 @@ Status LaunchAutotunedConv(const AutotuneEntry& autotune_entry, return stream->ConvolveWithAlgorithm( kind, input_desc, in_ptr, filter_desc, filter_ptr, output_desc, out_ptr, conv_desc, scratch_allocator, autotune_entry.GetAlgorithmConfig(), - nullptr); + call_context, nullptr); } } diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index 73dbb7a292677f..96b4fd5942b75c 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -804,7 +804,8 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK(context, stream->ThenBlasGemm( no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } else if (!is_grouped_convolution && filter_same_dims && padding == VALID && data_format == FORMAT_NHWC) { @@ -825,7 +826,8 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK(context, stream->ThenBlasGemm( no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } @@ -1311,7 +1313,8 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } else if (patch_rows == in_rows && patch_cols == in_cols && !is_grouped_convolution && row_dilation == 1 && @@ -1333,7 +1336,8 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, GetNumericOptions())); + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kForward)); return; } diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 003a49d5b1c659..6bedafd99f2bb1 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -471,6 +471,7 @@ struct EinsumHelper { ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, /*adj_y=*/false, trans_x, trans_y, + /*grad_x=*/false, /*gradj_x=*/false, bcast, &output_reshaped); return OkStatus(); } diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 427f790559eb4e..de054548c60710 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -410,7 +410,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { typedef ParallelMatMulKernel::IsComplex> ParallelMatMulKernel; bool conjugate_result = false; @@ -539,7 +540,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, se::blas::Transpose::kConjugateTranspose}; @@ -582,6 +584,14 @@ struct LaunchBatchMatMul { std::is_same_v; using Coefficient = std::conditional_t; + se::blas::CallContext call_context = se::blas::CallContext::kNone; + if (grad_x) { + call_context = se::blas::CallContext::kBackpropInput1; + } + if (grad_y) { + call_context = se::blas::CallContext::kBackpropInput2; + } + #if GOOGLE_CUDA || TF_HIPBLASLT static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; @@ -718,7 +728,7 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator) + GetNumericOptions(), &scratch_allocator, call_context) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -816,7 +826,8 @@ struct LaunchBatchMatMul { blas_transpose_b, blas_transpose_a, n, m, k, *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), adj_x || trans_x ? m : k, - c_ptrs[0], n, GetNumericOptions())); + c_ptrs[0], n, GetNumericOptions(), + call_context)); } else if (use_strided_batched) { OP_REQUIRES_OK( context, stream->ThenBlasGemmStridedBatched( @@ -825,7 +836,8 @@ struct LaunchBatchMatMul { adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], adj_x || trans_x ? m : k, a_stride, static_cast(0.0), c_ptrs[0], n, c_stride, - batch_size, GetNumericOptions())); + batch_size, GetNumericOptions(), + call_context)); } else { BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = @@ -835,7 +847,8 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator) + GetNumericOptions(), &scratch_allocator, + call_context) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -865,11 +878,15 @@ class BaseBatchMatMulOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_)); adj_x_ = false; adj_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_a", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_b", &grad_input_2_)); } else { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); trans_x_ = false; trans_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_x", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_y", &grad_input_2_)); } } @@ -954,7 +971,7 @@ class BaseBatchMatMulOp : public OpKernel { LaunchBatchMatMul::Launch( ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_, - trans_y_, bcast, &out_reshaped_float); + trans_y_, grad_input_1_, grad_input_2_, bcast, &out_reshaped_float); FloatToBFloat16(out_reshaped_float.flat().data(), out_reshaped.flat().data(), out->NumElements()); } else { @@ -966,9 +983,9 @@ class BaseBatchMatMulOp : public OpKernel { if (!std::is_same::value) { in1_reshaped = CastTensor(in1_reshaped); } - LaunchBatchMatMul::Launch(ctx, in0_reshaped, in1_reshaped, - adj_x_, adj_y_, trans_x_, - trans_y_, bcast, &out_reshaped); + LaunchBatchMatMul::Launch( + ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, trans_x_, trans_y_, + grad_input_1_, grad_input_2_, bcast, &out_reshaped); } } @@ -982,6 +999,8 @@ class BaseBatchMatMulOp : public OpKernel { bool adj_y_ = false; bool trans_x_ = false; bool trans_y_ = false; + bool grad_input_1_ = false; + bool grad_input_2_ = false; // Cast `t` from `SrcT` to `DstT`. template diff --git a/tensorflow/core/kernels/rnn/blas_gemm.cc b/tensorflow/core/kernels/rnn/blas_gemm.cc index b83de9f75201ae..e8cb77be22e721 100644 --- a/tensorflow/core/kernels/rnn/blas_gemm.cc +++ b/tensorflow/core/kernels/rnn/blas_gemm.cc @@ -54,7 +54,8 @@ void TensorCuBlasGemm::operator()(OpKernelContext* ctx, bool transa, ctx, ctx->op_device_context()->stream()->ThenBlasGemm( trans[transa], trans[transb], m, n, k, static_cast(alpha), a_ptr, lda, b_ptr, ldb, static_cast(beta), &c_ptr, ldc, - GetNumericOptions())); + GetNumericOptions(), + se::blas::CallContext::kNone)); #else ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); #endif diff --git a/tensorflow/core/kernels/rnn/blas_gemm.h b/tensorflow/core/kernels/rnn/blas_gemm.h index 3fe247337cf57b..c93ba21af49f84 100644 --- a/tensorflow/core/kernels/rnn/blas_gemm.h +++ b/tensorflow/core/kernels/rnn/blas_gemm.h @@ -25,6 +25,10 @@ limitations under the License. #include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" #endif +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + namespace tensorflow { class OpKernelContext; namespace functor { diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 2e612135233f7b..9d251de0bb9f97 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -126,6 +126,8 @@ REGISTER_OP("BatchMatMul") "complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulShape); REGISTER_OP("BatchMatMulV2") @@ -137,6 +139,8 @@ REGISTER_OP("BatchMatMulV2") "uint16, uint32, uint64, complex64, complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); REGISTER_OP("BatchMatMulV3") @@ -154,6 +158,8 @@ REGISTER_OP("BatchMatMulV3") "complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); #ifdef INTEL_MKL @@ -173,6 +179,8 @@ REGISTER_OP("_MklBatchMatMulV2") .Attr("T: {bfloat16, float}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); #endif // INTEL_MKL @@ -1080,6 +1088,8 @@ REGISTER_OP("MatMul") .Attr( "T: {bfloat16, half, float, double, int32, int64, uint8, " "uint16, uint32, uint64, complex64, complex128}") + .Attr("grad_a: bool = false") + .Attr("grad_b: bool = false") .SetShapeFn(shape_inference::MatMulShape); #ifdef INTEL_MKL @@ -1090,6 +1100,8 @@ REGISTER_OP("_MklMatMul") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") .Attr("T: {bfloat16, float}") + .Attr("grad_a: bool = false") + .Attr("grad_b: bool = false") .SetShapeFn(shape_inference::MatMulShape); #endif // INTEL_MKL diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 73d793d535eaa2..8df32bde385ed9 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4204,6 +4204,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatMulV2" @@ -4254,6 +4268,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatMulV3" @@ -4338,6 +4366,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatrixBandPart" @@ -26613,6 +26655,20 @@ op { } } } + attr { + name: "grad_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_b" + type: "bool" + default_value { + b: false + } + } } op { name: "MatchingFiles" diff --git a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py index 03b3d18f5b096f..4894142176fd96 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py @@ -2708,7 +2708,8 @@ def testInputGradientKernelSizeMatchesInputSize(self): padding="VALID", test_input=True, data_format=data_format, - use_gpu=use_gpu) + use_gpu=use_gpu, + max_err=0.005 if test.is_built_with_rocm() else 0.003) @test_util.deprecated_graph_mode_only def testFilterGradientKernelSizeMatchesInputSize(self): diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 06699244a414e5..06fe870bb664bf 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -1527,6 +1527,8 @@ def _matmul( # pylint:disable=missing-docstring a_is_sparse=False, b_is_sparse=False, output_type=None, # pylint: disable=unused-argument + grad_a=False, + grad_b=False, name=None): if transpose_a or transpose_b: raise ValueError("Transposing not supported at this time.") diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 71d03a2704eea7..d29a6b6cb7872f 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1701,13 +1701,13 @@ def _MatMulGradAgainstFirstOnly(op, grad): t_b = op.get_attr("transpose_b") b = math_ops.conj(op.inputs[1]) if not t_a and not t_b: - grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True) elif not t_a and t_b: - grad_a = gen_math_ops.mat_mul(grad, b) + grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True) elif t_a and not t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True) elif t_a and t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True, grad_a=True) return grad_a, None @@ -1717,13 +1717,13 @@ def _MatMulGradAgainstSecondOnly(op, grad): t_b = op.get_attr("transpose_b") a = math_ops.conj(op.inputs[0]) if not t_a and not t_b: - grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True) elif not t_a and t_b: - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True) elif t_a and not t_b: - grad_b = gen_math_ops.mat_mul(a, grad) + grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True) elif t_a and t_b: - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True, grad_b=True) return None, grad_b @@ -1746,17 +1746,17 @@ def _MatMulGrad(op, grad): a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) if not t_a and not t_b: - grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) - grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True) elif not t_a and t_b: - grad_a = gen_math_ops.mat_mul(grad, b) - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True) elif t_a and not t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) - grad_b = gen_math_ops.mat_mul(a, grad) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True) elif t_a and t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True, grad_a=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True, grad_b=True) return grad_a, grad_b @@ -1843,18 +1843,18 @@ def _BatchMatMul(op, grad): if not adj_x: if not adj_y: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False, grad_b=True) else: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False, grad_a=True) + grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False, grad_b=True) else: if not adj_y: - grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) + grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False, grad_b=True) else: - grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) + grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True, grad_b=True) return grad_x, grad_y @@ -1870,18 +1870,18 @@ def _BatchMatMulV2(op, grad): if not adj_x: if not adj_y: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False, grad_b=True) else: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False, grad_a=True) + grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False, grad_b=True) else: if not adj_y: - grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) + grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False, grad_b=True) else: - grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) + grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True, grad_a=True) + grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True, grad_b=True) # Possibly reduce along the broadcasted batch dimensions, if broadcasting # is required. diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index fcda7d59e4c5e7..a5230f5e299e35 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -3624,6 +3624,8 @@ def matmul(a, a_is_sparse=False, b_is_sparse=False, output_type=None, + grad_a=False, + grad_b=False, name=None): """Multiplies matrix `a` by matrix `b`, producing `a` * `b`. @@ -3789,10 +3791,12 @@ def matmul(a, adjoint_b = True if use_batch_matmul_v3: return gen_math_ops.batch_mat_mul_v3( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name) + a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, + grad_x=grad_a, grad_y=grad_b, name=name) else: return gen_math_ops.batch_mat_mul_v2( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name) + a, b, adj_x=adjoint_a, adj_y=adjoint_b, + grad_x=grad_a, grad_y=grad_b, name=name) # Neither matmul nor sparse_matmul support adjoint, so we conjugate # the matrix and use transpose instead. Conj() is a noop for real @@ -3836,10 +3840,12 @@ def matmul(a, adjoint_a = adjoint_a or transpose_a adjoint_b = adjoint_b or transpose_b return gen_math_ops.batch_mat_mul_v3( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name) + a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, + grad_x=grad_a, grad_y=grad_b, name=name) else: return gen_math_ops.mat_mul( - a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name) + a, b, transpose_a=transpose_a, transpose_b=transpose_b, + grad_a=grad_a, grad_b=grad_b, name=name) @tf_export("linalg.matvec") diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py index ef98fb344aed2d..382a2e86c6ba9e 100644 --- a/tensorflow/python/ops/ragged/ragged_math_ops.py +++ b/tensorflow/python/ops/ragged/ragged_math_ops.py @@ -801,6 +801,8 @@ def matmul(a: ragged_tensor.RaggedOrDense, a_is_sparse=False, b_is_sparse=False, output_type=None, + grad_a=False, + grad_b=False, name=None): """Multiplies matrix `a` by matrix `b`. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index d517b4a6219751..e8b27d1124aff1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -186,7 +186,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_rank" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 591f10f6a2129b..35161ca326ae08 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1674,7 +1674,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_band_part" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index b0f235c59d46f8..78044e942e552f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -398,15 +398,15 @@ tf_module { } member_method { name: "BatchMatMul" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV2" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV3" - argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatrixBandPart" @@ -2466,7 +2466,7 @@ tf_module { } member_method { name: "MatMul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "MatchingFiles" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index 2319f6abb046b6..b1861f63d55b8d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -198,7 +198,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_rank" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 7bb360dfe4bc37..06bc5da880ce34 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -834,7 +834,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_square_root" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index b0f235c59d46f8..78044e942e552f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -398,15 +398,15 @@ tf_module { } member_method { name: "BatchMatMul" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV2" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV3" - argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatrixBandPart" @@ -2466,7 +2466,7 @@ tf_module { } member_method { name: "MatMul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "MatchingFiles" diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 905a984732bbb6..f4fdfb50c1f4fe 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -199,6 +199,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")