Skip to content

Commit

Permalink
[IFRT] Add pass to legalize VIFRT into IFRT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696167532
  • Loading branch information
ICGog authored and Google-ML-Automation committed Nov 20, 2024
1 parent c738435 commit 70e7b55
Show file tree
Hide file tree
Showing 14 changed files with 679 additions and 94 deletions.
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ cc_library(
"//xla/python/ifrt:serdes",
"//xla/python/ifrt/ir/transforms:passes",
"//xla/python/ifrt/support:module_parsing",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
8 changes: 6 additions & 2 deletions xla/python/ifrt/ir/ifrt_ir_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ struct IfrtIRProgram : llvm::RTTIExtends<IfrtIRProgram, Program> {
struct SerializeIfrtIRProgramOptions
: llvm::RTTIExtends<SerializeIfrtIRProgramOptions, SerializeOptions> {
explicit SerializeIfrtIRProgramOptions(std::string ifrt_version,
std::string atom_program_version)
std::string atom_program_version,
bool version_in_place = true)
: ifrt_version(std::move(ifrt_version)),
atom_program_version(std::move(atom_program_version)) {}
atom_program_version(std::move(atom_program_version)),
version_in_place(version_in_place) {}

static char ID; // NOLINT

Expand All @@ -71,6 +73,8 @@ struct SerializeIfrtIRProgramOptions
// String of the form "major.minor.patch", representing the atom program
// version (currently VHLO version).
std::string atom_program_version;
// Whether to version the IFRT IR ModuleOp in-place.
bool version_in_place;
};

// Options for deserializing IFRT IR programs.
Expand Down
29 changes: 20 additions & 9 deletions xla/python/ifrt/ir/ifrt_ir_program_serdes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <string>
#include <utility>

#include "absl/cleanup/cleanup.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -26,6 +27,7 @@ limitations under the License.
#include "llvm/Support/ExtensibleRTTI.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -62,7 +64,7 @@ class IfrtIRProgramSerDes
absl::StatusOr<std::string> Serialize(
Serializable& serializable,
std::unique_ptr<SerializeOptions> options) override {
const auto& program = llvm::cast<IfrtIRProgram>(serializable);
auto& program = llvm::cast<IfrtIRProgram>(serializable);
if (program.mlir_module == nullptr) {
return absl::InvalidArgumentError("Unable to serialize null MLIR module");
}
Expand All @@ -88,13 +90,20 @@ class IfrtIRProgramSerDes
}
} else {
program_proto.set_ifrt_version(serialize_options->ifrt_version);

mlir::OwningOpRef<mlir::ModuleOp> cloned;
mlir::ModuleOp mlir_module;
if (serialize_options->version_in_place) {
mlir_module = program.mlir_module;
} else {
cloned = program.mlir_module.clone();
mlir_module = *cloned;
}
// Run the pipeline to convert IFRT IR program to a versioned artifact.
mlir::PassManager pm(program.mlir_module->getContext());
mlir::PassManager pm(mlir_module->getContext());
CreateIfrtToVersionedPipeline(pm, serialize_options->ifrt_version,
serialize_options->atom_program_version,
program_proto);
if (mlir::failed(pm.run(program.mlir_module))) {
if (mlir::failed(pm.run(mlir_module))) {
return absl::InvalidArgumentError(
absl::StrFormat("Failed to version IFRT IR program: %s",
diagnostic_handler.ConsumeStatus().message()));
Expand All @@ -114,7 +123,7 @@ class IfrtIRProgramSerDes
mlir::BytecodeWriterConfig writer_config(bytecode_version_string);
writer_config.setDesiredBytecodeVersion(*fail_or_bytecode_version);
if (mlir::failed(mlir::writeBytecodeToFile(
program.mlir_module, ifrt_ir_program_stream, writer_config))) {
mlir_module, ifrt_ir_program_stream, writer_config))) {
return absl::InvalidArgumentError(absl::StrFormat(
"Failed to serialize versioned IFRT IR module string: %s",
diagnostic_handler.ConsumeStatus().message()));
Expand Down Expand Up @@ -143,6 +152,12 @@ class IfrtIRProgramSerDes
context =
std::unique_ptr<mlir::MLIRContext>(deserialize_options->context);
}
absl::Cleanup release_context_pointer = [&]() {
if (use_existing_context) {
// Release the pointer s.t. the existing context is not freed.
context.release();
}
};

IfrtIrProgramProto program_proto;
if (!program_proto.ParseFromString(serialized)) {
Expand All @@ -156,8 +171,6 @@ class IfrtIRProgramSerDes
// The program was not versioned on serialization. The whole IFRT IR
// program was serialized to bytecode.
if (use_existing_context) {
// Release the point s.t. the existing context is not freed.
context.release();
return std::make_unique<IfrtIRProgram>(module.release());
} else {
return std::make_unique<IfrtIRProgram>(std::move(context),
Expand All @@ -176,8 +189,6 @@ class IfrtIRProgramSerDes
}

if (use_existing_context) {
// Release the point s.t. the existing context is not freed.
context.release();
return std::make_unique<IfrtIRProgram>(module.release());
} else {
return std::make_unique<IfrtIRProgram>(std::move(context),
Expand Down
3 changes: 2 additions & 1 deletion xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ module @multiple_calls_of_same_module {
Serialize(*initial_program,
std::make_unique<SerializeIfrtIRProgramOptions>(
Version::getCurrentVersion().toString(),
::mlir::vhlo::Version::getCurrentVersion().toString())));
::mlir::vhlo::Version::getCurrentVersion().toString(),
/*version_in_place=*/false)));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<IfrtIRProgram> deserialized_program,
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ xla_cc_binary(
"//xla/python/ifrt/ir:ifrt_ir_program_serdes", # build_cleaner: keep
"//xla/python/ifrt/ir:version",
"//xla/python/ifrt/ir:vifrt",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt/ir/tests/ifrt-translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#include <memory>
#include <string>

#include "absl/strings/str_cat.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -80,6 +82,8 @@ mlir::TranslateFromMLIRRegistration serializeRegistration(
os << serialized_or->SerializeAsString();
return mlir::success();
} else {
module.emitError(absl::StrCat("failed to serialize: ",
serialized_or.status().message()));
return mlir::failure();
}
},
Expand All @@ -105,6 +109,8 @@ mlir::TranslateToMLIRRegistration deserializeRegistration(
return mlir::OwningOpRef<mlir::ModuleOp>(
deserialized_program_or.value()->mlir_module);
} else {
llvm::dbgs() << "failed to deserialize: "
<< deserialized_program_or.status().message() << "\n";
return nullptr;
}
},
Expand Down
202 changes: 129 additions & 73 deletions xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s
// RUN: ifrt-translate --serialize --ifrt_version=current --atom_program_version=current %s | ifrt-translate --deserialize | ifrt-opt > %t.0
// RUN: ifrt-opt %s > %t.1
// RUN: diff %t.0 %t.1

// ============ Types and attributes ============

Expand Down Expand Up @@ -144,79 +147,6 @@ func.func @op_after(%arg0: !array_cp0, %arg1: !array_cp1)
return %1, %2: !array_cp0, !array_cp0
}

!array_op_call = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
// CHECK-LABEL: "op_call"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}):
func.func @op_call(
%arg0: !array_op_call {ifrt.donated}, %arg1: !array_op_call {ifrt.donated})
-> !array_op_call attributes {ifrt.function} {
// CHECK: %[[OUT0:.+]]:2 = "vifrt.CallV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that the control value is passed to the next call.

// CHECK: %[[OUT1:.+]]:2 = "vifrt.CallV1"(%[[OUT0]]#0, %[[OUT0]]#1)
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 1>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that the donated input indices attribute is converted.

// CHECK: "vifrt.CallV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32: 0>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%2, %ctrl_2 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array_op_call) -> !array_op_call

// Verifies that the io_aliases attribute is converted.

// CHECK: "vifrt.CallV1"(%[[ARG1]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>,
// CHECK-DAG: io_aliases = [array<i32: 0, 0>]
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%3, %ctrl_3 = ifrt.Call @add_one::@main(%arg1) on devices [0,1]
{io_aliases=[array<i32: 0, 0>]} : (!array_op_call) -> !array_op_call

return %1 : !array_op_call
}

// CHECK-NOT @add_one
module @add_one attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = stablehlo.constant dense<1> : tensor<2x2xi32>
%1 = stablehlo.add %arg0, %0 : tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
}

!array_le_in = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
!array_le_out = !ifrt.array<tensor<4x4xi32>,
Expand Down Expand Up @@ -347,3 +277,129 @@ func.func @donated_arguments(
// CHECK: "vifrt.ReturnV1"(%[[OUT]]#0, %[[OUT]]#1) : (!vifrt.array_v1<tensor<2xi32>, #vifrt.sharding_param_v1<1 to [0] on 1>, [2], memory_kind = "vifrt.default">, !vifrt.array_v1<tensor<2xi32>, #vifrt.sharding_param_v1<1 to [0] on 1>, [3], memory_kind = "vifrt.default">)
return %0, %1 : !array_r1, !array_r2
}

// CHECK-LABEL: "op_func_call"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}):
func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1
attributes {ifrt.function} {
// CHECK: %[[OUT0:.+]]:2 = "vifrt.CopyArraysV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: donated = false
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default">, !vifrt.control_v1)
%0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1
// CHECK: %[[OUT1:.+]] = "vifrt.CallFuncV1"(%[[OUT0]]#0)
// CHECK-SAME: <{callee = @copy_back}>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default">) -> !vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">
%1 = func.call @copy_back(%0) : (!array_cp1) -> !array_cp0
return %0: !array_cp1
}

// CHECK: "vifrt.FuncV1"()
// CHECK-SAME: <{
// CHECK-DAG: arg_attrs = []
// CHECK-DAG: function_type = #vifrt.type_v1<!vifrt.func_v1<(!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default">) -> !vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">>>
// CHECK-DAG: res_attrs = []
// CHECK-DAG: sym_name = "copy_back"
// CHECK-DAG: sym_visibility = "vifrt.default"
// CHECK-SAME: }>
// CHECK-NEXT: (%[[ARG1:.*]]: {{.*}}):
func.func @copy_back(%arg1: !array_cp1) -> !array_cp0
attributes {ifrt.function} {
// CHECK: "vifrt.CopyArraysV1"(%[[ARG1]])
// CHECK-SAME: <{
// CHECK-DAG: donated = false
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x4xi32>, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%0, %ctrl = ifrt.CopyArrays(%arg1) : (!array_cp1) -> !array_cp0
return %0: !array_cp0
}

// Important: The test verifying CallOps must be last. This is necessary because
// in order to test serialization rountrip the tests in this file are not split
// into per file tests. However, during deserialization we do not know where to
// re-introduce the atom program modules within the module, and thus we append
// them at the end.
!array_op_call = !ifrt.array<tensor<2x2xi32>,
#ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
// CHECK-LABEL: "op_call"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}):
func.func @op_call(
%arg0: !array_op_call {ifrt.donated}, %arg1: !array_op_call {ifrt.donated})
-> !array_op_call attributes {ifrt.function} {
// CHECK: %[[OUT0:.+]]:2 = "vifrt.CallV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that the control value is passed to the next call.

// CHECK: %[[OUT1:.+]]:2 = "vifrt.CallV1"(%[[OUT0]]#0, %[[OUT0]]#1)
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 1>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that the donated input indices attribute is converted.

// CHECK: "vifrt.CallV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_one::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32: 0>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%2, %ctrl_2 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array_op_call) -> !array_op_call

// Verifies that the io_aliases attribute is converted.

// CHECK: "vifrt.CallV1"(%[[ARG1]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@add_two::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>,
// CHECK-DAG: io_aliases = [array<i32: 0, 0>]
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%3, %ctrl_3 = ifrt.Call @add_two::@main(%arg1) on devices [0,1]
{io_aliases=[array<i32: 0, 0>]} : (!array_op_call) -> !array_op_call

return %1 : !array_op_call
}

// CHECK-NOT @add_one
module @add_one attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = stablehlo.constant dense<1> : tensor<2x2xi32>
%1 = stablehlo.add %arg0, %0 : tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
}

// CHECK-NOT @add_two
module @add_two attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = stablehlo.constant dense<2> : tensor<2x2xi32>
%1 = stablehlo.add %arg0, %0 : tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
}
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_library(
"passes.cc",
"spmd_expandable_interface_verification_pass.cc",
"spmd_expansion_pass.cc",
"vifrt_legalize_to_ifrt_pass.cc",
],
hdrs = [
"map_ifrt_to_vifrt.h",
Expand Down
Loading

0 comments on commit 70e7b55

Please sign in to comment.