Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@78c753ad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675658944
  • Loading branch information
sdasgup3 authored and Google-ML-Automation committed Sep 17, 2024
1 parent 161e7db commit cd5f5e0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 386 deletions.
384 changes: 0 additions & 384 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,385 +1 @@
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
--- stablehlo/BUILD.bazel
+++ stablehlo/BUILD.bazel
@@ -1112,11 +1112,13 @@
"stablehlo/transforms/StablehloRefineShapes.cpp",
"stablehlo/transforms/VhloLegalizeToStablehlo.cpp",
"stablehlo/transforms/VhloToVersion.cpp",
+ "stablehlo/transforms/passes_utils.cpp",
],
hdrs = [
"stablehlo/transforms/MapStablehloToVhlo.h",
"stablehlo/transforms/Passes.h",
"stablehlo/transforms/StablehloRefineShapes.h",
+ "stablehlo/transforms/passes_utils.h",
],
strip_include_prefix = ".",
deps = [
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir
@@ -121,8 +121,8 @@

// CHECK-LABEL: @transpose
func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
- // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
- // CHECK-DAG: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]]
+ // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ // CHECK: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]]
%0 = "stablehlo.transpose"(%arg0) {permutation = array<i64: 2, 1, 0>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}
diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
+++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
@@ -451,9 +451,10 @@

auto perms = op.getPermutation();
auto type = RankedTensorType::get({static_cast<int64_t>(perms.size())},
- rewriter.getI64Type());
+ rewriter.getI32Type());
+ std::vector<int32_t> perms_int32(perms.begin(), perms.end());
auto constOp = rewriter.create<tosa::ConstOp>(
- op->getLoc(), type, DenseIntElementsAttr::get(type, perms));
+ op->getLoc(), type, DenseIntElementsAttr::get(type, perms_int32));
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(op, op.getType(),
op.getOperand(), constOp);
return success();
diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp
--- stablehlo/stablehlo/dialect/Version.cpp
+++ stablehlo/stablehlo/dialect/Version.cpp
@@ -80,9 +80,9 @@
case CompatibilityRequirement::NONE:
return Version::getCurrentVersion();
case CompatibilityRequirement::WEEK_4:
- return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024
+ return Version(1, 5, 0); // v1.3.0 - Aug 1, 2024
case CompatibilityRequirement::WEEK_12:
- return Version(1, 0, 0); // v1.0.0 - May 14, 2024
+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024
case CompatibilityRequirement::MAX:
return Version::getMinimumVersion();
}
diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
+++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
@@ -3029,28 +3029,11 @@

// -----

-// CHECK-LABEL: @tan_f16
-// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
-func.func @tan_f16(%arg : tensor<f16>) -> tensor<f16> {
- // %[[TMP_0:.*]] = stablehlo.convert [[ARG]] : (tensor<f16>) -> tensor<f32>
- // %[[TMP_1:.*]] = stablehlo.sine %[[TMP_0]]
- // %[[TMP_2:.*]] = stablehlo.cosine %[[TMP_0]]
- // %[[TMP_3:.*]] = stablehlo.divide %[[TMP_1]], %[[TMP_2]]
- // %[[TMP_4:.*]] = stablehlo.convert %[[TMP_3]] : (tensor<f32>) -> tensor<f16>
- // return %[[TMP_4]] : tensor<f16>
- %1 = chlo.tan %arg : tensor<f16> -> tensor<f16>
- func.return %1 : tensor<f16>
-}
-
-// -----
-
// CHECK-LABEL: @tan_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func.func @tan_f32(%arg : tensor<f32>) -> tensor<f32> {
- // %[[TMP_0:.*]] = stablehlo.sine %[[ARG]]
- // %[[TMP_1:.*]] = stablehlo.cosine %[[ARG]]
- // %[[TMP_2:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_1]]
- // return %[[TMP_2]] : tensor<f32>
+ // CHECK: %[[TMP_0:.*]] = stablehlo.tan %[[ARG]] : tensor<f32>
+ // CHECK: return %[[TMP_0]] : tensor<f32>
%1 = chlo.tan %arg : tensor<f32> -> tensor<f32>
func.return %1 : tensor<f32>
}
@@ -3060,22 +3043,11 @@
// CHECK-LABEL: @tan_complexf32
// CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>
func.func @tan_complexf32(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) {
- // CHECK: %[[COMPLEX:.+]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex<f32>>
- // CHECK: %[[REAL:.+]] = stablehlo.real %[[COMPLEX]] : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
- // CHECK: %[[SINE:.+]] = stablehlo.sine %[[REAL]]
- // CHECK: %[[COS:.+]] = stablehlo.cosine %[[REAL]]
- // CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]]
- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[COMPLEX]] : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
- // CHECK: %[[TANH:.+]] = stablehlo.tanh %[[IMAG]]
- // CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]]
- // CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
- // CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]]
- // CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]]
- // CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]]
- // CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]]
- // CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]]
- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]]
- // CHECK: return %[[REAL]], %[[IMAG]]
+ // CHECK: %[[TMP_0:.*]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex<f32>>
+ // CHECK: %[[TMP_1:.*]] = stablehlo.tan %[[TMP_0]] : tensor<1xcomplex<f32>>
+ // CHECK: %[[TMP_2:.*]] = stablehlo.real %[[TMP_1]] : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
+ // CHECK: %[[TMP_3:.*]] = stablehlo.imag %[[TMP_1]] : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
+ // CHECK: return %[[TMP_2]], %[[TMP_3]] : tensor<1xf32>, tensor<1xf32>
%0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex<f32>>
%1 = chlo.tan %0 : tensor<1xcomplex<f32>> -> tensor<1xcomplex<f32>>
%2 = stablehlo.real %1 : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
@@ -1,5 +1,5 @@
-// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
-// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE

// -----

diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
--- stablehlo/stablehlo/transforms/CMakeLists.txt
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
@@ -53,6 +53,7 @@
StablehloRefineShapes.cpp
VhloLegalizeToStablehlo.cpp
VhloToVersion.cpp
+ passes_utils.cpp

DEPENDS
ChloDecompositionPatternsIncGen
diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
--- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
+++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
@@ -109,26 +109,9 @@
(STABLEHLO_DEFAULT_COMPARISON_TYPE)
)>;

-// Express `tan` as
-// sine(x) / cosine(x)
-def : Pat<(CHLO_TanOp NonComplexElementType:$input),
- (StableHLO_DivOp
- (StableHLO_SineOp $input),
- (StableHLO_CosineOp $input)
- )>;

-
-// Express `tan(a + bi)` as
-// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
-def : Pat<(CHLO_TanOp ComplexElementType:$input),
- (StableHLO_DivOp
- (StableHLO_ComplexOp
- (CHLO_TanOp:$tan (StableHLO_RealOp $input)),
- (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
- (StableHLO_ComplexOp
- (StableHLO_ConstantLike<"1.0"> $tan),
- (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
- )>;
+def : Pat<(CHLO_TanOp $input),
+ (StableHLO_TanOp $input)>;

def : Pat<(CHLO_ConstantOp $v),
(StableHLO_ConstantOp $v)>;
diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
@@ -49,6 +49,7 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/Passes.h"
+#include "stablehlo/transforms/passes_utils.h"

namespace mlir {
namespace stablehlo {
@@ -169,29 +170,6 @@
patterns->add<Pattern<mlir::chlo::BroadcastCompareOp,
mlir::stablehlo::CompareOp, HloCompareAdaptor>>(
context, args...);
-}
-
-template <typename T>
-static Value getConstantLike(OpBuilder &b, Location loc, T constant,
- Value val) {
- Type ty = getElementTypeOrSelf(val.getType());
- auto getAttr = [&]() -> Attribute {
- if (isa<IntegerType>(ty)) return b.getIntegerAttr(ty, constant);
- if (isa<FloatType>(ty)) return b.getFloatAttr(ty, constant);
- if (auto complexTy = dyn_cast<ComplexType>(ty)) {
- return complex::NumberAttr::get(complexTy, constant, 0);
- }
- llvm_unreachable("unhandled element type");
- };
- return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
- val);
-}
-
-static Value getConstantLike(OpBuilder &b, Location loc,
- const APFloat &constant, Value val) {
- Type ty = getElementTypeOrSelf(val.getType());
- return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
- val);
}

static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc,
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
--- stablehlo/stablehlo/transforms/Passes.td
+++ stablehlo/stablehlo/transforms/Passes.td
@@ -338,5 +338,6 @@
];
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
- ];
-}
+ "mlir::chlo::ChloDialect",
+ ];
+}
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
@@ -26,6 +26,7 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"
+#include "stablehlo/transforms/passes_utils.h"

namespace mlir {
namespace stablehlo {
@@ -37,21 +38,6 @@
//===----------------------------------------------------------------------===//
// Helpers.
//===----------------------------------------------------------------------===//
-
-// Creates a constant with all ones.
-static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
- if (!isa<FloatType>(getElementTypeOrSelf(val)))
- llvm_unreachable("Unsupported element type, expecting float");
-
- auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
- if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
-
- mlir::DenseElementsAttr elementsAttr =
- mlir::DenseElementsAttr::get(shapedTy, 1.0);
-
- return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
- elementsAttr);
-}

// Check user-specified target version.
vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
@@ -24,7 +24,8 @@
CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
"Non-complex element type">;

-def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
+class StableHLO_ConstantLike<string value> : NativeCodeCall<
+ "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;

// Express `tan` as
// sine(x) / cosine(x)
@@ -42,6 +43,6 @@
(StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
(StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
(StableHLO_ComplexOp
- (createConstantWithAllOnes $tan),
+ (StableHLO_ConstantLike<"1.0"> $tan),
(StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
)>;
diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.cpp b/stablehlo/stablehlo/transforms/passes_utils.cpp
--- stablehlo/stablehlo/transforms/passes_utils.cpp
+++ stablehlo/stablehlo/transforms/passes_utils.cpp
@@ -0,0 +1,34 @@
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "stablehlo/transforms/passes_utils.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "stablehlo/dialect/ChloOps.h"
+
+namespace mlir {
+namespace stablehlo {
+
+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
+ Value val) {
+ Type ty = getElementTypeOrSelf(val.getType());
+ return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
+ val);
+}
+
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.h b/stablehlo/stablehlo/transforms/passes_utils.h
--- stablehlo/stablehlo/transforms/passes_utils.h
+++ stablehlo/stablehlo/transforms/passes_utils.h
@@ -0,0 +1,57 @@
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_
+#define THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_
+
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "stablehlo/dialect/ChloOps.h"
+
+namespace mlir {
+namespace stablehlo {
+// Add utility functions common across passes.
+
+// Creates a chlo::ConstantLikeOp using a splat `constant` of the same shape
+// as `val`.
+template <typename T>
+Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) {
+ Type ty = getElementTypeOrSelf(val.getType());
+ auto getAttr = [&]() -> Attribute {
+ if (isa<IntegerType>(ty)) return b.getIntegerAttr(ty, constant);
+ if (isa<FloatType>(ty)) return b.getFloatAttr(ty, constant);
+ if (auto complexTy = dyn_cast<ComplexType>(ty)) {
+ return complex::NumberAttr::get(complexTy, constant, 0);
+ }
+ llvm_unreachable("unhandled element type");
+ };
+ return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
+ val);
+}
+
+// Creates a chlo::ConstantLikeOp using a APFloat splat `constant` of the
+// same shape as `val`.
+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
+ Value val);
+
+} // namespace stablehlo
+} // namespace mlir
+
+#endif // THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "e51fd95e5b2c28861f22dc9d609fb2a7f002124e"
STABLEHLO_SHA256 = "72007352d60cc42784908263ccf171f8721a40adf92a17e13b1b6893e986b8b6"
STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7"
STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down

0 comments on commit cd5f5e0

Please sign in to comment.