From cd5f5e051f1f10acbf1c179ffc59af67f493d88c Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 17 Sep 2024 12:07:55 -0700 Subject: [PATCH] Integrate StableHLO at openxla/stablehlo@78c753ad PiperOrigin-RevId: 675658944 --- third_party/stablehlo/temporary.patch | 384 -------------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 386 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index f803c97b5f40b..8b137891791fe 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,385 +1 @@ -diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ---- stablehlo/BUILD.bazel -+++ stablehlo/BUILD.bazel -@@ -1112,11 +1112,13 @@ - "stablehlo/transforms/StablehloRefineShapes.cpp", - "stablehlo/transforms/VhloLegalizeToStablehlo.cpp", - "stablehlo/transforms/VhloToVersion.cpp", -+ "stablehlo/transforms/passes_utils.cpp", - ], - hdrs = [ - "stablehlo/transforms/MapStablehloToVhlo.h", - "stablehlo/transforms/Passes.h", - "stablehlo/transforms/StablehloRefineShapes.h", -+ "stablehlo/transforms/passes_utils.h", - ], - strip_include_prefix = ".", - deps = [ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -@@ -121,8 +121,8 @@ - - // CHECK-LABEL: @transpose - func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -- // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> -- // CHECK-DAG: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] -+ // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -+ // CHECK: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] - %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> - return %0 : tensor<3x2x1xf32> - } -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp -@@ -451,9 +451,10 @@ - - auto perms = op.getPermutation(); - auto type = RankedTensorType::get({static_cast(perms.size())}, -- rewriter.getI64Type()); -+ rewriter.getI32Type()); -+ std::vector perms_int32(perms.begin(), perms.end()); - auto constOp = rewriter.create( -- op->getLoc(), type, DenseIntElementsAttr::get(type, perms)); -+ op->getLoc(), type, DenseIntElementsAttr::get(type, perms_int32)); - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand(), constOp); - return success(); -diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp ---- stablehlo/stablehlo/dialect/Version.cpp -+++ stablehlo/stablehlo/dialect/Version.cpp -@@ -80,9 +80,9 @@ - case CompatibilityRequirement::NONE: - return Version::getCurrentVersion(); - case CompatibilityRequirement::WEEK_4: -- return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024 -+ return Version(1, 5, 0); // v1.3.0 - Aug 1, 2024 - case CompatibilityRequirement::WEEK_12: -- return Version(1, 0, 0); // v1.0.0 - May 14, 2024 -+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024 - case CompatibilityRequirement::MAX: - return Version::getMinimumVersion(); - } -diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -@@ -3029,28 +3029,11 @@ - - // ----- - --// CHECK-LABEL: @tan_f16 --// CHECK-SAME: (%[[ARG:.*]]: tensor) --func.func @tan_f16(%arg : tensor) -> tensor { -- // %[[TMP_0:.*]] = stablehlo.convert [[ARG]] : (tensor) -> tensor -- // %[[TMP_1:.*]] = stablehlo.sine %[[TMP_0]] -- // %[[TMP_2:.*]] = stablehlo.cosine %[[TMP_0]] -- // %[[TMP_3:.*]] = stablehlo.divide %[[TMP_1]], %[[TMP_2]] -- // %[[TMP_4:.*]] = stablehlo.convert %[[TMP_3]] : (tensor) -> tensor -- // return %[[TMP_4]] : tensor -- %1 = chlo.tan %arg : tensor -> tensor -- func.return %1 : tensor --} -- --// ----- -- - // CHECK-LABEL: @tan_f32 - // CHECK-SAME: (%[[ARG:.*]]: tensor) - func.func @tan_f32(%arg : tensor) -> tensor { -- // %[[TMP_0:.*]] = stablehlo.sine %[[ARG]] -- // %[[TMP_1:.*]] = stablehlo.cosine %[[ARG]] -- // %[[TMP_2:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_1]] -- // return %[[TMP_2]] : tensor -+ // CHECK: %[[TMP_0:.*]] = stablehlo.tan %[[ARG]] : tensor -+ // CHECK: return %[[TMP_0]] : tensor - %1 = chlo.tan %arg : tensor -> tensor - func.return %1 : tensor - } -@@ -3060,22 +3043,11 @@ - // CHECK-LABEL: @tan_complexf32 - // CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32> - func.func @tan_complexf32(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { -- // CHECK: %[[COMPLEX:.+]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> -- // CHECK: %[[REAL:.+]] = stablehlo.real %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> -- // CHECK: %[[SINE:.+]] = stablehlo.sine %[[REAL]] -- // CHECK: %[[COS:.+]] = stablehlo.cosine %[[REAL]] -- // CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]] -- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> -- // CHECK: %[[TANH:.+]] = stablehlo.tanh %[[IMAG]] -- // CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]] -- // CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> -- // CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]] -- // CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]] -- // CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]] -- // CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]] -- // CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]] -- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]] -- // CHECK: return %[[REAL]], %[[IMAG]] -+ // CHECK: %[[TMP_0:.*]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> -+ // CHECK: %[[TMP_1:.*]] = stablehlo.tan %[[TMP_0]] : tensor<1xcomplex> -+ // CHECK: %[[TMP_2:.*]] = stablehlo.real %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> -+ // CHECK: %[[TMP_3:.*]] = stablehlo.imag %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> -+ // CHECK: return %[[TMP_2]], %[[TMP_3]] : tensor<1xf32>, tensor<1xf32> - %0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex> - %1 = chlo.tan %0 : tensor<1xcomplex> -> tensor<1xcomplex> - %2 = stablehlo.real %1 : (tensor<1xcomplex>) -> tensor<1xf32> -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -1,5 +1,5 @@ --// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK --// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE -+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK -+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE - - // ----- - -diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt ---- stablehlo/stablehlo/transforms/CMakeLists.txt -+++ stablehlo/stablehlo/transforms/CMakeLists.txt -@@ -53,6 +53,7 @@ - StablehloRefineShapes.cpp - VhloLegalizeToStablehlo.cpp - VhloToVersion.cpp -+ passes_utils.cpp - - DEPENDS - ChloDecompositionPatternsIncGen -diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td ---- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -+++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -@@ -109,26 +109,9 @@ - (STABLEHLO_DEFAULT_COMPARISON_TYPE) - )>; - --// Express `tan` as --// sine(x) / cosine(x) --def : Pat<(CHLO_TanOp NonComplexElementType:$input), -- (StableHLO_DivOp -- (StableHLO_SineOp $input), -- (StableHLO_CosineOp $input) -- )>; - -- --// Express `tan(a + bi)` as --// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) --def : Pat<(CHLO_TanOp ComplexElementType:$input), -- (StableHLO_DivOp -- (StableHLO_ComplexOp -- (CHLO_TanOp:$tan (StableHLO_RealOp $input)), -- (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), -- (StableHLO_ComplexOp -- (StableHLO_ConstantLike<"1.0"> $tan), -- (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) -- )>; -+def : Pat<(CHLO_TanOp $input), -+ (StableHLO_TanOp $input)>; - - def : Pat<(CHLO_ConstantOp $v), - (StableHLO_ConstantOp $v)>; -diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp -@@ -49,6 +49,7 @@ - #include "stablehlo/dialect/ChloOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/transforms/Passes.h" -+#include "stablehlo/transforms/passes_utils.h" - - namespace mlir { - namespace stablehlo { -@@ -169,29 +170,6 @@ - patterns->add>( - context, args...); --} -- --template --static Value getConstantLike(OpBuilder &b, Location loc, T constant, -- Value val) { -- Type ty = getElementTypeOrSelf(val.getType()); -- auto getAttr = [&]() -> Attribute { -- if (isa(ty)) return b.getIntegerAttr(ty, constant); -- if (isa(ty)) return b.getFloatAttr(ty, constant); -- if (auto complexTy = dyn_cast(ty)) { -- return complex::NumberAttr::get(complexTy, constant, 0); -- } -- llvm_unreachable("unhandled element type"); -- }; -- return b.create(loc, cast(getAttr()), -- val); --} -- --static Value getConstantLike(OpBuilder &b, Location loc, -- const APFloat &constant, Value val) { -- Type ty = getElementTypeOrSelf(val.getType()); -- return b.create(loc, b.getFloatAttr(ty, constant), -- val); - } - - static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -338,5 +338,6 @@ - ]; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", -- ]; --} -+ "mlir::chlo::ChloDialect", -+ ]; -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -26,6 +26,7 @@ - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/dialect/Version.h" - #include "stablehlo/transforms/Passes.h" -+#include "stablehlo/transforms/passes_utils.h" - - namespace mlir { - namespace stablehlo { -@@ -37,21 +38,6 @@ - //===----------------------------------------------------------------------===// - // Helpers. - //===----------------------------------------------------------------------===// -- --// Creates a constant with all ones. --static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { -- if (!isa(getElementTypeOrSelf(val))) -- llvm_unreachable("Unsupported element type, expecting float"); -- -- auto shapedTy = dyn_cast(val.getType()); -- if (!shapedTy) llvm_unreachable("Unsupported shaped type."); -- -- mlir::DenseElementsAttr elementsAttr = -- mlir::DenseElementsAttr::get(shapedTy, 1.0); -- -- return b.create(loc, val.getType(), -- elementsAttr); --} - - // Check user-specified target version. - vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td -@@ -24,7 +24,8 @@ - CPred<"!isa(cast($_self).getElementType())">, - "Non-complex element type">; - --def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; -+class StableHLO_ConstantLike : NativeCodeCall< -+ "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; - - // Express `tan` as - // sine(x) / cosine(x) -@@ -42,6 +43,6 @@ - (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), - (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), - (StableHLO_ComplexOp -- (createConstantWithAllOnes $tan), -+ (StableHLO_ConstantLike<"1.0"> $tan), - (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) - )>; -diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.cpp b/stablehlo/stablehlo/transforms/passes_utils.cpp ---- stablehlo/stablehlo/transforms/passes_utils.cpp -+++ stablehlo/stablehlo/transforms/passes_utils.cpp -@@ -0,0 +1,34 @@ -+/* Copyright 2024 The StableHLO Authors. All Rights Reserved. -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ http://www.apache.org/licenses/LICENSE-2.0 -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "stablehlo/transforms/passes_utils.h" -+ -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/Location.h" -+#include "mlir/IR/TypeUtilities.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" -+#include "mlir/Support/LLVM.h" -+#include "stablehlo/dialect/ChloOps.h" -+ -+namespace mlir { -+namespace stablehlo { -+ -+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, -+ Value val) { -+ Type ty = getElementTypeOrSelf(val.getType()); -+ return b.create(loc, b.getFloatAttr(ty, constant), -+ val); -+} -+ -+} // namespace stablehlo -+} // namespace mlir -diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.h b/stablehlo/stablehlo/transforms/passes_utils.h ---- stablehlo/stablehlo/transforms/passes_utils.h -+++ stablehlo/stablehlo/transforms/passes_utils.h -@@ -0,0 +1,57 @@ -+/* Copyright 2024 The StableHLO Authors. All Rights Reserved. -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ http://www.apache.org/licenses/LICENSE-2.0 -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ -+#define THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ -+ -+#include "llvm/Support/ErrorHandling.h" -+#include "mlir/Dialect/Complex/IR/Complex.h" -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/BuiltinAttributeInterfaces.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Location.h" -+#include "mlir/IR/TypeUtilities.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" -+#include "mlir/Support/LLVM.h" -+#include "stablehlo/dialect/ChloOps.h" -+ -+namespace mlir { -+namespace stablehlo { -+// Add utility functions common across passes. -+ -+// Creates a chlo::ConstantLikeOp using a splat `constant` of the same shape -+// as `val`. -+template -+Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) { -+ Type ty = getElementTypeOrSelf(val.getType()); -+ auto getAttr = [&]() -> Attribute { -+ if (isa(ty)) return b.getIntegerAttr(ty, constant); -+ if (isa(ty)) return b.getFloatAttr(ty, constant); -+ if (auto complexTy = dyn_cast(ty)) { -+ return complex::NumberAttr::get(complexTy, constant, 0); -+ } -+ llvm_unreachable("unhandled element type"); -+ }; -+ return b.create(loc, cast(getAttr()), -+ val); -+} -+ -+// Creates a chlo::ConstantLikeOp using a APFloat splat `constant` of the -+// same shape as `val`. -+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, -+ Value val); -+ -+} // namespace stablehlo -+} // namespace mlir -+ -+#endif // THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 0bdf74f3882cc..97fd0b990fc1c 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "e51fd95e5b2c28861f22dc9d609fb2a7f002124e" - STABLEHLO_SHA256 = "72007352d60cc42784908263ccf171f8721a40adf92a17e13b1b6893e986b8b6" + STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" + STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" # LINT.ThenChange(Google-internal path) tf_http_archive(