From 7c61f6d3b0f8c8120c58e188ccfde9ad9a76a64f Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 25 Sep 2024 18:39:54 -0700 Subject: [PATCH] Remove AutoShardingSolverResult in favor of StatusOr as the AutoShardingSolverResult::skip_auto_sharding is now dead after some recent changes. PiperOrigin-RevId: 678928364 --- third_party/stablehlo/temporary.patch | 901 ------------------ third_party/stablehlo/workspace.bzl | 4 +- xla/hlo/experimental/auto_sharding/BUILD | 3 +- .../auto_sharding/auto_sharding.cc | 87 +- .../auto_sharding/auto_sharding.h | 21 +- .../auto_sharding/auto_sharding_impl.cc | 3 +- .../auto_sharding/auto_sharding_solver.cc | 44 +- .../auto_sharding/auto_sharding_solver.h | 16 +- .../auto_sharding_solver_impl.cc | 2 +- .../auto_sharding_solver_test.cc | 147 ++- .../auto_sharding/auto_sharding_test.cc | 3 +- .../auto_sharding/auto_sharding_wrapper.h | 17 +- xla/mlir_hlo/BUILD | 2 +- .../stablehlo_ext/transforms/passes.h | 2 +- ...p => stablehlo_compatibility_expander.cpp} | 4 +- xla/pjrt/mlir_to_hlo.cc | 2 +- xla/tsl/BUILD | 15 +- 17 files changed, 150 insertions(+), 1123 deletions(-) rename xla/mlir_hlo/stablehlo_ext/transforms/{stablehlo_create_compatibility_expander.cpp => stablehlo_compatibility_expander.cpp} (87%) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 7102b01238d61..8b137891791fe 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,902 +1 @@ -diff --ruN a/stablehlo/examples/c++/ExampleAdd.cpp b/stablehlo/examples/c++/ExampleAdd.cpp ---- stablehlo/examples/c++/ExampleAdd.cpp -+++ stablehlo/examples/c++/ExampleAdd.cpp -@@ -18,7 +18,7 @@ - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/LogicalResult.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Block.h" - #include "mlir/IR/Builders.h" -@@ -43,7 +43,7 @@ - mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - module->getContext()->loadDialect(); - module->getContext()->loadDialect(); -- module->getContext()->loadDialect(); -+ module->getContext()->loadDialect(); - module->setName("test_module"); - - /** create function **/ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp -@@ -17,7 +17,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Tosa/IR/TosaOps.h" - #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp -@@ -18,7 +18,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Tosa/IR/TosaOps.h" - #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -31,7 +31,7 @@ - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/Debug.h" - #include "llvm/Support/ErrorHandling.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.h b/stablehlo/stablehlo/dialect/ChloOps.h ---- stablehlo/stablehlo/dialect/ChloOps.h -+++ stablehlo/stablehlo/dialect/ChloOps.h -@@ -20,7 +20,7 @@ - #include "llvm/ADT/APFloat.h" - #include "llvm/ADT/StringRef.h" - #include "mlir/Bytecode/BytecodeOpInterface.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinTypes.h" -diff --ruN a/stablehlo/stablehlo/dialect/Register.cpp b/stablehlo/stablehlo/dialect/Register.cpp ---- stablehlo/stablehlo/dialect/Register.cpp -+++ stablehlo/stablehlo/dialect/Register.cpp -@@ -17,7 +17,7 @@ - #include "stablehlo/dialect/Register.h" - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" - #include "mlir/IR/DialectRegistry.h" - #include "stablehlo/dialect/ChloOps.h" -@@ -30,7 +30,7 @@ - void registerAllDialects(mlir::DialectRegistry ®istry) { - // clang-format off - registry.insert(); - registry.insert - - #include "llvm/ADT/StringRef.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" -diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp ---- stablehlo/stablehlo/dialect/TypeInference.cpp -+++ stablehlo/stablehlo/dialect/TypeInference.cpp -@@ -52,7 +52,7 @@ - #include "llvm/Support/Regex.h" - #include "llvm/Support/raw_ostream.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.cpp b/stablehlo/stablehlo/dialect/VhloTypes.cpp ---- stablehlo/stablehlo/dialect/VhloTypes.cpp -+++ stablehlo/stablehlo/dialect/VhloTypes.cpp -@@ -20,7 +20,7 @@ - #include "llvm/ADT/SmallVectorExtras.h" - #include "llvm/ADT/StringRef.h" - #include "llvm/ADT/TypeSwitch.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinTypes.h" -diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp ---- stablehlo/stablehlo/reference/Api.cpp -+++ stablehlo/stablehlo/reference/Api.cpp -@@ -31,7 +31,7 @@ - #include "llvm/Support/Path.h" - #include "llvm/Support/SourceMgr.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" -diff --ruN a/stablehlo/stablehlo/tests/CheckOps.h b/stablehlo/stablehlo/tests/CheckOps.h ---- stablehlo/stablehlo/tests/CheckOps.h -+++ stablehlo/stablehlo/tests/CheckOps.h -@@ -17,7 +17,7 @@ - #define STABLEHLO_DIALECT_CHECKOPS_H_ - - #include "mlir/Bytecode/BytecodeOpInterface.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Dialect.h" -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -@@ -1338,24 +1338,24 @@ - - // ----- - -+// expected-error@+1 {{scale out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return - } - - // ----- - -+// expected-error@+1 {{scale out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return - } - - // ----- - -+// expected-error@+1 {{illegal quantized dimension: -1}} - func.func @quantized_element_type_c11(%arg0: tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>> - func.return - } -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -69,7 +69,7 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> - func.return %0 : tensor<4x3x5x8xi32> - } -@@ -77,9 +77,9 @@ - // ----- - - // CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> --// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> - // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ - // CHECK-SAME: dimension_numbers = #stablehlo.gather< -@@ -102,7 +102,7 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> - func.return %0 : tensor<4x3x5x8xi32> - } -@@ -133,9 +133,305 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> - func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> -+func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [0, 1], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> -+ func.return %0 : tensor<3x4x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = true, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [0, 2], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [0, 2], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> -+func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> -+ func.return %0 : tensor<4x127x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> -+func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> -+ func.return %0 : tensor<4x128x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> -+func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [0, 1], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> -+ func.return %0 : tensor<256x4x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> -+func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> -+ func.return %0 : tensor<4x2147483648x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_dynamic_size -+// CHECK: operand_batching_dims = [0, 2] -+// CHECK: start_indices_batching_dims = [1, 0] -+func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> -+ func.return %0 : tensor<4x?x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> -+func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> -+ func.return %0 : tensor<4x128x5x8xi32> - } - - // ----- -@@ -156,7 +452,7 @@ - // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] - // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -- indices_are_sorted = true, -+ indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [3], - inserted_window_dims = [1, 3], -@@ -176,9 +472,9 @@ - // ----- - - // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> --// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> - // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ - // CHECK-SAME: indices_are_sorted = false, -@@ -192,7 +488,7 @@ - // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] - // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -- indices_are_sorted = true, -+ indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [3], - inserted_window_dims = [1], -@@ -208,3 +504,60 @@ - }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> - func.return %0 : tensor<3x2x4x9xi32> - } -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = true, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> -+func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [2, 3], -+ input_batching_dims = [0, 1], -+ scatter_indices_batching_dims = [0, 2], -+ scatter_dims_to_operand_dims = [2, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> -+ func.return %0 : tensor<2x5x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices -+// CHECK: input_batching_dims = [0, 2] -+// CHECK: scatter_indices_batching_dims = [1, 0] -+func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = false, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor -+ func.return %0 : tensor -+} -diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp ---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -@@ -24,7 +24,7 @@ - #include "llvm/Support/ErrorHandling.h" - #include "llvm/Support/LogicalResult.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/DialectRegistry.h" -@@ -237,7 +237,7 @@ - }, - [](DialectRegistry ®istry) { - registry.insert(); -- registry.insert(); -+ registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -19,7 +19,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/Pass/Pass.h" -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -68,7 +68,7 @@ - let summary = "Legalize VHLO to StableHLO."; - let dependentDialects = [ - "mlir::func::FuncDialect", -- "mlir::quant::QuantizationDialect", -+ "mlir::quant::QuantDialect", - "mlir::shape::ShapeDialect", - "mlir::stablehlo::StablehloDialect", - ]; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -22,8 +22,11 @@ - #include "llvm/ADT/STLExtras.h" - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" -+#include "llvm/Support/MathExtras.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -@@ -75,6 +78,42 @@ - return result; - } - -+bool fitsInIntegralType(int64_t size, IntegerType type) { -+ if (type.isUnsigned()) { -+ return llvm::isUIntN(type.getWidth(), size); -+ } else { -+ return llvm::isIntN(type.getWidth(), size); -+ } -+} -+ -+// If `type` is an integer type in which `size` doesn't fit, promote it to i32 -+// or i64 (depending on `size`). -+Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { -+ // Gather/Scatter should have an integer type, but we check just in case. -+ auto intType = dyn_cast(type); -+ if (!intType || fitsInIntegralType(size, intType)) { -+ return type; -+ } -+ if (fitsInIntegralType(size, builder.getI32Type())) { -+ return builder.getI32Type(); -+ } -+ return builder.getI64Type(); -+} -+ -+// If `indices_batching_dims` and `updated_index_map` are both sorted, then the -+// `indices_are_sorted` property is preserved. -+// -+// This is because each concatenated iota is monotonically increasing, sorted -+// indices batching dims mean their order corresponds to the order of batching -+// dims in the operand, and a sorted updated start index map means the order of -+// the index vector dim corresponds to the order of operand dims. -+bool getUpdatedIndicesAreSorted(bool indices_are_sorted, -+ ArrayRef indices_batching_dims, -+ ArrayRef updated_index_map) { -+ return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && -+ llvm::is_sorted(updated_index_map); -+} -+ - // Returns an updated indices tensor such that an `IotaOp` is prepended for each - // dim in `indicesBatchingDims` with a `ConcatenateOp`. - // -@@ -85,16 +124,31 @@ - PatternRewriter &rewriter) { - Location loc = indices.getLoc(); - auto indicesType = cast(indices.getType()); -+ Type elementType = indicesType.getElementType(); -+ -+ // The batching dim sizes might not fit in the existing element type, -+ // in which case we need to promote it. -+ for (int64_t batchingDim : indicesBatchingDims) { -+ elementType = promoteTypeForSize( -+ elementType, indicesType.getDimSize(batchingDim), rewriter); -+ } -+ if (elementType != indicesType.getElementType()) { -+ indicesType = RankedTensorType::get(indicesType.getShape(), elementType); -+ indices = rewriter.create(loc, indicesType, indices); -+ } -+ - bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -- - SmallVector iotaShape(indicesType.getShape()); - if (indexVectorDimOnLastDim) { - iotaShape.push_back(1); - } else { - iotaShape[indexVectorDim] = 1; - } -- auto iotaType = -- RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ auto iotaType = RankedTensorType::get(iotaShape, elementType); -+ -+ if (indexVectorDimOnLastDim) { -+ indices = rewriter.create(loc, iotaType, indices); -+ } - - SmallVector indicesToConcat; - indicesToConcat.reserve(indicesBatchingDims.size() + 1); -@@ -102,12 +156,7 @@ - indicesToConcat.push_back( - rewriter.create(loc, iotaType, batchingDim)); - } -- if (indexVectorDimOnLastDim) { -- indicesToConcat.push_back( -- rewriter.create(loc, iotaType, indices)); -- } else { -- indicesToConcat.push_back(indices); -- } -+ indicesToConcat.push_back(indices); - return rewriter.create(loc, indicesToConcat, indexVectorDim); - } - -@@ -125,9 +174,17 @@ - PatternRewriter &rewriter) const override { - GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); - ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ ArrayRef startIndicesBatchingDims = -+ dimNumbers.getStartIndicesBatchingDims(); - if (operandBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ if (!op.getStartIndices().getType().hasStaticShape()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has start indices with dynamic shape, can't expand"; - }); - } - -@@ -136,16 +193,18 @@ - SmallVector newStartIndexMap = - llvm::to_vector(llvm::concat( - operandBatchingDims, dimNumbers.getStartIndexMap())); -- Value newIndices = createConcatIndices( -- op.getStartIndices(), dimNumbers.getIndexVectorDim(), -- dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ Value newIndices = createConcatIndices(op.getStartIndices(), -+ dimNumbers.getIndexVectorDim(), -+ startIndicesBatchingDims, rewriter); - rewriter.replaceOpWithNewOp( - op, op.getOperand(), newIndices, - GatherDimensionNumbersAttr::get( - op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, - /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, - newStartIndexMap, dimNumbers.getIndexVectorDim()), -- op.getSliceSizes(), /*indicesAreSorted=*/false); -+ op.getSliceSizes(), -+ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), -+ startIndicesBatchingDims, newStartIndexMap)); - - return success(); - } -@@ -161,9 +220,17 @@ - PatternRewriter &rewriter) const override { - ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); - ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ ArrayRef scatterIndicesBatchingDims = -+ dimNumbers.getScatterIndicesBatchingDims(); - if (inputBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ if (!op.getScatterIndices().getType().hasStaticShape()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has start indices with dynamic shape, can't expand"; - }); - } - -@@ -174,7 +241,7 @@ - inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); - Value newIndices = createConcatIndices( - op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -- dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ scatterIndicesBatchingDims, rewriter); - auto newScatterOp = rewriter.create( - op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, - op.getUpdates(), -@@ -183,7 +250,10 @@ - newInsertedWindowDims, - /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, - newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -- /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), -+ scatterIndicesBatchingDims, -+ newScatterDimsToOperandDims), -+ op.getUniqueIndices()); - - newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); - rewriter.replaceOp(op, newScatterOp.getResults()); -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp -@@ -15,7 +15,7 @@ - - #include "llvm/ADT/SmallVector.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Operation.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/Transforms/DialectConversion.h" // Include for TypeConverter -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp -@@ -24,8 +24,8 @@ - #include "llvm/ADT/SmallVector.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/Dialect/Func/Transforms/FuncConversions.h" --#include "mlir/Dialect/Quant/QuantOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" -@@ -1331,7 +1331,7 @@ - populateReturnOpTypeConversionPattern(patterns, converter); - - ConversionTarget target(*op->getContext()); -- target.addIllegalDialect(); -+ target.addIllegalDialect(); - auto isLegal = [&converter](Operation *op) { - return converter.isLegal(op); - }; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp -@@ -17,7 +17,7 @@ - - #include "llvm/ADT/STLExtras.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/TypeRange.h" diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 0bd9fb077ccc9..5ad132c608ca7 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" - STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" + STABLEHLO_COMMIT = "f7f8e4e35296deeff2e12e39421ac8d9599ba340" + STABLEHLO_SHA256 = "c92b55d5512e58d6fefba62c58e60d7762adb184dc3ad489521de562f6ca7aeb" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index ee63f4e96cc4f..554ecde61e91d 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -217,6 +217,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -227,7 +228,6 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", - ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_wrapper", @@ -236,6 +236,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 29ad335d888e2..21a1fcd62bf0e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1747,7 +1747,8 @@ std::unique_ptr CreateReshapeStrategies( return strategy_group; } -AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -3504,14 +3505,14 @@ std::pair ReduceMemoryTerms( return num_terms; } -absl::StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution) { if (!option_.enable) { - return AutoShardingResult::kModuleUnchanged; + return false; } bool module_is_changed = false; @@ -3790,16 +3791,11 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - auto solver_result = + spmd::AutoShardingSolverResult solver_result = Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, option_, request_name, sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerformed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, solver_result.status); if (mesh_idx == partial_mesh_shapes.size() - 1) { @@ -3823,21 +3819,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { - if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } - - if (!InsertReshardReshapes( - sequence, instructions_to_shard, strategy_map, cost_graph, - output.s_val, cluster_env, - /* crash_at_error */ !option_.try_multiple_mesh_shapes, - option_.insert_resharding_reshapes_for_non_dot_ops, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } + TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing( + sequence, instructions_to_shard, preserve_shardings)); + TF_RETURN_IF_ERROR(InsertReshardReshapes( + sequence, instructions_to_shard, strategy_map, cost_graph, + output.s_val, cluster_env, + /* crash_at_error */ !option_.try_multiple_mesh_shapes, + option_.insert_resharding_reshapes_for_non_dot_ops, + preserve_shardings)); } else { spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings); } @@ -3878,8 +3867,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( } } - return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed - : AutoShardingResult::kModuleUnchanged; + return module_is_changed; } bool ModuleIsManuallyPartitioned(const HloModule* module) { @@ -4109,15 +4097,12 @@ absl::StatusOr AutoSharding::Run( } } - absl::StatusOr min_mesh_pass_result = - AutoShardingResult::kModuleUnchanged; - + bool module_is_changed = false; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); double min_objective_value = std::numeric_limits::max(); int min_mesh_shape_index = -1; std::unique_ptr min_mesh_shape_module; - bool skip_auto_sharding = true; for (size_t i = 0; i < mesh_shapes.size(); ++i) { VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); AutoShardingOption this_option = option_; @@ -4130,7 +4115,7 @@ absl::StatusOr AutoSharding::Run( } auto pass = std::make_unique(this_option); std::unique_ptr module_clone = CloneModule(module); - absl::StatusOr pass_result = + absl::StatusOr pass_result = pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, execution_threads, sharding_propagation_solution); if (!pass_result.ok()) { @@ -4148,19 +4133,11 @@ absl::StatusOr AutoSharding::Run( min_mesh_shape_index = i; min_mesh_shape_module = std::move(module_clone); min_objective_value = this_mesh_objective_value; - min_mesh_pass_result = pass_result; - } - if (*pass_result != - AutoShardingResult::kModuleUnchangedNoShardingPerformed) { - skip_auto_sharding = false; + CHECK_OK(pass_result); + module_is_changed = *pass_result; } } - if (skip_auto_sharding) { - RecordPassEndAndDumpModule(start_time, module); - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; - } - std::string trying_to_find = option_.try_multiple_mesh_shapes ? "a device mesh (and the corresponding shardings)" @@ -4173,28 +4150,18 @@ absl::StatusOr AutoSharding::Run( "higher budget). If you think you have set a reasonably large memory " "budget, please report this as a bug."; - if (!min_mesh_pass_result.ok()) { - RecordPassEndAndDumpModule(start_time, module); - return min_mesh_pass_result.status(); - } - - absl::StatusOr module_is_changed; solver_optimal_objective_value_ = min_objective_value; - if (*min_mesh_pass_result != - AutoShardingResult::kModuleChangedShardingPerformed) { - RecordPassEndAndDumpModule(start_time, module); - return false; + if (module_is_changed) { + VLOG(1) << "Choosing mesh shape " + << spmd::ToString(mesh_shapes[min_mesh_shape_index]) + << " which had the minimal solver objective value of " + << min_objective_value; + chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( + min_mesh_shape_module.get(), module)); } - - VLOG(1) << "Choosing mesh shape " - << spmd::ToString(mesh_shapes[min_mesh_shape_index]) - << " which had the minimal solver objective value of " - << min_objective_value; - chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; - TF_RETURN_IF_ERROR( - MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module)); RecordPassEndAndDumpModule(start_time, module); - return true; + return module_is_changed; } } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index e749fb8682532..7153bf860515c 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -50,18 +50,12 @@ limitations under the License. namespace xla { -enum class AutoShardingResult { - kModuleUnchanged, - kModuleChangedShardingPerformed, - kModuleUnchangedNoShardingPerformed -}; - class AutoShardingImplementation { public: explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - absl::StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, @@ -216,19 +210,6 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env); -// The high-level "recipe" for solving an Auto Sharding problem. -AutoShardingSolverResult Solve( - const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, - const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, - const AutoShardingOption& option, absl::string_view request_prefix, - const absl::flat_hash_map& - sharding_propagation_solution = {}); - // Populates temporal distance values. void PopulateTemporalValues(const CostGraph& cost_graph, AutoShardingSolverRequest& request); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 7a92ac5715039..b9226f561244e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { namespace spmd { -AutoShardingSolverResult Solve( +absl::StatusOr Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 114cca321a050..dec88705f4086 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -81,12 +81,6 @@ bool AutoShardingSolverOutput::operator==( peak_times == other.peak_times; } -bool AutoShardingSolverResult::operator==( - const AutoShardingSolverResult& other) const { - return status == other.status && - skip_auto_sharding == other.skip_auto_sharding; -} - void PrintLargestInstructions( const std::vector& chosen_strategy, const AutoShardingSolverRequest& request) { @@ -143,7 +137,7 @@ void PrintLargestInstructions( } } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -399,7 +393,7 @@ void AddMemoryTerms( // can be a few (usually < 10) edges in the problem with negative costs. This // is guaranteed to never produce a negative overall cost for the graph, // however. -AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); @@ -568,8 +562,7 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; - return AutoShardingSolverResult(absl::InternalError(err_msg), - /*skip_auto_sharding=*/false); + return absl::InternalError(err_msg); } } } @@ -783,9 +776,9 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( } auto result = SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, *solver); - if (result.status.ok()) { + if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(unscaled_request, result); + Evaluate(unscaled_request, *result); LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost @@ -831,7 +824,7 @@ std::vector GetChosenNodeStrategy( return chosen_node_strategy; } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -869,22 +862,18 @@ AutoShardingSolverResult SolveAndExtractSolution( } } #endif - return AutoShardingSolverResult( - absl::InternalError("MPSolver could not find any feasible solution."), - /*skip_auto_sharding=*/false); + return absl::InternalError( + "MPSolver could not find any feasible solution."); } else if (status == operations_research::MPSolver::MODEL_INVALID) { - LOG(FATAL) << "Solver says that the input MIP is invalid. This is most " - "likely a bug and should be reported."; - return AutoShardingSolverResult(absl::InternalError("Invalid MIP."), - /*skip_auto_sharding=*/false); + LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a " + "bug and should be reported."; + return absl::InternalError("Invalid MIP."); } else if (status == operations_research::MPSolver::NOT_SOLVED) { LOG(WARNING) << "Solver timeout; no solution was produced"; - return AutoShardingSolverResult(absl::InternalError("Solver timed out."), - /*skip_auto_sharding=*/true); + return absl::InternalError("Solver timed out."); } else if (status != operations_research::MPSolver::OPTIMAL) { LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution"; } - // Fingerprint the model & solution (useful when checking for determinism). // We use TensorFlow's fingerprint library here, which differs from CP-SAT's. operations_research::MPModelProto model_proto; @@ -951,9 +940,8 @@ AutoShardingSolverResult SolveAndExtractSolution( << request.memory_budget() / (1024 * 1024 * 1024) << " GB"; } PrintLargestInstructions(chosen_node_strategy, request); - const AutoShardingSolverOutput output = {std::move(chosen_node_strategy), - solver.Objective().Value()}; - return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false); + return AutoShardingSolverOutput{.s_val = std::move(chosen_node_strategy), + .cost = solver.Objective().Value()}; } bool CostComponents::operator==(const CostComponents& other) const { @@ -977,13 +965,13 @@ bool AutoShardingEvaluation::operator==( } AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result) { + const AutoShardingSolverOutput& result) { const auto& c = request.computation_costs(); const auto& d = request.communication_costs(); const auto& r = request.resharding_costs(); const auto& v = request.value_costs(); const auto& p = request.departure_costs(); - const std::vector& s_val = result.status->s_val; + const std::vector& s_val = result.s_val; const auto e_val = [&](EdgeIdx edge_idx) { const auto& edge = request.edges(edge_idx); return s_val[edge.first()] * request.s_len(edge.second()) + diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 88884f7286d0b..e6dd82717b6e8 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -37,17 +37,7 @@ struct AutoShardingSolverOutput { bool operator==(const AutoShardingSolverOutput& other) const; }; -struct AutoShardingSolverResult { - public: - AutoShardingSolverResult(absl::StatusOr status, - bool skip_auto_sharding) - : status(status), skip_auto_sharding(skip_auto_sharding) {} - bool operator==(const AutoShardingSolverResult& other) const; - absl::StatusOr status; - bool skip_auto_sharding; -}; - -AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request); enum AutoShardingViolationCode { @@ -92,7 +82,7 @@ struct AutoShardingEvaluation { // Evaluates the given solver result w.r.t. the input request, computing various // solution quality metrics and validating the consistency of hard constraints. AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result); + const AutoShardingSolverOutput& result); // Creates and returns a variable for makespan. operations_research::MPVariable* CreateMakespanVar( @@ -101,7 +91,7 @@ operations_research::MPVariable* CreateMakespanVar( operations_research::MPSolver& solver); double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation); // Scale down values to reduce the range of costs & coefficients in the solver. diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 4be54f98a0a49..176f1426a9866 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -33,7 +33,7 @@ MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request, } double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation) { return 0.0; // TODO(moffitt): Implement this. } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 3e0c82d3b7551..4ddafbee670ca 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "tsl/platform/platform.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -253,14 +254,13 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { @@ -268,42 +268,39 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(3.0); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_minimize_departures(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { @@ -312,28 +309,26 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { @@ -352,14 +347,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { 70000, 71000, 72000, 73000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { @@ -380,14 +374,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { 80000, 81000, 82000, 83000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { @@ -395,38 +388,36 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 - const AutoShardingSolverResult result = + const absl::StatusOr result = FormulateAndSolveMIPFromSolverRequest(request); - EXPECT_TRUE(absl::IsInternal(result.status.status())); + EXPECT_TRUE(absl::IsInternal(result.status())); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(1e19); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { @@ -443,14 +434,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { @@ -472,14 +462,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -506,14 +495,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -527,14 +515,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -569,28 +556,26 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, request.set_enable_memory_edge_costs(true); request.set_memory_budget(4321); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(AutoShardingEvaluatorTest, NoViolations) { @@ -598,9 +583,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51 @@ -620,9 +604,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -648,9 +631,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -679,9 +661,8 @@ TEST(AutoShardingEvaluatorTest, const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -701,9 +682,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const std::vector s_val = {3, 1, 2, 1 /* violates */, 1}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kFollowerViolationCode}; @@ -722,9 +702,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const std::vector s_val = {3, 1, 2, 2, 0 /* violates */}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kAliasViolationCode}; @@ -743,9 +722,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMemoryViolationCode}; @@ -767,9 +745,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -789,9 +766,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const std::vector s_val = {0, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -811,9 +787,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMaxDeparturesViolationCode}; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index cd12aeaf3d846..aa3e45ea54b0b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -2839,7 +2839,8 @@ ENTRY matmul { // TODO(b/369616683) Fix the error message output in this case. EXPECT_DEATH( absl::StatusOr status = AutoSharding(option).Run(module.get()), - "The auto-sharding solver has timed out without a solution."); + "The auto-sharding pass could not find shardings that works for this " + "input."); } TEST_F(AutoShardingTest, IgnoreShardAsShardLike) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index f9058802eea52..333df715447f0 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -39,9 +40,23 @@ limitations under the License. namespace xla { namespace spmd { +// The high-level "recipe" for solving an Auto Sharding problem. +absl::StatusOr Solve( + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, + const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, + const CostGraph& cost_graph, const AliasSet& alias_set, + const std::vector>& node_intervals, + const std::vector>& edge_intervals, + const std::vector>& node_groups, + const std::vector>& edge_groups, + const AutoShardingOption& option, absl::string_view request_prefix, + const absl::flat_hash_map& + sharding_propagation_solution = {}); + // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. -AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index 30eb4a8fccf14..e8147577018ee 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -1147,7 +1147,7 @@ cc_library( srcs = [ "stablehlo_ext/transforms/chlo_recompose_ops.cpp", "stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp", - "stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp", + "stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", ], hdrs = [ diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h index c72a92f112b23..22c6d637c1409 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h +++ b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h @@ -35,7 +35,7 @@ void createChloLegalizeToStablehloPipeline(OpPassManager &pm); // Expand backward compatibility with the given StableHLO version by decomposing // newer StableHLO operations into equivalent operations supported by that older // version. -std::unique_ptr createStablehloCreateCompatibilityExpanderPass( +std::unique_ptr createStablehloCompatibilityExpanderPass( std::string targetVersionOption); } // namespace stablehlo_ext diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp similarity index 87% rename from xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp rename to xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp index 0db5fd4780b67..96dc5b8c645b9 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp @@ -25,9 +25,9 @@ namespace stablehlo_ext { // TODO(b/369406385): remove this method (and file) once issue is resolved. -std::unique_ptr<::mlir::Pass> createStablehloCreateCompatibilityExpanderPass( +std::unique_ptr<::mlir::Pass> createStablehloCompatibilityExpanderPass( std::string targetVersionOption) { - return mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( + return mlir::stablehlo::createStablehloCompatibilityExpanderPass( {std::move(targetVersionOption)}); } diff --git a/xla/pjrt/mlir_to_hlo.cc b/xla/pjrt/mlir_to_hlo.cc index 393377a73f157..a7adeedd9203d 100644 --- a/xla/pjrt/mlir_to_hlo.cc +++ b/xla/pjrt/mlir_to_hlo.cc @@ -199,7 +199,7 @@ absl::StatusOr SerializeUsingVersionedStablehlo( pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( - mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( + mlir::stablehlo::createStablehloCompatibilityExpanderPass( {std::string(target)})); pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); diff --git a/xla/tsl/BUILD b/xla/tsl/BUILD index 18f54ec37eb9f..401cf088936d3 100644 --- a/xla/tsl/BUILD +++ b/xla/tsl/BUILD @@ -3,7 +3,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl:package_groups.bzl", "tsl_package_groups") -load("//xla/tsl:tsl.bzl", "if_google", "if_oss") +load("//xla/tsl:tsl.bzl", "if_google", "if_oss", "internal_visibility") load( "//xla/tsl:tsl.default.bzl", "tsl_extra_config_settings", @@ -500,11 +500,20 @@ config_setting( ) config_setting( - name = "no_nccl_support", + name = "using_no_nccl_support_define", define_values = dict( - if_google({"GOOGLE_CUDA_COMPILER": "clang"}), no_nccl_support = "true", ), + visibility = internal_visibility(["//visibility:private"]), +) + +selects.config_setting_group( + name = "no_nccl_support", + match_all = [ + ":using_no_nccl_support_define", + ] + if_google([ + "@local_config_cuda//cuda:using_config_cuda", + ]), visibility = ["//visibility:public"], )