Skip to content

Commit

Permalink
Merge pull request tensorflow#2219 from ROCmSoftwarePlatform/develop-…
Browse files Browse the repository at this point in the history
…upstream-QA-rocm60-denorm

[develop-upstream-QA-rocm60] Changes to track call-context information (forw…
  • Loading branch information
jayfurmanek authored Sep 14, 2023
2 parents 0c64573 + 988f560 commit 5b0eb43
Show file tree
Hide file tree
Showing 92 changed files with 1,196 additions and 641 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHa
// Summary:
//
// Description:
Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, const char* name, const char* raw_device_name) {
Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, bool grad_a, bool grad_b, const char* name, const char* raw_device_name) {
AbstractOperationPtr op_ptr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op_ptr->Reset("MatMul", raw_device_name));
TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name));
TF_RETURN_IF_ERROR(op_ptr->AddInput(a));
TF_RETURN_IF_ERROR(op_ptr->AddInput(b));
TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_a", transpose_a));
TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_b", transpose_b));
TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_a", grad_a));
TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_b", grad_b));
int num_retvals = 1;
return op_ptr->Execute(absl::MakeSpan(product, 1), &num_retvals);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ops {
Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHandle** y, const char* name = nullptr, const char* raw_device_name = nullptr);

//
Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, const char* name = nullptr, const char* raw_device_name = nullptr);
Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, bool grad_a = false, bool grad_b = false, const char* name = nullptr, const char* raw_device_name = nullptr);

//
Status IdentityN(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> input, absl::Span<AbstractTensorHandle*> output, const char* name = nullptr, const char* raw_device_name = nullptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @batchmatmulv2(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
// SUPPORTED_FALLBACK_DEVICE: mhlo.dot_general
// SUPPORTED_FALLBACK_DEVICE: mhlo.transpose

%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
func.return %0 : tensor<3x4x4xf32>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK: mhlo.reduce
// CHECK: mhlo.dot_general
// CHECK: mhlo.transpose
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
func.return %0 : tensor<3x4x4xf32>
}

Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class BatchMatMulOp : public XlaOpKernel {
explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_x", &grad_x_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_y", &grad_y_));

if (ctx->HasAttr("Tout")) {
DataType output_type;
Expand All @@ -50,13 +52,15 @@ class BatchMatMulOp : public XlaOpKernel {
: xla::PrecisionConfig::HIGHEST;
auto result = xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_,
MaybeConjugate(ctx->Input(1), adj_y_), adj_y_,
precision, preferred_element_type_);
precision, grad_x_, grad_y_, preferred_element_type_);
ctx->SetOutput(0, result);
}

private:
bool adj_x_;
bool adj_y_;
bool grad_x_;
bool grad_y_;
std::optional<xla::PrimitiveType> preferred_element_type_;
};

Expand Down
52 changes: 50 additions & 2 deletions tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
#include "tensorflow/compiler/jit/flags.h"

namespace tensorflow {
namespace {
Expand Down Expand Up @@ -246,6 +247,8 @@ StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
// Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));

xla::XlaOp conv_forward;

// For 2D convolution, there should be 4 dimensions.
int num_dims = attrs.num_spatial_dims + 2;
if (input_shape.dimensions_size() != num_dims) {
Expand Down Expand Up @@ -329,13 +332,32 @@ StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
xla::PrecisionConfig precision_config = GetPrecisionConfig();

if (padding_type != xla::PaddingType::PADDING_INVALID) {
return xla::DynamicConvForward(
conv_forward = xla::DynamicConvForward(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth
: feature_group_count,
/*batch_group_count=*/1, &precision_config, padding_type);
} else {
conv_forward = xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth
: feature_group_count,
/*batch_group_count=*/1, &precision_config);
}

// TODO : CHECK_EQ(HloOpcode::kConvolution, conv_forward->opcode());

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED
auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
state = tensorflow::GetMlirBridgeRolloutState(std::nullopt);
if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(conv_forward, "call_context",
"kForward"));
}
#endif

return xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
Expand All @@ -361,6 +383,8 @@ StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(StringPiece type_string,
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(out_backprop));

xla::XlaOp input_backprop;

int64_t in_depth = input_shape.dimensions(feature_dim),
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
feature_group_count =
Expand Down Expand Up @@ -429,7 +453,7 @@ StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(StringPiece type_string,
filter = xla::Rev(filter, kernel_spatial_dims);
if (padding_type != xla::PaddingType::PADDING_INVALID) {
TF_RET_CHECK(input_sizes != nullptr);
return xla::DynamicConvInputGrad(
input_backprop = xla::DynamicConvInputGrad(
*input_sizes, out_backprop, filter, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums,
/*feature_group_count=*/
Expand All @@ -443,6 +467,18 @@ StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(StringPiece type_string,
/*feature_group_count=*/
feature_group_count,
/*batch_group_count=*/1, &precision_config);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED
auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
state = tensorflow::GetMlirBridgeRolloutState(std::nullopt);
if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(input_backprop, "call_context",
"kBackpropData"));
}
#endif

return input_backprop;
}

StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(StringPiece type_string,
Expand Down Expand Up @@ -598,6 +634,18 @@ StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(StringPiece type_string,
/*batch_group_count=*/batch_group_count, &precision_config);
}

// TODO : CHECK_EQ(HloOpcode::kConvolution, filter_backprop->opcode());

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Set call_context attribute, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED
auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
state = tensorflow::GetMlirBridgeRolloutState(std::nullopt);
if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(filter_backprop, "call_context",
"kBackpropFilter"));
}
#endif

if (attrs.depthwise) {
filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
}
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/tf2xla/kernels/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class MatMulOp : public XlaOpKernel {
: XlaOpKernel(ctx), is_sparse_(is_sparse) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_a", &grad_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_b", &grad_b_));
if (is_sparse) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Ta", &a_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tb", &b_type_));
Expand Down Expand Up @@ -96,13 +98,15 @@ class MatMulOp : public XlaOpKernel {
? xla::PrecisionConfig::DEFAULT
: xla::PrecisionConfig::HIGHEST;
ctx->SetOutput(0,
xla::BatchDot(a, transpose_a_, b, transpose_b_, precision));
xla::BatchDot(a, transpose_a_, b, transpose_b_, precision, grad_a_, grad_b_));
}

private:
bool is_sparse_;
bool transpose_a_;
bool transpose_b_;
bool grad_a_;
bool grad_b_;
DataType a_type_;
DataType b_type_;
};
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/jit:flags",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
41 changes: 32 additions & 9 deletions tensorflow/compiler/xla/client/lib/matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/jit/flags.h"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -45,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/jit/flags.h"

namespace xla {

Expand Down Expand Up @@ -385,25 +387,27 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64_t> x_config,
xla::XlaOp y, absl::Span<const int64_t> y_config,
absl::Span<const int64_t> output_config,
xla::PrecisionConfig::Precision precision,
std::optional<PrimitiveType> preferred_element_type) {
std::optional<PrimitiveType> preferred_element_type,
bool grad_x, bool grad_y) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
if (x_diagonal_labels) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
y_config, output_config, precision, preferred_element_type);
y_config, output_config, precision, preferred_element_type,
grad_x, grad_y);
}
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
if (y_diagonal_labels) {
return Einsum(x, x_config, EinsumDiagonal(y, y_config),
y_diagonal_labels->at(0), output_config, precision,
preferred_element_type);
preferred_element_type, grad_x, grad_y);
}
auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
if (output_diagonal_labels) {
return EinsumInverseDiagonal(
Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
precision, preferred_element_type),
precision, preferred_element_type, grad_x, grad_y),
output_config);
}

Expand Down Expand Up @@ -547,6 +551,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64_t> x_config,
precision_proto.add_operand_precision(precision);
auto dot =
DotGeneral(x, y, dnums, &precision_proto, preferred_element_type);


#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Set grad_x, grad_y attributes, but only if !MLIR_BRIDGE_ROLLOUT_ENABLED
auto state = tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
state = tensorflow::GetMlirBridgeRolloutState(std::nullopt);
if (state != tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(dot, "grad_x",
(grad_x ? "true" : "false")));
TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(dot, "grad_y",
(grad_y ? "true" : "false")));
}
#endif

dot = Transpose(dot, transpose_dims);
if (transpose_rank == output_rank) {
return dot;
Expand All @@ -573,11 +591,12 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64_t> x_config,

XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision,
std::optional<PrimitiveType> preferred_element_type) {
return BatchDot(x, false, y, false, precision, preferred_element_type);
return BatchDot(x, false, y, false, precision, false, false, preferred_element_type);
}

XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
PrecisionConfig::Precision precision,
bool grad_x, bool grad_y,
std::optional<PrimitiveType> preferred_element_type) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
Expand All @@ -588,7 +607,8 @@ XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
if (transpose_y) {
std::swap(string[6 + 3], string[6 + 4]);
}
return Einsum(x, y, string, precision, preferred_element_type);
return Einsum(x, y, string, precision, preferred_element_type, grad_x,
grad_y);
});
}

Expand Down Expand Up @@ -709,20 +729,23 @@ std::string NormalizeEinsumString(absl::string_view einsum_config) {

XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
PrecisionConfig::Precision precision,
std::optional<PrimitiveType> preferred_element_type) {
std::optional<PrimitiveType> preferred_element_type,
bool grad_x, bool grad_y) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto new_config = NormalizeEinsumString(einsum_config);
if (!new_config.empty()) {
return Einsum(x, y, new_config, precision, preferred_element_type);
return Einsum(x, y, new_config, precision, preferred_element_type, grad_x,
grad_y);
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
TF_ASSIGN_OR_RETURN(
auto einsum_config_numeric,
ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
einsum_config_numeric[2], precision, preferred_element_type);
einsum_config_numeric[2], precision, preferred_element_type,
grad_x, grad_y);
});
}

Expand Down
7 changes: 5 additions & 2 deletions tensorflow/compiler/xla/client/lib/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ xla::XlaOp BatchDot(
xla::XlaOp BatchDot(
xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
bool grad_x = false, bool grad_y = false,
std::optional<PrimitiveType> preferred_element_type = std::nullopt);

// Parse an einsum string into dimension numbers:
Expand Down Expand Up @@ -128,7 +129,8 @@ std::string NormalizeEinsumString(absl::string_view einsum_config);
xla::XlaOp Einsum(
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
std::optional<PrimitiveType> preferred_element_type = std::nullopt);
std::optional<PrimitiveType> preferred_element_type = std::nullopt,
bool grad_x = false, bool grad_y = false);
xla::XlaOp Einsum(
xla::XlaOp x, absl::string_view einsum_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
Expand All @@ -143,7 +145,8 @@ xla::XlaOp Einsum(
xla::XlaOp x, absl::Span<const int64_t> x_config, xla::XlaOp y,
absl::Span<const int64_t> y_config, absl::Span<const int64_t> output_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
std::optional<PrimitiveType> preferred_element_type = std::nullopt);
std::optional<PrimitiveType> preferred_element_type = std::nullopt,
bool grad_x = false, bool grad_y = false);

// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/examples/axpy/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary")

# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])

xla_cc_test(
tf_cc_test(
name = "stablehlo_compile_test",
srcs = ["stablehlo_compile_test.cc"],
data = ["stablehlo_axpy.mlir"],
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/hlo/evaluator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library")
load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test","tf_cc_binary")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -81,7 +82,7 @@ cc_library(
],
)

xla_cc_test(
tf_cc_test(
name = "hlo_evaluator_test",
srcs = ["hlo_evaluator_test.cc"],
deps = [
Expand Down
Loading

0 comments on commit 5b0eb43

Please sign in to comment.