Skip to content

Commit

Permalink
[JAX] Remove uses of dialect="mhlo" from the JAX compiler_ir() functi…
Browse files Browse the repository at this point in the history
…on (#28)

JAX no longer uses MHLO itself, and will drop support shortly. JAX emits stablehlo. Either use that stablehlo or convert it to mhlo.

Co-authored-by: Peter Hawkins <[email protected]>
  • Loading branch information
wsmoses and hawkinsp authored Jan 16, 2024
1 parent f55975c commit 13f0c63
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/client/client_library.h"
#include "xla/client/executable_build_options.h"
#include "xla/client/xla_computation.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/printer.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/cpu/cpu_executable.h"
Expand All @@ -36,10 +39,17 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
context.loadDialect<mlir::stablehlo::StablehloDialect>();
mlir::ParserConfig parser_config(&context);
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mhlo_text, parser_config);

mlir::PassManager pm(&context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
if (!mlir::succeeded(pm.run(*parsed_module))) {
throw pybind11::value_error("StableHLO => MHLO failed");
}

// Convert to XLA Computation.
xla::HloProto hlo_proto;
mlir::ConvertMlirHloToHlo(*parsed_module, &hlo_proto,
Expand Down
10 changes: 5 additions & 5 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _enzyme_aug_abstract_eval(
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="mhlo")
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
Expand Down Expand Up @@ -247,7 +247,7 @@ def _enzyme_primal_lowering(
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="mhlo")
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
Expand Down Expand Up @@ -308,7 +308,7 @@ def _enzyme_fwd_lowering(
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2])
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="mhlo")
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i // 2 in kept)
Expand Down Expand Up @@ -368,7 +368,7 @@ def _enzyme_aug_lowering(
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="mhlo")
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
Expand Down Expand Up @@ -432,7 +432,7 @@ def _enzyme_rev_lowering(
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="mhlo")
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
# in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
Expand Down
30 changes: 15 additions & 15 deletions test/lit_tests/ir.pyt
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def fwdmode(a, b, c, d):
return jax.jvp(do_something, (a, b), (c, d))


print(fwdmode.lower(ones, twos, ones, twos).compiler_ir(dialect="mhlo"))
print(fwdmode.lower(ones, twos, ones, twos).compiler_ir(dialect="stablehlo"))

# CHECK: module @jit_fwdmode attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg3: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<6x9xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<1> : tensor<1xi64>
# CHECK-NEXT: %1:4 = mhlo.custom_call @jaxzyme.fwd(%0, %arg0, %arg2, %arg1, %arg3) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
# CHECK-NEXT: %0 = stablehlo.constant dense<1> : tensor<1xi64>
# CHECK-NEXT %1:4 = stablehlo.custom_call @jaxzyme.fwd(%0, %arg0, %arg2, %arg1, %arg3) : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
# CHECK-NEXT: return %1#0, %1#2, %1#1, %1#3 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<6x9xf32>, tensor<4x6xf32>
# CHECK-NEXT: }
# CHECK-NEXT: }
Expand All @@ -57,12 +57,12 @@ def f(a, b):
return jax.vjp(do_something, a, b)


print(f.lower(ones, twos).compiler_ir(dialect="mhlo"))
print(f.lower(ones, twos).compiler_ir(dialect="stablehlo"))

# CHECK: module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<16xi8> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][1]", mhlo.layout_mode = "default"}, tensor<5x7xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][2]", mhlo.layout_mode = "default"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<2> : tensor<1xi64>
# CHECK-NEXT: %1:3 = mhlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: %0 = stablehlo.constant dense<2> : tensor<1xi64>
# CHECK-NEXT: %1:3 = stablehlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: return %1#0, %1#1, %1#2, %arg0, %arg1 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>, tensor<2x3xf32>, tensor<5x7xf32>
# CHECK-NEXT: }
# CHECK-NEXT: }
Expand All @@ -79,28 +79,28 @@ def g(a, b, x, y):

# CHECK: module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<6x9xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg3: tensor<4x6xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<5x7xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<3> : tensor<1xi64>
# CHECK-NEXT: %1:3 = mhlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: %2 = mhlo.constant dense<4> : tensor<1xi64>
# CHECK-NEXT: %3:2 = mhlo.custom_call @jaxzyme.rev(%2, %1#2, %arg2, %arg3) {backend_config = ""} : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: %0 = stablehlo.constant dense<3> : tensor<1xi64>
# CHECK-NEXT: %1:3 = stablehlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: %2 = stablehlo.constant dense<4> : tensor<1xi64>
# CHECK-NEXT: %3:2 = stablehlo.custom_call @jaxzyme.rev(%2, %1#2, %arg2, %arg3) : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: return %1#0, %1#1, %3#0, %3#1 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<2x3xf32>, tensor<5x7xf32>
# CHECK-NEXT: }
# CHECK-NEXT: }

print(g.lower(ones, twos, x, y).compiler_ir(dialect="mhlo"))
print(g.lower(ones, twos, x, y).compiler_ir(dialect="stablehlo"))

primals, f_vjp = jax.vjp(jax.jit(do_something), ones, twos)

print(jax.jit(f_vjp).lower((x, y)).compiler_ir(dialect="mhlo"))
print(jax.jit(f_vjp).lower((x, y)).compiler_ir(dialect="stablehlo"))
# CHECK: module @jit__unnamed_wrapped_function_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<6x9xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<4x6xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}, tensor<5x7xf32> {mhlo.layout_mode = "default"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<[0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63]> : tensor<16xi8>
# CHECK-NEXT: %0 = stablehlo.constant dense<[0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63]> : tensor<16xi8>
# CHECK-NEXT: %1:2 = call @do_something(%0, %arg0, %arg1) : (tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: return %1#0, %1#1 : tensor<2x3xf32>, tensor<5x7xf32>
# CHECK-NEXT: }
# CHECK-NEXT: func.func private @do_something(%arg0: tensor<16xi8>, %arg1: tensor<6x9xf32>, %arg2: tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>) {
# CHECK-NEXT: %0 = mhlo.constant dense<6> : tensor<1xi64>
# CHECK-NEXT: %1:2 = mhlo.custom_call @jaxzyme.rev(%0, %arg0, %arg1, %arg2) {backend_config = ""} : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: %0 = stablehlo.constant dense<6> : tensor<1xi64>
# CHECK-NEXT: %1:2 = stablehlo.custom_call @jaxzyme.rev(%0, %arg0, %arg1, %arg2) : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: return %1#0, %1#1 : tensor<2x3xf32>, tensor<5x7xf32>
# CHECK-NEXT: }
# CHECK-NEXT: }

0 comments on commit 13f0c63

Please sign in to comment.