Skip to content

Commit

Permalink
CHLO -> StableHLO : use TanOp and StablehloCreateCompatibilityExpande…
Browse files Browse the repository at this point in the history
…rPass

This should unblock jax-ml/jax#23261

PiperOrigin-RevId: 674567980
  • Loading branch information
abhigunj authored and Google-ML-Automation committed Sep 14, 2024
1 parent 0a79c44 commit 5a07f58
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 49 deletions.
339 changes: 339 additions & 0 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct ChloLegalizeToHighLevelMhloPass
// Consider the mhlo dialect legal for tests. Also add helper dialects
// that are needed by the patterns.
conversionTarget.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect>();
conversionTarget.addIllegalOp<chlo::TopKOp, chlo::ErfOp, chlo::TanOp>();
conversionTarget.addIllegalOp<chlo::TopKOp, chlo::ErfOp>();

if (failed(applyPartialConversion(getOperation(), conversionTarget,
std::move(conversionPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ include "stablehlo/dialect/ChloOps.td"
// Direct CHLO->MHLO conversions
//===----------------------------------------------------------------------===//

def : Pat<(CHLO_TanOp $v),
(MHLO_TanOp $v),
[], [], (addBenefit 10)>;

def : Pat<(CHLO_ErfOp $v),
(MHLO_ErfOp $v),
[], [], (addBenefit 10)>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) {
// fit for StableHLO, and are usually accompanied by a StableHLO GitHub ticket.
template <typename HloOpTy>
std::optional<int64_t> getPublicFeaturesNotInStablehlo(HloOpTy hloOp) {
// StableHLO doesn't support TanOp yet.
// Proposal: https://github.com/openxla/stablehlo/issues/954
if constexpr (std::is_same<HloOpTy, mhlo::TanOp>::value) {
// Version 1: Initial version for TanOp.
return 1;
}
// StableHLO doesn't support TopK yet.
// Proposal: https://github.com/openxla/stablehlo/pull/1593
if constexpr (std::is_same<HloOpTy, mhlo::TopKOp>::value) {
Expand Down Expand Up @@ -460,8 +454,7 @@ LogicalResult convertAttributes(ConversionPatternRewriter& rewriter,
}

// Handle DenseElements --> DenseArray for certain StableHLO ops
if constexpr (!std::is_same<HloOpTy, mhlo::TanOp>::value &&
!std::is_same<HloOpTy, mhlo::ErfOp>::value &&
if constexpr (!std::is_same<HloOpTy, mhlo::ErfOp>::value &&
!std::is_same<HloOpTy, mhlo::TopKOp>::value) {
if (!stablehloAttr)
stablehloAttr = convertDenseArray<HloToStablehloOp<HloOpTy>>(
Expand Down Expand Up @@ -729,8 +722,7 @@ void populateHloToStablehloPatterns(RewritePatternSet* patterns,
#include "stablehlo/dialect/StablehloOps.cpp.inc"
>(patterns, converter, context, allowExperimentalFeatures);

populateHloToStablehloCustomCallPatterns<mhlo::TanOp, mhlo::TopKOp,
mhlo::ErfOp>(
populateHloToStablehloCustomCallPatterns<mhlo::TopKOp, mhlo::ErfOp>(
patterns, converter, context, allowExperimentalFeatures);
}

Expand Down
24 changes: 0 additions & 24 deletions xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2845,30 +2845,6 @@ func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32>

// -----

// CHECK-LABEL: @tan_f16
// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
func.func @tan_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK-HIGH-LEVEL: mhlo.tan
// CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor<f16>
// CHECK: return %[[RESULT]]
%1 = chlo.tan %arg : tensor<f16> -> tensor<f16>
func.return %1 : tensor<f16>
}

// -----

// CHECK-LABEL: @tan_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func.func @tan_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK-HIGH-LEVEL: mhlo.tan
// CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor<f32>
// CHECK: return %[[RESULT]]
%1 = chlo.tan %arg : tensor<f32> -> tensor<f32>
func.return %1 : tensor<f32>
}

// -----

// CHECK-LABEL: @top_k
// CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>)
func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1620,15 +1620,6 @@ func.func @op_subtract(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "op_tan"
func.func @op_tan(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: "stablehlo.custom_call"([[ARG0:%arg[0-9]+]]) <{
// CHECK-SAME: call_target_name = "mhlo.tan"}> {mhlo.attributes = {}, mhlo.version = 1 : i64}
// CHECK-SAME: (tensor<f32>) -> tensor<f32>
%0 = "mhlo.tan"(%arg0) : (tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "op_tanh"
func.func @op_tanh(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: "stablehlo.tanh"([[ARG0:%arg[0-9]+]]) : (tensor<f32>) -> tensor<f32>
Expand Down
6 changes: 5 additions & 1 deletion xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::mhlo::createChloLegalizeToHighLevelMhloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass(
{std::string(target)}));
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createShapeLegalizeToStablehloPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
Expand Down Expand Up @@ -264,7 +269,6 @@ absl::StatusOr<std::string> Serialize(mlir::ModuleOp module,
if (!llvm::isa<mlir::ModuleOp>(op) &&
!llvm::isa<mlir::stablehlo::StablehloDialect, mlir::func::FuncDialect,
mlir::chlo::ChloDialect>(op->getDialect())) {
std::cout << op->getDialect()->getNamespace().str() << "\n";
all_stablehlo = false;
return mlir::WalkResult::interrupt();
}
Expand Down
17 changes: 17 additions & 0 deletions xla/pjrt/mlir_to_hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ TEST(MlirToHloTest, ChloTest) {
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
}

TEST(MlirToHloTest, ChloTanOpTest) {
constexpr char kProgram[] =
R"(
func.func @add(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = chlo.tan %arg0 : tensor<1x2xf32> -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
)";
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));

// CHLO decomposes to StableHLO, so uses VHLO serialization.
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
}

TEST(MlirToHloTest, MhloTest) {
constexpr char kProgram[] =
R"(
Expand Down

0 comments on commit 5a07f58

Please sign in to comment.