Skip to content

Commit

Permalink
Merge branch 'main' into f8E4M3_f8E3M4
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored Aug 31, 2024
2 parents 356dc4b + 21dcdd2 commit 7f4c57d
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 34 deletions.
17 changes: 17 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_create_compatibility_expander_inc_gen",
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
)

cc_library(
name = "interpreter_ops",
srcs = [
Expand Down Expand Up @@ -1086,6 +1101,7 @@ cc_library(
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloConvertToSignless.cpp",
"stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
Expand All @@ -1109,6 +1125,7 @@ cc_library(
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
Expand Down
43 changes: 43 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,49 @@ value, then the operation can be transformed to ReshapeOp.

_Pass to transform the IR to be on signless integers._

### `-stablehlo-create-compatibility-expander`

_Create compatibility expander for StableHLO operations._

StableHLO ops gets updates or new op is introduced in the latest versions.
This opt-in pass expands backward compatibility with older StableHLO
versions by decomposing newer StableHLO operations into equivalent
operations supported by those older versions.

Why is this an opt-in pass?

Occasionally, StableHLO op enhancements are used to greatly simplify the
handling of certain common patterns in the OpenXLA ecosystem. This
includes things like TanOp, which has high framework and compiler support,
as well as gather/scatter batching dimensions, which can be represented
using slices, but makes sharding much more difficult. For this category of
new features, we do not offer automatic downgrade, since it may throw away
important information used in subsequent optimizations. This pass can be
used to expand these ops based on a target version to maximize compatibility
at the expense of potentially less optimal compilation.

```mlir
func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%1 = stablehlo.tan %arg0 : tensor<4xf64>
func.return %1 : tensor<4xf64>
}
```

will become:

```mlir
func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%0 = stablehlo.sine %arg0 : tensor<4xf64>
%1 = stablehlo.cosine %arg0 : tensor<4xf64>
%2 = stablehlo.divide %0, %1 : tensor<4xf64>
return %2 : tensor<4xf64>
}
```

#### Options
```
-target : The target version. Must be a version of the form #.#.#.
```
### `-stablehlo-legalize-composite-to-call`

_Replaces composite ops with a call to their decomposition_
Expand Down
34 changes: 0 additions & 34 deletions stablehlo/integrations/python/StablehloApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,40 +189,6 @@ void AddPortableApi(py::module &m) {
return accumulator.toString();
});

//
// Serialization APIs (deprecated, use _str methods).
//
m.def(
"serialize_portable_artifact",
[](std::string_view moduleStrOrBytecode,
std::string_view targetVersion) -> py::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact(
toMlirStringRef(moduleStrOrBytecode),
toMlirStringRef(targetVersion),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
return "";
}
return py::bytes(accumulator.toString());
},
py::arg("module_str"), py::arg("target_version"));

m.def(
"deserialize_portable_artifact",
[](std::string_view artifact) -> py::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
return "";
}
return py::bytes(accumulator.toString());
},
py::arg("artifact_str"));

//
// Serialization APIs.
//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE

// -----

// CHECK-LABEL @tan_op_non_complex
// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
// CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
%1 = stablehlo.tan %arg0 : tensor<4xf64>
func.return %1 : tensor<4xf64>
}

// -----

// CHECK-LABEL: @tan_op_complex
// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
%0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
// CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
%1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
%2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
%3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
}
7 changes: 7 additions & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td)
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(ChloDecompositionPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
Expand All @@ -28,6 +32,7 @@ set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(VhloToVersionPatterns)


add_mlir_dialect_library(StablehloPasses
PARTIAL_SOURCES_INTENDED
ChloLegalizeToStablehlo.cpp
Expand All @@ -37,6 +42,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloAggressiveSimplification.cpp
StablehloCanonicalizeDynamism.cpp
StablehloConvertToSignless.cpp
StablehloCreateCompatibilityExpander.cpp
StablehloLegalizeCompositeToCall.cpp
StablehloLegalizeDeprecatedOps.cpp
StablehloLegalizeQuantToMath.cpp
Expand All @@ -53,6 +59,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloLegalizeDeprecatedOpsPatternsIncGen
PassesIncGen
VhloToVersionPatterns
StablehloCreateCompatibilityExpanderPatternsIncGen

LINK_LIBS PUBLIC
ChloOps
Expand Down
7 changes: 7 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {
Expand Down Expand Up @@ -96,6 +97,12 @@ void populateStablehloLegalizeDeprecatedOpsPatterns(
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of patterns to create compatibility expander for StableHLO
/// operations.
void populateStablehloCreateCompatibilityExpanderPatterns(
RewritePatternSet *patterns, MLIRContext *context,
vhlo::Version targetVersion);

//// Additional pass constructors ////

std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
Expand Down
48 changes: 48 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,51 @@ def StablehloLegalizeQDQToQuantizedOpPass : Pass<"stablehlo-legalize-qdq-to-quan
"mlir::stablehlo::StablehloDialect",
];
}

def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
let summary = "Create compatibility expander for StableHLO operations.";

let description = [{
StableHLO ops gets updates or new op is introduced in the latest versions.
This opt-in pass expands backward compatibility with older StableHLO
versions by decomposing newer StableHLO operations into equivalent
operations supported by those older versions.

Why is this an opt-in pass?

Occasionally, StableHLO op enhancements are used to greatly simplify the
handling of certain common patterns in the OpenXLA ecosystem. This
includes things like TanOp, which has high framework and compiler support,
as well as gather/scatter batching dimensions, which can be represented
using slices, but makes sharding much more difficult. For this category of
new features, we do not offer automatic downgrade, since it may throw away
important information used in subsequent optimizations. This pass can be
used to expand these ops based on a target version to maximize compatibility
at the expense of potentially less optimal compilation.

```mlir
func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%1 = stablehlo.tan %arg0 : tensor<4xf64>
func.return %1 : tensor<4xf64>
}
```

will become:

```mlir
func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%0 = stablehlo.sine %arg0 : tensor<4xf64>
%1 = stablehlo.cosine %arg0 : tensor<4xf64>
%2 = stablehlo.divide %0, %1 : tensor<4xf64>
return %2 : tensor<4xf64>
}
```
}];
let options = [
Option<"targetVersionOption", "target", "std::string", "",
"The target version. Must be a version of the form #.#.#.">,
];
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
];
}
Loading

0 comments on commit 7f4c57d

Please sign in to comment.