diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 7102b01238d619..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,902 +1 @@ -diff --ruN a/stablehlo/examples/c++/ExampleAdd.cpp b/stablehlo/examples/c++/ExampleAdd.cpp ---- stablehlo/examples/c++/ExampleAdd.cpp -+++ stablehlo/examples/c++/ExampleAdd.cpp -@@ -18,7 +18,7 @@ - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/LogicalResult.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Block.h" - #include "mlir/IR/Builders.h" -@@ -43,7 +43,7 @@ - mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - module->getContext()->loadDialect(); - module->getContext()->loadDialect(); -- module->getContext()->loadDialect(); -+ module->getContext()->loadDialect(); - module->setName("test_module"); - - /** create function **/ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp -@@ -17,7 +17,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Tosa/IR/TosaOps.h" - #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp -@@ -18,7 +18,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Tosa/IR/TosaOps.h" - #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -31,7 +31,7 @@ - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/Debug.h" - #include "llvm/Support/ErrorHandling.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.h b/stablehlo/stablehlo/dialect/ChloOps.h ---- stablehlo/stablehlo/dialect/ChloOps.h -+++ stablehlo/stablehlo/dialect/ChloOps.h -@@ -20,7 +20,7 @@ - #include "llvm/ADT/APFloat.h" - #include "llvm/ADT/StringRef.h" - #include "mlir/Bytecode/BytecodeOpInterface.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinTypes.h" -diff --ruN a/stablehlo/stablehlo/dialect/Register.cpp b/stablehlo/stablehlo/dialect/Register.cpp ---- stablehlo/stablehlo/dialect/Register.cpp -+++ stablehlo/stablehlo/dialect/Register.cpp -@@ -17,7 +17,7 @@ - #include "stablehlo/dialect/Register.h" - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" - #include "mlir/IR/DialectRegistry.h" - #include "stablehlo/dialect/ChloOps.h" -@@ -30,7 +30,7 @@ - void registerAllDialects(mlir::DialectRegistry ®istry) { - // clang-format off - registry.insert(); - registry.insert - - #include "llvm/ADT/StringRef.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" -diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp ---- stablehlo/stablehlo/dialect/TypeInference.cpp -+++ stablehlo/stablehlo/dialect/TypeInference.cpp -@@ -52,7 +52,7 @@ - #include "llvm/Support/Regex.h" - #include "llvm/Support/raw_ostream.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.cpp b/stablehlo/stablehlo/dialect/VhloTypes.cpp ---- stablehlo/stablehlo/dialect/VhloTypes.cpp -+++ stablehlo/stablehlo/dialect/VhloTypes.cpp -@@ -20,7 +20,7 @@ - #include "llvm/ADT/SmallVectorExtras.h" - #include "llvm/ADT/StringRef.h" - #include "llvm/ADT/TypeSwitch.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinTypes.h" -diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp ---- stablehlo/stablehlo/reference/Api.cpp -+++ stablehlo/stablehlo/reference/Api.cpp -@@ -31,7 +31,7 @@ - #include "llvm/Support/Path.h" - #include "llvm/Support/SourceMgr.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" -diff --ruN a/stablehlo/stablehlo/tests/CheckOps.h b/stablehlo/stablehlo/tests/CheckOps.h ---- stablehlo/stablehlo/tests/CheckOps.h -+++ stablehlo/stablehlo/tests/CheckOps.h -@@ -17,7 +17,7 @@ - #define STABLEHLO_DIALECT_CHECKOPS_H_ - - #include "mlir/Bytecode/BytecodeOpInterface.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Dialect.h" -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir -@@ -1338,24 +1338,24 @@ - - // ----- - -+// expected-error@+1 {{scale out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return - } - - // ----- - -+// expected-error@+1 {{scale out of expressed type range}} - func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform> - func.return - } - - // ----- - -+// expected-error@+1 {{illegal quantized dimension: -1}} - func.func @quantized_element_type_c11(%arg0: tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>>) { -- // expected-error-re@+1 {{operand #0 must be ranked tensor of {{.*}} 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x5x2x!quant.uniform>'}} - %0 = stablehlo.add %arg0, %arg0 : tensor<1x5x2x!quant.uniform:f32:-1, {0.1:-30, 0.1:-30}>> - func.return - } -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -69,7 +69,7 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> - func.return %0 : tensor<4x3x5x8xi32> - } -@@ -77,9 +77,9 @@ - // ----- - - // CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> --// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> - // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ - // CHECK-SAME: dimension_numbers = #stablehlo.gather< -@@ -102,7 +102,7 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> - func.return %0 : tensor<4x3x5x8xi32> - } -@@ -133,9 +133,305 @@ - index_vector_dim = 3 - >, - slice_sizes = array, -- indices_are_sorted = true -+ indices_are_sorted = false - }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> - func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> -+func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [0, 1], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> -+ func.return %0 : tensor<3x4x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = true, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [0, 2], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> -+func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [2, 3], -+ operand_batching_dims = [0, 1], -+ start_indices_batching_dims = [0, 2], -+ start_index_map = [2, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> -+ func.return %0 : tensor<2x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> -+func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> -+ func.return %0 : tensor<4x127x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> -+func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> -+ func.return %0 : tensor<4x128x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> -+func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [0, 1], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> -+ func.return %0 : tensor<256x4x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> -+func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> -+ func.return %0 : tensor<4x2147483648x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_dynamic_size -+// CHECK: operand_batching_dims = [0, 2] -+// CHECK: start_indices_batching_dims = [1, 0] -+func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> -+ func.return %0 : tensor<4x?x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim -+// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> -+func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = false -+ } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> -+ func.return %0 : tensor<4x128x5x8xi32> - } - - // ----- -@@ -156,7 +452,7 @@ - // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] - // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -- indices_are_sorted = true, -+ indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [3], - inserted_window_dims = [1, 3], -@@ -176,9 +472,9 @@ - // ----- - - // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> --// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> - // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> - // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ - // CHECK-SAME: indices_are_sorted = false, -@@ -192,7 +488,7 @@ - // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] - // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -- indices_are_sorted = true, -+ indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [3], - inserted_window_dims = [1], -@@ -208,3 +504,60 @@ - }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> - func.return %0 : tensor<3x2x4x9xi32> - } -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = true, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> -+func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [2, 3], -+ input_batching_dims = [0, 1], -+ scatter_indices_batching_dims = [0, 2], -+ scatter_dims_to_operand_dims = [2, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> -+ func.return %0 : tensor<2x5x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices -+// CHECK: input_batching_dims = [0, 2] -+// CHECK: scatter_indices_batching_dims = [1, 0] -+func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = false, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor -+ func.return %0 : tensor -+} -diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp ---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -@@ -24,7 +24,7 @@ - #include "llvm/Support/ErrorHandling.h" - #include "llvm/Support/LogicalResult.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/IR/DialectRegistry.h" -@@ -237,7 +237,7 @@ - }, - [](DialectRegistry ®istry) { - registry.insert(); -- registry.insert(); -+ registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -19,7 +19,7 @@ - #include - - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/BuiltinOps.h" - #include "mlir/Pass/Pass.h" -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -68,7 +68,7 @@ - let summary = "Legalize VHLO to StableHLO."; - let dependentDialects = [ - "mlir::func::FuncDialect", -- "mlir::quant::QuantizationDialect", -+ "mlir::quant::QuantDialect", - "mlir::shape::ShapeDialect", - "mlir::stablehlo::StablehloDialect", - ]; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -22,8 +22,11 @@ - #include "llvm/ADT/STLExtras.h" - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" -+#include "llvm/Support/MathExtras.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -@@ -75,6 +78,42 @@ - return result; - } - -+bool fitsInIntegralType(int64_t size, IntegerType type) { -+ if (type.isUnsigned()) { -+ return llvm::isUIntN(type.getWidth(), size); -+ } else { -+ return llvm::isIntN(type.getWidth(), size); -+ } -+} -+ -+// If `type` is an integer type in which `size` doesn't fit, promote it to i32 -+// or i64 (depending on `size`). -+Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { -+ // Gather/Scatter should have an integer type, but we check just in case. -+ auto intType = dyn_cast(type); -+ if (!intType || fitsInIntegralType(size, intType)) { -+ return type; -+ } -+ if (fitsInIntegralType(size, builder.getI32Type())) { -+ return builder.getI32Type(); -+ } -+ return builder.getI64Type(); -+} -+ -+// If `indices_batching_dims` and `updated_index_map` are both sorted, then the -+// `indices_are_sorted` property is preserved. -+// -+// This is because each concatenated iota is monotonically increasing, sorted -+// indices batching dims mean their order corresponds to the order of batching -+// dims in the operand, and a sorted updated start index map means the order of -+// the index vector dim corresponds to the order of operand dims. -+bool getUpdatedIndicesAreSorted(bool indices_are_sorted, -+ ArrayRef indices_batching_dims, -+ ArrayRef updated_index_map) { -+ return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && -+ llvm::is_sorted(updated_index_map); -+} -+ - // Returns an updated indices tensor such that an `IotaOp` is prepended for each - // dim in `indicesBatchingDims` with a `ConcatenateOp`. - // -@@ -85,16 +124,31 @@ - PatternRewriter &rewriter) { - Location loc = indices.getLoc(); - auto indicesType = cast(indices.getType()); -+ Type elementType = indicesType.getElementType(); -+ -+ // The batching dim sizes might not fit in the existing element type, -+ // in which case we need to promote it. -+ for (int64_t batchingDim : indicesBatchingDims) { -+ elementType = promoteTypeForSize( -+ elementType, indicesType.getDimSize(batchingDim), rewriter); -+ } -+ if (elementType != indicesType.getElementType()) { -+ indicesType = RankedTensorType::get(indicesType.getShape(), elementType); -+ indices = rewriter.create(loc, indicesType, indices); -+ } -+ - bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -- - SmallVector iotaShape(indicesType.getShape()); - if (indexVectorDimOnLastDim) { - iotaShape.push_back(1); - } else { - iotaShape[indexVectorDim] = 1; - } -- auto iotaType = -- RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ auto iotaType = RankedTensorType::get(iotaShape, elementType); -+ -+ if (indexVectorDimOnLastDim) { -+ indices = rewriter.create(loc, iotaType, indices); -+ } - - SmallVector indicesToConcat; - indicesToConcat.reserve(indicesBatchingDims.size() + 1); -@@ -102,12 +156,7 @@ - indicesToConcat.push_back( - rewriter.create(loc, iotaType, batchingDim)); - } -- if (indexVectorDimOnLastDim) { -- indicesToConcat.push_back( -- rewriter.create(loc, iotaType, indices)); -- } else { -- indicesToConcat.push_back(indices); -- } -+ indicesToConcat.push_back(indices); - return rewriter.create(loc, indicesToConcat, indexVectorDim); - } - -@@ -125,9 +174,17 @@ - PatternRewriter &rewriter) const override { - GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); - ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ ArrayRef startIndicesBatchingDims = -+ dimNumbers.getStartIndicesBatchingDims(); - if (operandBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ if (!op.getStartIndices().getType().hasStaticShape()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has start indices with dynamic shape, can't expand"; - }); - } - -@@ -136,16 +193,18 @@ - SmallVector newStartIndexMap = - llvm::to_vector(llvm::concat( - operandBatchingDims, dimNumbers.getStartIndexMap())); -- Value newIndices = createConcatIndices( -- op.getStartIndices(), dimNumbers.getIndexVectorDim(), -- dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ Value newIndices = createConcatIndices(op.getStartIndices(), -+ dimNumbers.getIndexVectorDim(), -+ startIndicesBatchingDims, rewriter); - rewriter.replaceOpWithNewOp( - op, op.getOperand(), newIndices, - GatherDimensionNumbersAttr::get( - op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, - /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, - newStartIndexMap, dimNumbers.getIndexVectorDim()), -- op.getSliceSizes(), /*indicesAreSorted=*/false); -+ op.getSliceSizes(), -+ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), -+ startIndicesBatchingDims, newStartIndexMap)); - - return success(); - } -@@ -161,9 +220,17 @@ - PatternRewriter &rewriter) const override { - ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); - ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ ArrayRef scatterIndicesBatchingDims = -+ dimNumbers.getScatterIndicesBatchingDims(); - if (inputBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ if (!op.getScatterIndices().getType().hasStaticShape()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has start indices with dynamic shape, can't expand"; - }); - } - -@@ -174,7 +241,7 @@ - inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); - Value newIndices = createConcatIndices( - op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -- dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ scatterIndicesBatchingDims, rewriter); - auto newScatterOp = rewriter.create( - op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, - op.getUpdates(), -@@ -183,7 +250,10 @@ - newInsertedWindowDims, - /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, - newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -- /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), -+ scatterIndicesBatchingDims, -+ newScatterDimsToOperandDims), -+ op.getUniqueIndices()); - - newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); - rewriter.replaceOp(op, newScatterOp.getResults()); -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp -@@ -15,7 +15,7 @@ - - #include "llvm/ADT/SmallVector.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Operation.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/Transforms/DialectConversion.h" // Include for TypeConverter -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp -@@ -24,8 +24,8 @@ - #include "llvm/ADT/SmallVector.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/Dialect/Func/Transforms/FuncConversions.h" --#include "mlir/Dialect/Quant/QuantOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/Quant.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" -@@ -1331,7 +1331,7 @@ - populateReturnOpTypeConversionPattern(patterns, converter); - - ConversionTarget target(*op->getContext()); -- target.addIllegalDialect(); -+ target.addIllegalDialect(); - auto isLegal = [&converter](Operation *op) { - return converter.isLegal(op); - }; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp -@@ -17,7 +17,7 @@ - - #include "llvm/ADT/STLExtras.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" --#include "mlir/Dialect/Quant/QuantTypes.h" -+#include "mlir/Dialect/Quant/IR/QuantTypes.h" - #include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/PatternMatch.h" - #include "mlir/IR/TypeRange.h" diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 0bd9fb077ccc9a..5ad132c608ca77 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" - STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" + STABLEHLO_COMMIT = "f7f8e4e35296deeff2e12e39421ac8d9599ba340" + STABLEHLO_SHA256 = "c92b55d5512e58d6fefba62c58e60d7762adb184dc3ad489521de562f6ca7aeb" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 916be8db4f6998..89a40bd891e106 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -20,6 +20,8 @@ limitations under the License. #include "ml_dtypes/include/intn.h" // from @ml_dtypes namespace tsl { +using float8_e3m4 = ::ml_dtypes::float8_e3m4; +using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; diff --git a/xla/BUILD b/xla/BUILD index 44459e0e6df7d9..fb6cdbe281e4a3 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -316,6 +316,7 @@ xla_cc_test( ":util", "@com_google_absl//absl/base", "@com_google_absl//absl/numeric:bits", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", ], @@ -373,6 +374,7 @@ xla_cc_test( ":test", ":types", ":util", + "@ml_dtypes//:float8", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 4d0fbf3732ff9a..4686e2ec5c1ac6 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF8E4M3) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, LinspaceF8E4M3Fn) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); @@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) { } } +TEST(Array2dTest, LinspaceF8E3M4) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 0e142c42286e12..31a84a7d929e60 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os, return os << "TOKEN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; + case XLA_FFI_DataType_F8E3M4: + return os << "F8E3M4"; + case XLA_FFI_DataType_F8E4M3: + return os << "F8E4M3"; case XLA_FFI_DataType_F8E4M3FN: return os << "F8E4M3FN"; case XLA_FFI_DataType_F8E4M3B11FNUZ: diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index d5e2b11538133f..f0c4f40e78ea7a 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -195,6 +195,8 @@ typedef enum { XLA_FFI_DataType_C128 = 18, XLA_FFI_DataType_TOKEN = 17, XLA_FFI_DataType_F8E5M2 = 19, + XLA_FFI_DataType_F8E3M4 = 29, + XLA_FFI_DataType_F8E4M3 = 28, XLA_FFI_DataType_F8E4M3FN = 20, XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index b31da22175333d..e6560833cb4aeb 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -73,10 +73,12 @@ enum class DataType : uint8_t { C128 = XLA_FFI_DataType_C128, TOKEN = XLA_FFI_DataType_TOKEN, F8E5M2 = XLA_FFI_DataType_F8E5M2, + F8E4M3 = XLA_FFI_DataType_F8E4M3, F8E4M3FN = XLA_FFI_DataType_F8E4M3FN, F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ, F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, + F8E3M4 = XLA_FFI_DataType_F8E3M4, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64; inline constexpr DataType C128 = DataType::C128; inline constexpr DataType TOKEN = DataType::TOKEN; inline constexpr DataType F8E5M2 = DataType::F8E5M2; +inline constexpr DataType F8E4M3 = DataType::F8E4M3; inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN; inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; +inline constexpr DataType F8E3M4 = DataType::F8E3M4; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::S8: case DataType::U8: case DataType::F8E5M2: + case DataType::F8E4M3: case DataType::F8E4M3FN: case DataType::F8E4M3B11FNUZ: case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: + case DataType::F8E3M4: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 74837790c8449c..315587b94463da 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ), encoded(DataType::F8E4M3B11FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); + EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), + ByteWidth(DataType::F8E4M3)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN), ByteWidth(DataType::F8E4M3FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ), @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E5M2FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ), ByteWidth(DataType::F8E4M3FNUZ)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), + ByteWidth(DataType::F8E3M4)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 12fed1ba745440..3fb2ac3c7786fa 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -265,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C128: case PrimitiveType::TOKEN: case PrimitiveType::F8E5M2: + case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: case PrimitiveType::F8E4M3B11FNUZ: case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: + case PrimitiveType::F8E3M4: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 36f0c5be9d5bde..3eb7c54f919b0a 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" #include "absl/numeric/bits.h" #include "xla/bit_cast.h" @@ -111,21 +112,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); -// Test F8E4M3 floating-point types (F8E4M3FN) +// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN) template class FP8E4M3DistanceTest : public ::testing::Test {}; -using F8E4M3Types = ::testing::Types; +using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F8E3M4Distance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(8.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(15.5)), + 15); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(6)), + 8); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ( + CalculateDistanceInFloats( + std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ( + CalculateDistanceInFloats( + -std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 32); +} + TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) { // a & b are equal, distance should be 0 EXPECT_EQ( CalculateDistanceInFloats(TypeParam(8.0), TypeParam(8.0)), 0); // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats(TypeParam(8.0), TypeParam(13)), - 5); + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(15.0)), 7); // a & b have different exponents EXPECT_EQ( diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index f7c00aece14d0e..e7792c65b7370a 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F8E3M4: + case F8E4M3: case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: @@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index d7bc1ba49a9bf7..8157fe34baee0e 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 7c97c210aa36a5..d425d33c2feab5 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -19,8 +19,10 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/xla/hlo/experimental/auto_sharding/cluster_environment.cc index b5ad371562e0c8..9a68b636b79fa5 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -162,7 +163,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape( } if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i], (*src_tensor_dim_to_mesh_axis)[i])) { - // do nothing; the src is sharded more than the dest + // do nothing; the dst is sharded more than the src continue; } if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i], @@ -231,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost( // Overestimate the cost of replicating a tensor by decomposing the resharding // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( - const Shape& shape, const HloSharding& src_spec, + const Shape& shape, const HloSharding& src_sharding, const DeviceMesh& device_mesh) const { - if (src_spec.IsTileMaximal() || src_spec.IsManual()) { - // TODO(b/238210866) Do not use kInfinityCost. - return kInfinityCost; + if (src_sharding.IsReplicated()) { + return 0; } - int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec); + int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding); double cost = 0.0; for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) { - auto this_cost = this->AllGatherCost(bytes_moved, i); - cost += this_cost; + cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i) + : this->AllGatherCost(bytes_moved, i); bytes_moved *= device_mesh.dimensions()[i]; } return cost; diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 68ad73882e02a7..3a1e7ceabb160f 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -410,11 +410,17 @@ add { // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ> %constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> + // CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> %constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> %constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + %constant.12 = f8e4m3[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -524,7 +530,19 @@ add { %convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10) // CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32> - ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + + // CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3> + %convert.13 = f8e4m3[4] convert(f32[4] %convert.12) + + // CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32> + %convert.14 = f32[4] convert(f8e4m3[4] %convert.13) + + // CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4> + %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) + + // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> + ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index b4e7a128a5d1ed..c41792b1556338 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -600,6 +600,12 @@ func.func @main() { // CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4}) %cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: f8e4m3[4] constant({1, 2, 3, 4}) + %cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + + // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) + %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + func.return } @@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32> %6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ> %7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32> - func.return %7 : tensor<2xf32> + %8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3> + %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> + %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> + %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> + func.return %11 : tensor<2xf32> } // CHECK: ENTRY @@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) // CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) // CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) +// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) +// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) +// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 971b1d48ac563b..ed0716f8ec50e8 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,9 +91,10 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || - !proto.f8e4m3fnuzs().empty() || !proto.f16s().empty() || + !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || + !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || + !proto.f8e3m4s().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1684,7 +1685,15 @@ void ConvertBetweenNativeTypes(absl::Span src_data, return std::numeric_limits::lowest(); } } - return static_cast(src); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case e3m4 by + // casting to half first. + if constexpr (sizeof(src) == 1 && + std::is_same_v) { + return static_cast(static_cast(src)); + } else { + return static_cast(src); + } }; NativeDestT* dest_data = static_cast(dst_base); @@ -2258,6 +2267,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E4M3: + *proto->mutable_f8e4m3s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E4M3FN: *proto->mutable_f8e4m3fns() = std::string( reinterpret_cast(data().data()), @@ -2278,6 +2292,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E3M4: + *proto->mutable_f8e3m4s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2436,6 +2455,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E4M3: { + const std::string& s(proto.f8e4m3s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e4m3) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E4M3FN: { const std::string& s(proto.f8e4m3fns()); TF_RET_CHECK(data().size() * @@ -2468,6 +2494,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E3M4: { + const std::string& s(proto.f8e3m4s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e3m4) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index fa3a7cda9824cd..c97629594122bb 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -354,8 +354,16 @@ class NearComparator { return primitive_util::FloatingPointTypeSwitch( [&](const auto kType) -> int { using NarrowNativeT = primitive_util::NativeTypeOf; - return CalculateDistanceInFloats(NarrowNativeT(expected), - NarrowNativeT(actual)); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case + // e3m4 by casting to half first. + if constexpr (std::is_same_v) { + return CalculateDistanceInFloats(NarrowNativeT(half(expected)), + NarrowNativeT(half(actual))); + } else { + return CalculateDistanceInFloats(NarrowNativeT(expected), + NarrowNativeT(actual)); + } }, error_.low_precision_fp_error_spec.type); } diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 37b7c31f267104..7713aceaaa3bc5 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -29,14 +29,15 @@ namespace { template class LiteralComparisonTest : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); - TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), + TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -44,15 +45,19 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 10.0 : 9.0; + float expV = 9.0; // F8E4M3* + if (type == F8E5M2) + expV = 10.0; + else if (type == F8E3M4) + expV = 8.5; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -60,17 +65,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 14.0 : 12.0; + float expV = 12.0; // F8E4M3* + if (type == F8E5M2) + expV = 14.0; + else if (type == F8E3M4) + expV = 10.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -78,17 +87,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(8.0); - float expV = type == F8E5M2 ? 13.0 : 12.1; + float expV = 12.1; // F8E4M3* + if (type == F8E5M2) + expV = 13.0; + else if (type == F8E3M4) + expV = 10.125; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 767fe581121db3..65aa09040668fb 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -125,9 +125,10 @@ template class LiteralUtilFloatTest : public LiteralUtilTest {}; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -184,8 +185,12 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { EXPECT_EQ("f8e5m2[] 3", f8e5m2_lit_truncated.ToString()); auto f8e4m3_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3(0.5)); + EXPECT_EQ("f8e4m3[] 0.5", f8e4m3_lit.ToString()); + + auto f8e4m3fn_lit = LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); - EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); + EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3fn_lit.ToString()); auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0( tsl::float8_e4m3b11fnuz(0.5)); @@ -198,6 +203,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e5m2fnuz_lit = LiteralUtil::CreateR0(tsl::float8_e5m2fnuz(0.5)); EXPECT_EQ("f8e5m2fnuz[] 0.5", f8e5m2fnuz_lit.ToString()); + + auto f8e3m4_lit = + LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); + EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -650,20 +659,24 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); - tsl::float8_e5m2 q16(8); - EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(8)); + tsl::float8_e5m2 p16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false - EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(9)); + EXPECT_FALSE(LiteralUtil::CreateR1({p16}).IsAll(9)); - tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 + tsl::float8_e4m3 q16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(9)); + + tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3fn EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); - tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3 + tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3b11fnuz EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); - tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 + tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3fnuz EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({t16}).IsAll(9)); @@ -672,6 +685,10 @@ TEST_F(LiteralUtilTest, IsAll) { // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false EXPECT_FALSE(LiteralUtil::CreateR1({u16}).IsAll(9)); + tsl::float8_e3m4 v16(9); // Exactly representable in e3m4 + EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2200,9 +2217,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); - using e4 = tsl::float8_e4m3fn; + using e4 = tsl::float8_e4m3; auto vector_f8e4m3 = LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); + using e4fn = tsl::float8_e4m3fn; + auto vector_f8e4m3fn = + LiteralUtil::CreateR1({e4fn{10.0}, e4fn{20.0}, e4fn{-32.0}}); using b11 = tsl::float8_e4m3b11fnuz; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); @@ -2212,6 +2232,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e4f = tsl::float8_e4m3fnuz; auto vector_f8e4m3fnuz = LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); + using e3 = tsl::float8_e3m4; + auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2234,9 +2256,11 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2521,6 +2545,18 @@ TEST_F(LiteralUtilTest, IsEqualAt) { tsl::float8_e4m3fnuz{val_double}); EXPECT_TRUE(c6.IsEqualAt({}, val_double)); EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); + Literal c8 = + LiteralUtil::CreateR0(tsl::float8_e4m3{val_double}); + EXPECT_TRUE(c8.IsEqualAt({}, val_double)); + EXPECT_TRUE(c8.IsEqualAt({}, val_integral)); + Literal c9 = + LiteralUtil::CreateR0(tsl::float8_e4m3fn{val_double}); + EXPECT_TRUE(c9.IsEqualAt({}, val_double)); + EXPECT_TRUE(c9.IsEqualAt({}, val_integral)); + Literal c10 = + LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); + EXPECT_TRUE(c10.IsEqualAt({}, val_double)); + EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2846,10 +2882,10 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, - F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 59b19c34611412..2581390a1e13d7 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -34,6 +34,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getI1Type(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); + case xla::PrimitiveType::F8E4M3: + return b.getFloat8E4M3Type(); case xla::PrimitiveType::F8E4M3FN: return b.getFloat8E4M3FNType(); case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -42,6 +44,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E5M2FNUZType(); case xla::PrimitiveType::F8E4M3FNUZ: return b.getFloat8E4M3FNUZType(); + case xla::PrimitiveType::F8E3M4: + return b.getFloat8E3M4Type(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -76,6 +80,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; + } else if (type.isFloat8E4M3()) { + return xla::PrimitiveType::F8E4M3; } else if (type.isFloat8E4M3FN()) { return xla::PrimitiveType::F8E4M3FN; } else if (type.isFloat8E4M3B11FNUZ()) { @@ -84,6 +90,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E4M3FNUZ; } else if (type.isFloat8E5M2FNUZ()) { return xla::PrimitiveType::F8E5M2FNUZ; + } else if (type.isFloat8E3M4()) { + return xla::PrimitiveType::F8E3M4; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index 6c19098574dec5..a8043ab0b5f140 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -102,6 +102,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, + {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, {F8E4M3B11FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3B11FNUZType(); }}, @@ -109,6 +110,7 @@ INSTANTIATE_TEST_SUITE_P( [](mlir::Builder b) { return b.getFloat8E5M2FNUZType(); }}, {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, + {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index 30eb4a8fccf14d..e8147577018ee9 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -1147,7 +1147,7 @@ cc_library( srcs = [ "stablehlo_ext/transforms/chlo_recompose_ops.cpp", "stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp", - "stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp", + "stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", ], hdrs = [ diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h index c72a92f112b23d..22c6d637c14091 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h +++ b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h @@ -35,7 +35,7 @@ void createChloLegalizeToStablehloPipeline(OpPassManager &pm); // Expand backward compatibility with the given StableHLO version by decomposing // newer StableHLO operations into equivalent operations supported by that older // version. -std::unique_ptr createStablehloCreateCompatibilityExpanderPass( +std::unique_ptr createStablehloCompatibilityExpanderPass( std::string targetVersionOption); } // namespace stablehlo_ext diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp similarity index 87% rename from xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp rename to xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp index 0db5fd4780b67d..96dc5b8c645b9e 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_create_compatibility_expander.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp @@ -25,9 +25,9 @@ namespace stablehlo_ext { // TODO(b/369406385): remove this method (and file) once issue is resolved. -std::unique_ptr<::mlir::Pass> createStablehloCreateCompatibilityExpanderPass( +std::unique_ptr<::mlir::Pass> createStablehloCompatibilityExpanderPass( std::string targetVersionOption) { - return mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( + return mlir::stablehlo::createStablehloCompatibilityExpanderPass( {std::move(targetVersionOption)}); } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 8baa3e0d3298df..59618001c2d7cc 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 65594f55fd979d..03b6a21e07210c 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @f8e4m3(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e4m3fn(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 0f2e1b108a710f..66c388b9ed373e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 477045dfb93adb..20b660ed6eecfd 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,8 @@ # PJRT C API changelog +## 0.55 +* Added types F8E4M3 and F8E3M4. + ## 0.54 * Deprecated PJRT_Buffer_GetMemoryLayout. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 1d5b44c60201c5..a96f35920b9fa1 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 54 +#define PJRT_API_MINOR 55 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -644,6 +644,10 @@ typedef enum { // 2-bit integer types PJRT_Buffer_Type_S2, PJRT_Buffer_Type_U2, + + // More truncated 8 bit floating-point formats. + PJRT_Buffer_Type_F8E4M3, + PJRT_Buffer_Type_F8E3M4, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index b9508cf24950b4..5121877e1f4dd0 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -295,6 +295,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; + case xla::PrimitiveType::F8E4M3: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3; case xla::PrimitiveType::F8E4M3FN: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN; case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -303,6 +305,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2FNUZ; case xla::PrimitiveType::F8E4M3FNUZ: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; + case xla::PrimitiveType::F8E3M4: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -358,6 +362,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C128; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: + return xla::PrimitiveType::F8E4M3; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN: return xla::PrimitiveType::F8E4M3FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ: @@ -366,6 +372,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E5M2FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ: return xla::PrimitiveType::F8E4M3FNUZ; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: + return xla::PrimitiveType::F8E3M4; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/pjrt/mlir_to_hlo.cc b/xla/pjrt/mlir_to_hlo.cc index 393377a73f157c..a7adeedd9203de 100644 --- a/xla/pjrt/mlir_to_hlo.cc +++ b/xla/pjrt/mlir_to_hlo.cc @@ -199,7 +199,7 @@ absl::StatusOr SerializeUsingVersionedStablehlo( pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( - mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( + mlir::stablehlo::createStablehloCompatibilityExpanderPass( {std::string(target)})); pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 8fbeedbff94dad..de5ee4fde11d7b 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -180,6 +180,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E4M3; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FN; @@ -200,6 +205,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FNUZ; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E3M4; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -309,6 +319,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e4m3; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fn; @@ -329,6 +344,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fnuz; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e3m4; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -362,8 +382,9 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { } constexpr bool IsF8Type(PrimitiveType type) { - return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ || - type == F8E5M2FNUZ || type == F8E4M3FNUZ; + return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || + type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || + type == F8E3M4; } constexpr bool IsFloatingPointType(PrimitiveType type) { @@ -428,6 +449,12 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F8E3M4: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E4M3: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E4M3FN: return std::forward(f)( PrimitiveTypeConstant()); diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index e8c9dc77087062..850203f17379a4 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -76,10 +76,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][BF16] = true; expecteds[PRED][C128] = true; expecteds[PRED][F8E5M2] = true; + expecteds[PRED][F8E4M3] = true; expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = true; expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = true; + expecteds[PRED][F8E3M4] = true; expecteds[S2][PRED] = false; expecteds[S2][S2] = true; expecteds[S2][S4] = true; @@ -100,10 +102,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][BF16] = true; expecteds[S2][C128] = true; expecteds[S2][F8E5M2] = true; + expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; expecteds[S2][F8E4M3B11FNUZ] = true; expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; + expecteds[S2][F8E3M4] = true; expecteds[S4][PRED] = false; expecteds[S4][S2] = false; expecteds[S4][S4] = true; @@ -124,10 +128,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][BF16] = true; expecteds[S4][C128] = true; expecteds[S4][F8E5M2] = true; + expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; expecteds[S4][F8E4M3B11FNUZ] = true; expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; + expecteds[S4][F8E3M4] = true; expecteds[S8][PRED] = false; expecteds[S8][S2] = false; expecteds[S8][S4] = false; @@ -148,10 +154,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][BF16] = true; expecteds[S8][C128] = true; expecteds[S8][F8E5M2] = false; + expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; expecteds[S8][F8E4M3B11FNUZ] = false; expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; + expecteds[S8][F8E3M4] = false; expecteds[S16][PRED] = false; expecteds[S16][S2] = false; expecteds[S16][S4] = false; @@ -172,10 +180,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][BF16] = false; expecteds[S16][C128] = true; expecteds[S16][F8E5M2] = false; + expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; expecteds[S16][F8E4M3B11FNUZ] = false; expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; + expecteds[S16][F8E3M4] = false; expecteds[S32][PRED] = false; expecteds[S32][S2] = false; expecteds[S32][S4] = false; @@ -196,10 +206,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][BF16] = false; expecteds[S32][C128] = true; expecteds[S32][F8E5M2] = false; + expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; expecteds[S32][F8E4M3B11FNUZ] = false; expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; + expecteds[S32][F8E3M4] = false; expecteds[S64][PRED] = false; expecteds[S64][S2] = false; expecteds[S64][S4] = false; @@ -220,10 +232,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][BF16] = false; expecteds[S64][C128] = false; expecteds[S64][F8E5M2] = false; + expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; expecteds[S64][F8E4M3B11FNUZ] = false; expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; + expecteds[S64][F8E3M4] = false; expecteds[U2][PRED] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; @@ -246,10 +260,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][BF16] = true; expecteds[U2][C128] = true; expecteds[U2][F8E5M2] = true; + expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; expecteds[U2][F8E4M3B11FNUZ] = true; expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; + expecteds[U2][F8E3M4] = true; expecteds[U4][PRED] = false; expecteds[U4][S2] = false; expecteds[U4][S4] = false; @@ -272,10 +288,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][BF16] = true; expecteds[U4][C128] = true; expecteds[U4][F8E5M2] = false; + expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; expecteds[U4][F8E4M3B11FNUZ] = true; expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; + expecteds[U4][F8E3M4] = true; expecteds[U8][PRED] = false; expecteds[U8][S2] = false; expecteds[U8][S4] = false; @@ -298,10 +316,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][BF16] = true; expecteds[U8][C128] = true; expecteds[U8][F8E5M2] = false; + expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; expecteds[U8][F8E4M3B11FNUZ] = false; expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; + expecteds[U8][F8E3M4] = false; expecteds[U16][PRED] = false; expecteds[U16][S2] = false; expecteds[U16][S4] = false; @@ -322,10 +342,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][BF16] = false; expecteds[U16][C128] = true; expecteds[U16][F8E5M2] = false; + expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; expecteds[U16][F8E4M3B11FNUZ] = false; expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; + expecteds[U16][F8E3M4] = false; expecteds[U32][PRED] = false; expecteds[U32][S2] = false; expecteds[U32][S4] = false; @@ -346,10 +368,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][BF16] = false; expecteds[U32][C128] = true; expecteds[U32][F8E5M2] = false; + expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; expecteds[U32][F8E4M3B11FNUZ] = false; expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; + expecteds[U32][F8E3M4] = false; expecteds[U64][PRED] = false; expecteds[U64][S2] = false; expecteds[U64][S4] = false; @@ -370,10 +394,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][BF16] = false; expecteds[U64][C128] = false; expecteds[U64][F8E5M2] = false; + expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; expecteds[U64][F8E4M3B11FNUZ] = false; expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; + expecteds[U64][F8E3M4] = false; expecteds[F16][PRED] = false; expecteds[F16][S2] = false; expecteds[F16][S4] = false; @@ -394,10 +420,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][BF16] = false; expecteds[F16][C128] = true; expecteds[F16][F8E5M2] = false; + expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; expecteds[F16][F8E4M3B11FNUZ] = false; expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; + expecteds[F16][F8E3M4] = false; expecteds[F32][PRED] = false; expecteds[F32][S2] = false; expecteds[F32][S4] = false; @@ -418,10 +446,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][BF16] = false; expecteds[F32][C128] = true; expecteds[F32][F8E5M2] = false; + expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; expecteds[F32][F8E4M3B11FNUZ] = false; expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; + expecteds[F32][F8E3M4] = false; expecteds[F64][PRED] = false; expecteds[F64][S2] = false; expecteds[F64][S4] = false; @@ -442,10 +472,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][BF16] = false; expecteds[F64][C128] = true; expecteds[F64][F8E5M2] = false; + expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; expecteds[F64][F8E4M3B11FNUZ] = false; expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; + expecteds[F64][F8E3M4] = false; expecteds[C64][PRED] = false; expecteds[C64][S2] = false; expecteds[C64][S4] = false; @@ -466,10 +498,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][BF16] = false; expecteds[C64][C128] = true; expecteds[C64][F8E5M2] = false; + expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; expecteds[C64][F8E4M3B11FNUZ] = false; expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; + expecteds[C64][F8E3M4] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S2] = false; expecteds[BF16][S4] = false; @@ -490,10 +524,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; expecteds[BF16][F8E5M2] = false; + expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; expecteds[BF16][F8E4M3B11FNUZ] = false; expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; + expecteds[BF16][F8E3M4] = false; expecteds[C128][PRED] = false; expecteds[C128][S2] = false; expecteds[C128][S4] = false; @@ -514,10 +550,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][BF16] = false; expecteds[C128][C128] = true; expecteds[C128][F8E5M2] = false; + expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; expecteds[C128][F8E4M3B11FNUZ] = false; expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; + expecteds[C128][F8E3M4] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S2] = false; expecteds[F8E5M2][S4] = false; @@ -538,10 +576,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; expecteds[F8E5M2][F8E5M2] = true; + expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; expecteds[F8E5M2][F8E4M3B11FNUZ] = false; expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; + expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E4M3][PRED] = false; + expecteds[F8E4M3][S2] = false; + expecteds[F8E4M3][S4] = false; + expecteds[F8E4M3][S8] = false; + expecteds[F8E4M3][S16] = false; + expecteds[F8E4M3][S32] = false; + expecteds[F8E4M3][S64] = false; + expecteds[F8E4M3][U2] = false; + expecteds[F8E4M3][U4] = false; + expecteds[F8E4M3][U8] = false; + expecteds[F8E4M3][U16] = false; + expecteds[F8E4M3][U32] = false; + expecteds[F8E4M3][U64] = false; + expecteds[F8E4M3][F16] = true; + expecteds[F8E4M3][F32] = true; + expecteds[F8E4M3][F64] = true; + expecteds[F8E4M3][C64] = true; + expecteds[F8E4M3][BF16] = true; + expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F8E5M2] = false; + expecteds[F8E4M3][F8E5M2FNUZ] = false; + expecteds[F8E4M3][F8E4M3] = true; + expecteds[F8E4M3][F8E4M3FN] = false; + expecteds[F8E4M3][F8E4M3FNUZ] = false; + expecteds[F8E4M3][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3][F8E3M4] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; @@ -562,8 +628,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; expecteds[F8E4M3FN][F8E5M2] = false; + expecteds[F8E4M3FN][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E4M3] = false; expecteds[F8E4M3FN][F8E4M3FN] = true; + expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3FN][F8E3M4] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S2] = false; expecteds[F8E4M3B11FNUZ][S4] = false; @@ -584,12 +654,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; + expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; expecteds[F8E4M3B11FNUZ][F8E4M3B11FNUZ] = true; expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E4M3FNUZ] = false; + expecteds[F8E4M3B11FNUZ][F8E3M4] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; @@ -610,10 +680,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; expecteds[F8E5M2FNUZ][F8E5M2] = false; + expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; expecteds[F8E5M2FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; + expecteds[F8E5M2FNUZ][F8E3M4] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S2] = false; expecteds[F8E4M3FNUZ][S4] = false; @@ -634,10 +706,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; expecteds[F8E4M3FNUZ][F8E5M2] = false; + expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; expecteds[F8E4M3FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; + expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E3M4][PRED] = false; + expecteds[F8E3M4][S2] = false; + expecteds[F8E3M4][S4] = false; + expecteds[F8E3M4][S8] = false; + expecteds[F8E3M4][S16] = false; + expecteds[F8E3M4][S32] = false; + expecteds[F8E3M4][S64] = false; + expecteds[F8E3M4][U2] = false; + expecteds[F8E3M4][U4] = false; + expecteds[F8E3M4][U8] = false; + expecteds[F8E3M4][U16] = false; + expecteds[F8E3M4][U32] = false; + expecteds[F8E3M4][U64] = false; + expecteds[F8E3M4][F16] = true; + expecteds[F8E3M4][F32] = true; + expecteds[F8E3M4][F64] = true; + expecteds[F8E3M4][C64] = true; + expecteds[F8E3M4][BF16] = true; + expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F8E5M2] = false; + expecteds[F8E3M4][F8E5M2FNUZ] = false; + expecteds[F8E3M4][F8E4M3] = false; + expecteds[F8E3M4][F8E4M3FN] = false; + expecteds[F8E3M4][F8E4M3FNUZ] = false; + expecteds[F8E3M4][F8E4M3B11FNUZ] = false; + expecteds[F8E3M4][F8E3M4] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index 1de5702b6cc8df..17e2cfa281d251 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -37,6 +37,8 @@ std::optional DType::byte_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -78,6 +80,8 @@ std::optional DType::bit_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -133,6 +137,9 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -175,6 +182,9 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 06a92b67f863c8..911702512c0501 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -78,13 +78,15 @@ class DType { // dtype will have empty dimensions. kToken = 17, + kF8E3M4 = 29, + kF8E4M3 = 28, kF8E4M3FN = 20, kF8E4M3B11FNUZ = 23, kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, - // Next = 26 + // Next = 30 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index eadfd42a3550cd..37976833e7e8c7 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -60,6 +60,8 @@ message DTypeProto { // dtype will have empty dimensions. KIND_TOKEN = 17; + KIND_F8E3M4 = 29; + KIND_F8E4M3 = 28; KIND_F8E4M3FN = 20; KIND_F8E4M3B11FNUZ = 23; KIND_F8E4M3FNUZ = 25; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 5ac531dabcb9ce..57fec6702d277d 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -49,6 +49,8 @@ TEST(DTypeTest, ByteSize) { {DType::kPred, 1}, {DType::kS8, 1}, {DType::kU8, 1}, + {DType::kF8E3M4, 1}, + {DType::kF8E4M3, 1}, {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, {DType::kF8E4M3FNUZ, 1}, @@ -85,6 +87,8 @@ TEST(DTypeTest, BitSize) { {DType::kPred, 8}, {DType::kS8, 8}, {DType::kU8, 8}, + {DType::kF8E3M4, 8}, + {DType::kF8E4M3, 8}, {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, {DType::kF8E4M3FNUZ, 8}, diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 36d492f27569a9..10a293778bd467 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,6 +44,8 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); + CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); @@ -80,6 +82,8 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F8E3M4: + case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: case xla::PrimitiveType::F8E4M3B11FNUZ: case xla::PrimitiveType::F8E4M3FNUZ: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index db7a38d0a50363..9a9c63a922e90d 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -185,6 +185,12 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3FN; @@ -394,6 +400,14 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = @@ -583,6 +597,9 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 4a1de389cd5b5d..125f96a75fdf25 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -59,6 +59,8 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float8_e3m4; + std::optional float8_e4m3; nb_dtype float8_e4m3fn; nb_dtype float8_e4m3b11fnuz; nb_dtype float8_e4m3fnuz; @@ -75,6 +77,12 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->float8_e4m3 = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3")); + } dtypes->float8_e4m3fn = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); dtypes->float8_e5m2 = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2")); @@ -140,6 +148,12 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float8_e3m4.has_value()) { + map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); + } + if (custom_dtypes.float8_e4m3.has_value()) { + map->emplace(*custom_dtypes.float8_e4m3, F8E4M3); + } map->emplace(custom_dtypes.float8_e4m3fn, F8E4M3FN); map->emplace(custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); @@ -204,6 +218,16 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case F8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case F8E4M3FN: return custom_dtypes.float8_e4m3fn; case F8E4M3B11FNUZ: @@ -284,6 +308,16 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case ifrt::DType::kF8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case ifrt::DType::kF8E4M3FN: return custom_dtypes.float8_e4m3fn; case ifrt::DType::kF8E4M3B11FNUZ: @@ -347,6 +381,12 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->np_float8_e4m3 = nb::object(ml_dtypes.attr("float8_e4m3")); + } dtypes->np_float8_e4m3fn = nb::object(ml_dtypes.attr("float8_e4m3fn")); dtypes->np_float8_e4m3b11fnuz = nb::object(ml_dtypes.attr("float8_e4m3b11fnuz")); diff --git a/xla/python/types.h b/xla/python/types.h index ed7ca847b1a7f7..fece926edd3017 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -79,6 +79,9 @@ struct NumpyScalarTypes { nanobind::object np_uint32; nanobind::object np_uint64; nanobind::object np_bfloat16; + // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float8_e3m4; + std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; nanobind::object np_float8_e4m3b11fnuz; nanobind::object np_float8_e4m3fnuz; @@ -128,7 +131,6 @@ nanobind::tuple SpanToNbTuple(absl::Span xs) { // references to the objects. nanobind::tuple MutableSpanToNbTuple(absl::Span xs); - template std::vector IterableToVector(const nanobind::iterable& iterable) { std::vector output; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 868a3aa9d74016..70c8c90c0a04e4 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -202,6 +202,9 @@ NB_MODULE(xla_extension, m_nb) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F8E3M4", F8E3M4) + // .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 5cc12efa93709b..51d879814b2e68 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -274,6 +274,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = ml_dtypes.float8_e3m4 +# float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -292,6 +295,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 898a632ab340a7..5a1df08e736f64 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -59,6 +59,9 @@ _version: int mlir_api_version: int bfloat16: type[numpy.generic] +# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4: type[numpy.generic] +# float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index edd582e405c265..441d5fbf450fa4 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -54,6 +54,9 @@ xla_client._xla.jax_jit.global_state().enable_memories = False bfloat16 = xla_client.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = xla_client.float8_e3m4 +# float8_e4m3 = xla_client.float8_e4m3 float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -150,6 +153,8 @@ def TestFactory(xla_backend, # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # standard_dtypes += [float8_e3m4, float8_e4m3] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index b5ae4c6431ca66..e363d8d82471cb 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -73,6 +73,8 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType F8E4M3B11FNUZ: PrimitiveType F8E4M3FNUZ: PrimitiveType diff --git a/xla/service/BUILD b/xla/service/BUILD index 5f1c0407d99fd0..796215cb8a735e 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2572,6 +2572,7 @@ cc_library( srcs = ["triangular_solve_expander.cc"], hdrs = ["triangular_solve_expander.h"], deps = [ + ":hlo_creation_utils", ":hlo_module_config", ":op_expander_pass", "//xla:shape_util", @@ -2614,6 +2615,7 @@ cc_library( srcs = ["cholesky_expander.cc"], hdrs = ["cholesky_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal", "//xla:shape_util", @@ -2637,6 +2639,7 @@ cc_library( srcs = ["qr_expander.cc"], hdrs = ["qr_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal", "//xla:shape_util", @@ -2692,6 +2695,7 @@ cc_library( srcs = ["eigh_expander.cc"], hdrs = ["eigh_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal_util", "//xla:shape_util", @@ -3068,6 +3072,7 @@ cc_library( srcs = ["bitcast_dtypes_expander.cc"], hdrs = ["bitcast_dtypes_expander.h"], deps = [ + ":hlo_creation_utils", ":hlo_module_config", ":op_expander_pass", "//xla:literal_util", @@ -5917,6 +5922,7 @@ cc_library( "//xla:xla_data_proto_cc", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:statusor", ], ) @@ -5971,22 +5977,21 @@ xla_test( deps = [ ":elemental_ir_emitter", ":hlo_module_config", - ":hlo_parser", "//xla:error_spec", - "//xla:execution_options_util", "//xla:literal", "//xla:literal_util", - "//xla:status_macros", "//xla:test", + "//xla:types", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:ir_array", - "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", ], ) @@ -7477,6 +7482,7 @@ cc_library( srcs = ["rng_bit_generator_expander.cc"], hdrs = ["rng_bit_generator_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:shape_util", "//xla:util", @@ -7591,6 +7597,7 @@ cc_library( srcs = ["topk_rewriter.cc"], hdrs = ["topk_rewriter.h"], deps = [ + ":hlo_creation_utils", ":pattern_matcher", "//xla:shape_util", "//xla:util", diff --git a/xla/service/bitcast_dtypes_expander.cc b/xla/service/bitcast_dtypes_expander.cc index 6a5d12c25c32fa..e12f69f56bc4e2 100644 --- a/xla/service/bitcast_dtypes_expander.cc +++ b/xla/service/bitcast_dtypes_expander.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -104,14 +105,8 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( BitcastConvertType(input, to_shape.element_type()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, b.Build()); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/cholesky_expander.cc b/xla/service/cholesky_expander.cc index d70e0211103fff..d5a3053e7168a3 100644 --- a/xla/service/cholesky_expander.cc +++ b/xla/service/cholesky_expander.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -248,15 +249,8 @@ absl::StatusOr CholeskyExpander::ExpandInstruction( MaybeTransposeInMinorDims(l, !options.lower()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 1a77e4ffb8eb13..1b5da63773d85f 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -600,6 +600,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( #endif FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); + FloatSupport f8e4m3_support(F8E4M3, F16); + pipeline.AddPass(&f8e4m3_support); FloatSupport f8e4m3fn_support(F8E4M3FN, F16); pipeline.AddPass(&f8e4m3fn_support); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); @@ -608,6 +610,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&f8e5m2fnuz_support); FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); pipeline.AddPass(&f8e4m3fnuz_support); + FloatSupport f8e3m4_support(F8E3M4, F16); + pipeline.AddPass(&f8e3m4_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index c0c956a32dc0b1..2fef54861722f1 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -71,7 +71,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ + // F8E4M3B11FNUZ, F8E4M3, F8E3M4 default: return dt::undef; } diff --git a/xla/service/eigh_expander.cc b/xla/service/eigh_expander.cc index e95b268c1f3d8b..e34dbbf96d22c0 100644 --- a/xla/service/eigh_expander.cc +++ b/xla/service/eigh_expander.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -582,15 +583,8 @@ absl::StatusOr EighExpander::ExpandInstruction( } XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index aff56f92b15601..b04e4e554a8a8e 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -220,6 +220,90 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +template +llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, + llvm::Value* f8_bits, + llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + // F16 values that are halfway between denormal F8 values. This is used to + // determine how to round to denormal F8 values. + const int halfway_points_e4[8] = { + 0x1400, // 0x1.0p-10 ; halfway between [0/8 * 2^-6, 1/8 * 2^-6] + 0x1A00, // 0x1.8p-9 ; halfway between [1/8 * 2^-6, 2/8 * 2^-6] + 0x1D00, // 0x1.4p-8 ; halfway between [2/8 * 2^-6, 3/8 * 2^-6] + 0x1F00, // 0x1.Cp-8 ; halfway between [3/8 * 2^-6, 4/8 * 2^-6] + 0x2080, // 0x1.2p-7 ; halfway between [4/8 * 2^-6, 5/8 * 2^-6] + 0x2180, // 0x1.6p-7 ; halfway between [5/8 * 2^-6, 6/8 * 2^-6] + 0x2280, // 0x1.Ap-7 ; halfway between [6/8 * 2^-6, 7/8 * 2^-6] + 0x2380, // 0x1.Ep-7 ; halfway between [7/8 * 2^-6, 8/8 * 2^-6] + }; + + const int halfway_points_e3[16] = { + 0x2000, // 0x1.0p-7; halfway between [0/16 * 2^-2, 1/16 * 2^-2] + 0x2600, // 0x1.8p-6; halfway between [1/16 * 2^-2, 2/16 * 2^-2] + 0x2900, // 0x1.4p-5; halfway between [2/16 * 2^-2, 3/16 * 2^-2] + 0x2B00, // 0x1.Cp-5; halfway between [3/16 * 2^-2, 4/16 * 2^-2] + 0x2C80, // 0x1.2p-4; halfway between [4/16 * 2^-2, 5/16 * 2^-2] + 0x2D80, // 0x1.6p-4; halfway between [5/16 * 2^-2, 6/16 * 2^-2] + 0x2E80, // 0x1.Ap-4; halfway between [6/16 * 2^-2, 7/16 * 2^-2] + 0x2F80, // 0x1.Ep-4; halfway between [7/16 * 2^-2, 8/16 * 2^-2] + 0x3040, // 0x1.1p-3; halfway between [8/16 * 2^-2, 9/16 * 2^-2] + 0x30C0, // 0x1.3p-3; halfway between [9/16 * 2^-2, 10/16 * 2^-2] + 0x3140, // 0x1.5p-3; halfway between [10/16 * 2^-2, 11/16 * 2^-2] + 0x31C0, // 0x1.7p-3; halfway between [11/16 * 2^-2, 12/16 * 2^-2] + 0x3240, // 0x1.9p-3; halfway between [12/16 * 2^-2, 13/16 * 2^-2] + 0x32C0, // 0x1.Bp-3; halfway between [13/16 * 2^-2, 14/16 * 2^-2] + 0x3340, // 0x1.Dp-3; halfway between [14/16 * 2^-2, 15/16 * 2^-2] + 0x33C0, // 0x1.Fp-3; halfway between [15/16 * 2^-2, 16/16 * 2^-2] + }; + + const int* halfway_points; + int arr_sz; + if constexpr (f8_exponent_bits == 4) { + halfway_points = halfway_points_e4; + arr_sz = 8; + } else if constexpr (f8_exponent_bits == 3) { + halfway_points = halfway_points_e3; + arr_sz = 16; + } + + // Handle case where output is denormal. If we're rounding to a denormal + // value, ignore the current value of f8_bits and set it to the correct + // denormal value. We emit the equivalent of the following: + // + // if (f16_abs_bits <= halfway_points[0]) { + // f8_bits = 0; + // } else if (f16_abs_bits < halfway_points[1]) { + // f8_bits = 1; + // } else if (f16_abs_bits <= halfway_points[2]) { + // ... // More if-else statements. The comparisons alternate between <= + // ... // and < to handle round-to-even properly. + // } else if (f16_abs_bits < halfway_points[7]) { + // f8_bits = 7; + // } + for (int i = arr_sz - 1; i >= 0; i--) { + Value* comparison; + if (i % 2 == 0) { + comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); + } else { + comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); + } + f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); + } + return f8_bits; +} + absl::StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, llvm::IRBuilder<>* b) { TF_ASSIGN_OR_RETURN( @@ -242,6 +326,223 @@ llvm::Value* EmitF8e5m2ToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { return b->CreateBitCast(shifted, b->getHalfTy()); } +template +absl::StatusOr EmitF16ToF8e(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f16_as_int = bitcast(f16_value, int) + // f16_abs_bits = f16_as_int & 0x7FFF + Value* f16_as_int = b->CreateBitCast(f16_value, i16_type); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_as_int, i16_const(0x7FFF)); + + // Get the sign. + // f8_sign = (f16_as_int & 0x8000) >> 8 + Value* f16_sign = b->CreateAnd(f16_as_int, i16_const(0x8000)); + f16_sign = b->CreateLShr(f16_sign, i16_const(8)); + Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); + + // Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits + // Denormal values are not handled properly here and are + // dealt with later in this function. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/f8_exponent_bits, + /*dest_mantissa_bits=*/f8_mantissa_bits, + /*quiet_nans=*/true, b); + CHECK_OK(f16_reduced_statusor.status()); // Crash OK + Value* f16_reduced = f16_reduced_statusor.value(); + f16_reduced = b->CreateBitCast(f16_reduced, i16_type); + + // Remove the sign bit. + // f16_reduced = f16_reduced & 0x7FFF + f16_reduced = b->CreateAnd(f16_reduced, i16_const(0x7FFF)); + + // F16 inf in binary: 0 11111 0000000000 + constexpr int f16_inf_value = 0x7C00; + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int min_normal_value = (exponent_bias_difference + 1) + << f16_mantissa_bits; + + // Round values smaller than the smallest F8 normal value up to the smallest + // F8 normal value. The case where we round to a denormal value is handled + // later. + // f16_reduced = max(f16_reduced, min_normal_value) + f16_reduced = b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(min_normal_value)), + i16_const(min_normal_value), f16_reduced); + + // Adjust the exponent by subtracting the difference in exponent bias: + // f16_reduced -= (exponent_bias_difference << f16_mantissa_bits) + // For infinity/NaN values, subtract twice the difference in exponent bias + // to ensure the leading exponent bit(s) of f16_reduced are set to zero. + f16_reduced = b->CreateSub( + f16_reduced, + b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(f16_inf_value)), + i16_const(exponent_bias_difference << f16_mantissa_bits), + i16_const(exponent_bias_difference << (f16_mantissa_bits + 1)))); + + // Shift to convert to F8. + // f16_reduced = f16_reduced >> mantissa_bits_difference; + f16_reduced = b->CreateLShr(f16_reduced, i16_const(mantissa_bits_difference)); + + Value* f8_bits = b->CreateTrunc(f16_reduced, i8_type); + + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + + // Set the sign bit. + // f8_bits |= f8_sign + f8_bits = b->CreateOr(f8_bits, f8_sign); + return f8_bits; +} + +template +llvm::Value* EmitToF16F8e(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e4[8] = { + 0x0000, // 0 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e3[16] = { + 0x0000, // 0 + 0x2400, // 1/16 * 2^-2 + 0x2800, // 2/16 * 2^-2 + 0x2A00, // 3/16 * 2^-2 + 0x2C00, // 4/16 * 2^-2 + 0x2D00, // 5/16 * 2^-2 + 0x2E00, // 6/16 * 2^-2 + 0x2F00, // 7/16 * 2^-2 + 0x3000, // 8/16 * 2^-2 + 0x3080, // 9/16 * 2^-2 + 0x3100, // 10/16 * 2^-2 + 0x3180, // 11/16 * 2^-2 + 0x3200, // 12/16 * 2^-2 + 0x3280, // 13/16 * 2^-2 + 0x3300, // 14/16 * 2^-2 + 0x3380, // 15/16 * 2^-2 + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f8_as_int = bitcast(f16_value, int) + // f8_abs_bits = f8_as_int & 0x7F + Value* f8_as_int = b->CreateBitCast(f8_value, i8_type); + Value* f8_abs_bits = b->CreateAnd(f8_as_int, i8_const(0x7F)); + + // We assume below that the value is neither NaN nor denormal. If it NaN or + // denormal, the output is set to NaN or zero at the end using Select + // instructions. + + // Get the sign: + // f16_sign = (f8_as_int & 0x80) << 8 + Value* f8_sign = b->CreateAnd(f8_as_int, i8_const(0x80)); + Value* f16_sign = b->CreateZExt(f8_sign, i16_type); + f16_sign = b->CreateShl(f16_sign, i16_const(8)); + + int exponent_mask; + const int* f8_denormal_to_f16; + int f8_denormal_size; + if constexpr (f8_exponent_bits == 4) { + exponent_mask = 0x78; + f8_denormal_to_f16 = f8_denormal_to_f16_e4; + f8_denormal_size = 8; + } else if constexpr (f8_exponent_bits == 3) { + exponent_mask = 0x70; + f8_denormal_to_f16 = f8_denormal_to_f16_e3; + f8_denormal_size = 16; + } + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int f8_mantissa_mask = (1 << f8_mantissa_bits) - 1; + + // Get the exponent: + // f8_exponent = (f8_as_int & exponent_mask) >> f8_mantissa_bits + Value* f8_exponent_bits_v = b->CreateAnd(f8_as_int, i8_const(exponent_mask)); + Value* f8_exponent = + b->CreateLShr(f8_exponent_bits_v, i8_const(f8_mantissa_bits)); + + // Adjust the exponent by adding the difference in exponent bias: + // f16_exponent = (f8_exponent + exponent_bias_difference) + // << f16_mantissa_bits + Value* f16_exponent = + b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); + f16_exponent = b->CreateZExt(f16_exponent, i16_type); + f16_exponent = b->CreateShl(f16_exponent, i16_const(f16_mantissa_bits)); + + // Set output exponent to 11111 if input exponent is 111 (Inf or NaN) + // 0.11111.0000000000 is 0x7C00 + Value* is_exp_1111 = + b->CreateICmpEQ(f8_exponent_bits_v, i8_const(exponent_mask)); + f16_exponent = b->CreateSelect(is_exp_1111, i16_const(0x7C00), f16_exponent); + + // Get the mantissa: + // f16_mantissa = (f8_mantissa & f8_mantissa_mask) + // << mantissa_bits_difference + Value* f8_mantissa = b->CreateAnd(f8_as_int, i8_const(f8_mantissa_mask)); + Value* f16_mantissa = b->CreateZExt(f8_mantissa, i16_type); + f16_mantissa = + b->CreateShl(f16_mantissa, i16_const(mantissa_bits_difference)); + + // Combine the exponent and mantissa: + // f16_as_int = f16_exponent | f16_mantissa + Value* f16_as_int = b->CreateOr(f16_exponent, f16_mantissa); + + // If the F8 value is denormal, use the map above to determine the correct F16 + // value. + // if (f8_abs_bits < 8) { f16_as_int = f8_denormal_to_f16[f8_abs_bits]; } + for (int i = 0; i < f8_denormal_size; i++) { + Value* is_denormal_value = b->CreateICmpEQ(f8_abs_bits, i8_const(i)); + f16_as_int = b->CreateSelect(is_denormal_value, + i16_const(f8_denormal_to_f16[i]), f16_as_int); + } + + // Set the sign bit. + // f16_as_int |= f16_sign + f16_as_int = b->CreateOr(f16_as_int, f16_sign); + return b->CreateBitCast(f16_as_int, b->getHalfTy()); +} + llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { using llvm::APInt; using llvm::Value; @@ -297,6 +598,7 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { i16_const(min_normal_value), f16_reduced); constexpr int exponent_bias_difference = 15 - 7; + constexpr int f8_exponent_bits = 4; constexpr int f16_mantissa_bits = 10; constexpr int f8_mantissa_bits = 3; constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; @@ -322,42 +624,9 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { b->CreateICmpUGT(f16_abs_bits, i16_const(max_finite_value)), i8_const(0x7F), f8_bits); - // F16 values that are halfway between denormal F8 values. This is used to - // determine how to round to denormal F8 values. - const int halfway_points[8] = { - 0x1400, // 2**-10; halfway between [0, 2**-9] - 0x1A00, // 1.5 * 2**-9; halfway between [2**-9, 2**-8] - 0x1D00, // 1.25 * 2**-8; halfway between [2**-8, 1.5 * 2**-8] - 0x1F00, // 1.75 * 2**-8; halfway between [1.5 * 2**-8, 2**-7] - 0x2080, // 1.125 * 2**-7; halfway between [2**-7, 1.25 * 2**-7] - 0x2180, // 1.375 * 2**-7; halfway between [1.25 * 2**-7, 1.5 * 2**-7] - 0x2280, // 1.625 * 2**-7; halfway between [1.5 * 2**-7, 1.75 * 2**-7] - 0x2380, // 1.875 * 2**-7; halfway between [1.75 * 2**-7, 2**-6] - }; - - // Handle case where output is denormal. If we're rounding to a denormal - // value, ignore the current value of f8_bits and set it to the correct - // denormal value. We emit the equivalent of the following: - // - // if (f16_abs_bits <= halfway_points[0]) { - // f8_bits = 0; - // } else if (f16_abs_bits < halfway_points[1]) { - // f8_bits = 1; - // } else if (f16_abs_bits <= halfway_points[2]) { - // ... // More if-else statements. The comparisons alternate between <= - // ... // and < to handle round-to-even properly. - // } else if (f16_abs_bits < halfway_points[7]) { - // f8_bits = 7; - // } - for (int i = ABSL_ARRAYSIZE(halfway_points) - 1; i >= 0; i--) { - Value* comparison; - if (i % 2 == 0) { - comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); - } else { - comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); - } - f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); - } + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -408,7 +677,7 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { b->CreateLShr(f8_exponent_bits, i8_const(f8_mantissa_bits)); // Adjust the exponent by adding the difference in exponent bias: - // f16_exponent = (f8_exopnent + exponent_bias_difference) + // f16_exponent = (f8_exponent + exponent_bias_difference) // << f16_mantissa_bits Value* f16_exponent = b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); @@ -435,13 +704,13 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { // Map from F8 denormal value to F16 value. int f8_denormal_to_f16[8] = { 0x0000, // 0 - 0x1800, // 2**-9 - 0x1C00, // 2**-8 - 0x1E00, // 1.5 * 2**-8 - 0x2000, // 2**-7 - 0x2100, // 1.25 * 2**-7 - 0x2200, // 1.5 * 2**-7 - 0x2300, // 1.75 * 2**-7 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 }; // If the F8 value is denormal, use the map above to determine the correct F16 @@ -604,6 +873,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E4M3) { + return EmitF16ToF8e<4>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } if (to_type == F8E4M3FN) { return EmitF16ToF8e4m3fn( EmitIntegralToFloating(operand_value, from_type, F16, module_, @@ -623,6 +898,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), to_type, b_); } + if (to_type == F8E3M4) { + return EmitF16ToF8e<3>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } return EmitIntegralToFloating(operand_value, from_type, to_type, module_, b_); } @@ -789,6 +1070,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E4M3) { + TF_RET_CHECK(to_type != F8E4M3); + operand_value = EmitToF16F8e<4>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E4M3FN) { TF_RET_CHECK(to_type != F8E4M3FN); operand_value = EmitF8e4m3fnToF16(operand_value, b_); @@ -817,6 +1106,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E3M4) { + TF_RET_CHECK(to_type != F8E3M4); + operand_value = EmitToF16F8e<3>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (primitive_util::IsComplexType(to_type)) { PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); @@ -844,6 +1141,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e5m2(operand_value, b_); } + if (to_type == F8E4M3) { + // Cast to F16 first. Casts to F8E4M3 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<4>(operand_value, b_); + } if (to_type == F8E4M3FN) { // Cast to F16 first. Casts to F8E4M3FN must be from F16. if (from_type != F16) { @@ -863,6 +1168,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } + if (to_type == F8E3M4) { + // Cast to F16 first. Casts to F8E3M4 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<3>(operand_value, b_); + } if (to_type == PRED) { return b_->CreateZExt( FCmpUNE(operand_value, @@ -1391,6 +1704,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( if (operand_type == F8E5M2) { lhs_value = EmitF8e5m2ToF16(lhs_value, b_); rhs_value = EmitF8e5m2ToF16(rhs_value, b_); + } else if (operand_type == F8E4M3) { + lhs_value = EmitToF16F8e<4>(lhs_value, b_); + rhs_value = EmitToF16F8e<4>(rhs_value, b_); } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); @@ -1401,6 +1717,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( TF_ASSIGN_OR_RETURN( rhs_value, EmitF8fnuzToFloating(operand_type, rhs_value, F16, b_, module_)); + } else if (operand_type == F8E3M4) { + lhs_value = EmitToF16F8e<3>(lhs_value, b_); + rhs_value = EmitToF16F8e<3>(rhs_value, b_); } switch (op->comparison_direction()) { case ComparisonDirection::kEq: diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 7c73dd3a1d0dd7..60c4535909d158 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include #include +#include #include +#include #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -37,6 +39,8 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -89,9 +93,9 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -249,8 +253,10 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { auto tname = this->TypeName(); - if (std::is_same() || - std::is_same()) { + if (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -413,8 +419,10 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, CompareFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { auto tname = this->TypeName(); if (std::is_same() || + std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index fe3a1041933cb5..22916aa084fc47 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/APFloat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Intrinsics.h" #include "xla/primitive_util.h" @@ -39,6 +40,10 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F8E3M4: + return &llvm::APFloat::Float8E3M4(); + case F8E4M3: + return &llvm::APFloat::Float8E4M3(); case F8E4M3B11FNUZ: return &llvm::APFloat::Float8E4M3B11FNUZ(); case F8E4M3FN: @@ -67,6 +72,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, PrimitiveType type) { switch (type) { + case F8E3M4: + case F8E4M3: case F8E4M3B11FNUZ: case F8E4M3FN: case F8E4M3FNUZ: diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index a140d2e933af9a..d62443a7d9ff06 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -144,7 +144,7 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E5M2)); + ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/xla/service/gpu/fusions/transforms/expand_float_ops.cc index aebb7f44608559..e9b9731756f7db 100644 --- a/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -175,12 +175,19 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { } assert(ty.getIntOrFloatBitWidth() == 8); - if (!ty.isFloat8E5M2()) { - // F8E5M2 is the only 8 bit float with infinities. + // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. + if (ty.isFloat8E5M2()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x7C; + } else if (ty.isFloat8E4M3()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x78; + } else if (ty.isFloat8E3M4()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x70; + } else { return b.create(false, b.getI1Type()); } - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x7C; } Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -193,8 +200,12 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { Val bits{b.create(b.getI8Type(), value), &b}; if (ty.isFloat8E5M2()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1100); + } else if (ty.isFloat8E4M3()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1000); } else if (ty.isFloat8E4M3FN()) { return (bits & 0b0111'1111) == 0b0111'1111; + } else if (ty.isFloat8E3M4()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); } return bits == 0x80; } @@ -544,7 +555,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. if (rhs_cst.isZero() && - (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { + (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || + lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { int_value = int_value & 0x7f; constant &= 0x7f; } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index f574d64a2290c4..afb1002cc19461 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1441,19 +1441,23 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Lambdas and related constants: const GpuFloatSupport bf16_support(gpu_version, BF16); const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); + const GpuFloatSupport f8e4m3_support(gpu_version, F8E4M3, F16); const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); + const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); sub_pipeline.AddPass(&bf16_support); sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3_support); sub_pipeline.AddPass(&f8e4m3fn_support); sub_pipeline.AddPass(&f8e4m3b11fnuz_support); sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); + sub_pipeline.AddPass(&f8e3m4_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 406fcd9534a9dc..01b32bf8301d34 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -101,7 +101,8 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = - (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + (output_primitive_type == F8E3M4 || output_primitive_type == F8E4M3 || + output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || output_primitive_type == F8E4M3FNUZ || output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 || output_primitive_type == BF16 || output_primitive_type == F32 || diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index b5d571e4c7be3f..16383324dfb016 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,8 +29,9 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3fn", - "f8e4m3fnuz", "f8e4m3b11fnuz")); + "f8e5m2", "f8e5m2fnuz", "f8e4m3", + "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index ed276412453607..d81d7b0478ecba 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1191,6 +1191,7 @@ cc_library( "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", diff --git a/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 698b8fb73dd579..d35fcd105844d8 100644 --- a/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -96,24 +97,6 @@ static std::vector GetRelevantConvs( return convs; } -// Converts an XlaBuilder into an HloComputation in the same module as -// `sibling_computation`. -// -// Yes, we serialize/deserialize as a proto. :) -static absl::StatusOr BuilderToHloComputation( - XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - // Reshapes `instr` so that it has an extra dimension of size `vect_size` right // after `dim`. static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) { @@ -460,11 +443,11 @@ static absl::StatusOr TryRevectorizeConv( new_conv_result, dnums->output_feature_dimension(), *output_vect_dim, /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim)); + XlaOp root = Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Set the name on the new conv. This is purely cosmetic, but we attempt to // preserve e.g. "cudnn-conv.42" instead of "custom-call.42". @@ -599,11 +582,11 @@ static absl::StatusOr TryVectorizeConv( Collapse(new_conv_result, {dnums->output_feature_dimension(), dnums->output_feature_dimension() + 1}); + XlaOp root = Tuple(&b, {conv_result_collapsed, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Create a tuple and replace the old conv with it! VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString(); diff --git a/xla/service/hlo_creation_utils.cc b/xla/service/hlo_creation_utils.cc index a94e23d21066e5..c9b5d5b2be361e 100644 --- a/xla/service/hlo_creation_utils.cc +++ b/xla/service/hlo_creation_utils.cc @@ -597,12 +597,22 @@ HloInstruction* MaybeMakeTuple(absl::Span operands) { HloInstruction::CreateTuple(operands)); } +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(src_comp.proto(), config)); + HloCloneContext context(dest_module); + return dest_module->DeepCloneComputation(new_module->entry_computation(), + &context); +} + absl::StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module, const OpMetadata* metadata) { CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; - HloComputation* compare_computation; XlaBuilder b("Sort.Compare"); if (metadata != nullptr) { b.SetOpMetadata(*metadata); @@ -612,13 +622,8 @@ absl::StatusOr MakeSortHlo( operand_types[i] = operands[i]->shape().element_type(); } XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module); - compare_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module)); return builder->AddInstruction(HloInstruction::CreateSort( sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } diff --git a/xla/service/hlo_creation_utils.h b/xla/service/hlo_creation_utils.h index 2db4a7045fc0e2..d9599663ea7fea 100644 --- a/xla/service/hlo_creation_utils.h +++ b/xla/service/hlo_creation_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal_util.h" @@ -257,6 +258,11 @@ absl::StatusOr MakeSelectHlo( // instruction with all the operands. Crashes if `operands` is empty. HloInstruction* MaybeMakeTuple(absl::Span operands); +// Creates a HloComputation in the destination module from a builder's +// XlaComputation. +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module); + // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 27630b674d2ce4..036d742c05158e 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -200,9 +200,11 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: case F8E5M2FNUZ: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E4M3FNUZ: + case F8E3M4: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); case BF16: diff --git a/xla/service/qr_expander.cc b/xla/service/qr_expander.cc index 4f79769d7c6bf8..7f32a0bc1628bd 100644 --- a/xla/service/qr_expander.cc +++ b/xla/service/qr_expander.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -551,15 +552,8 @@ absl::StatusOr QrExpander::ExpandInstruction( } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/rng_bit_generator_expander.cc b/xla/service/rng_bit_generator_expander.cc index 0d78762f47b964..88758e3c0b3667 100644 --- a/xla/service/rng_bit_generator_expander.cc +++ b/xla/service/rng_bit_generator_expander.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -86,15 +87,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - HloComputation* new_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * new_computation, + XlaComputationToHloComputation(xla_computation, module)); computation_cache_.emplace(cache_key, new_computation); return new_computation; } diff --git a/xla/service/rng_expander.cc b/xla/service/rng_expander.cc index cbc5a1d4549db9..294916f8fb68a9 100644 --- a/xla/service/rng_expander.cc +++ b/xla/service/rng_expander.cc @@ -111,16 +111,7 @@ absl::StatusOr GetComputationForRng(HloInstruction* rng) { } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloModule* module = rng->GetModule(); - HloCloneContext context(module); - return module->DeepCloneComputation(new_module->entry_computation(), - &context); + return XlaComputationToHloComputation(xla_computation, rng->GetModule()); } } // namespace diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 6c2a9321819535..fccd701ab5277c 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -61,6 +61,7 @@ cc_library( "//xla/service:custom_call_sharding_helper", "//xla/service:dot_as_convolution_util", "//xla/service:flatten_call_graph", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_cse", "//xla/service:hlo_dce", "//xla/service:hlo_lexer", diff --git a/xla/service/spmd/custom_call_handler.cc b/xla/service/spmd/custom_call_handler.cc index dab26f5985a0c5..018a6f7f444337 100644 --- a/xla/service/spmd/custom_call_handler.cc +++ b/xla/service/spmd/custom_call_handler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" #include "xla/service/custom_call_sharding_helper.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/host_memory_offload_annotations.h" @@ -207,13 +208,8 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallTopK( XlaComputation comparator = CreateScalarComparisonComputation( "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module_); - auto compare_computation = - module_->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module_)); // Each partition needs to do TopK separately, thus the base shape for sort // becomes [ceil(batch_size / batch_dim_partition), k * shard_count]. const Shape sort_shape = ShapeUtil::MakeTupleShape( diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc index bb65d436acedbd..d25076f2bc938e 100644 --- a/xla/service/topk_rewriter.cc +++ b/xla/service/topk_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -41,20 +42,6 @@ namespace xla { namespace m = match; -// TODO(cheshire): Avoid duplication w/ cudnn_vectorize_convolutions. -static absl::StatusOr BuilderToHloComputation( - XlaComputation& comp, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - static bool IsNanSafeGt(HloComputation* comp) { namespace m = match; auto match_bitcast_f32 = [](int64_t parameter_number) { @@ -500,9 +487,9 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { XlaComputation comparison = topk->largest() ? CreateScalarGtComputation(ptypes, &b) : CreateScalarLtComputation(ptypes, &b); - - TF_ASSIGN_OR_RETURN(HloComputation * comparator, - BuilderToHloComputation(comparison, topk->parent())); + TF_ASSIGN_OR_RETURN( + HloComputation * comparator, + XlaComputationToHloComputation(comparison, topk->parent()->parent())); return comparator; } diff --git a/xla/service/triangular_solve_expander.cc b/xla/service/triangular_solve_expander.cc index c61dc148c0ec33..3bc7ba36b60b81 100644 --- a/xla/service/triangular_solve_expander.cc +++ b/xla/service/triangular_solve_expander.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -599,15 +600,8 @@ absl::StatusOr TriangularSolveExpander::ExpandInstruction( /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index 03b09f9b644f07..f5246389e485c3 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,14 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E3M4; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E4M3; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E4M3FN; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 10270d0b3c1be2..f46506c0fda2c7 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -66,6 +66,8 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6a604e20619455..6aee86bf2cbc19 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -46,12 +46,16 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { switch (dtype) { case PrimitiveType::F8E5M2: return DataType::kF8E5M2; + case PrimitiveType::F8E4M3: + return DataType::kF8E4M3; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; case PrimitiveType::F8E5M2FNUZ: return DataType::kF8E5M2FNUZ; case PrimitiveType::F8E4M3FNUZ: return DataType::kF8E4M3FNUZ; + case PrimitiveType::F8E3M4: + return DataType::kF8E3M4; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -79,12 +83,16 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { switch (dtype) { case DataType::kF8E5M2: return PrimitiveType::F8E5M2; + case DataType::kF8E4M3: + return PrimitiveType::F8E4M3; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; case DataType::kF8E5M2FNUZ: return PrimitiveType::F8E5M2FNUZ; case DataType::kF8E4M3FNUZ: return PrimitiveType::F8E4M3FNUZ; + case DataType::kF8E3M4: + return PrimitiveType::F8E3M4; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -141,9 +149,11 @@ absl::StatusOr GetBlasComputationType( if (algorithm == xla::PrecisionConfig::ALG_UNSET) { switch (output_dtype) { case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3: // fall-through case PrimitiveType::F8E4M3FN: // fall-through case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through + case PrimitiveType::F8E3M4: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 57e9f9651e0f50..f6e8a867e2b484 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -745,13 +745,11 @@ cc_library( deps = [ ":hipblas_lt_header", ":rocblas_plugin", - "//xla/stream_executor", "//xla/stream_executor:blas", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", ], ) diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index a59c935614cd8f..e5730121addd8d 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/rocm/hip_blas_utils.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "xla/stream_executor/blas.h" @@ -35,8 +36,11 @@ absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: + case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; + case blas::DataType::kF8E3M4: + LOG(FATAL) + << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 34f124c69fb640..9472d3f5b6f31d 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1182,10 +1182,12 @@ xla_test( "//xla:array3d", "//xla:array4d", "//xla:literal_util", + "//xla:types", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:constants", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index b8b1613768c2bc..c12ce79a06e8fa 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -1423,7 +1423,8 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types #include +#include #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/types.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" @@ -46,10 +48,11 @@ class ConstantsTest : public ClientLibraryTestBase { template class ConstantsFloatTest : public ConstantsTest {}; -typedef ::testing::Types - FloatTypes; +using FloatTypes = + ::testing::Types; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index f5a68c32886410..4f06ea0cc290c7 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,10 +54,9 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -873,22 +872,200 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive) { XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive2) { // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e5m2; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// ----- F8E4M3 + +XLA_TEST_F(ConvertTest, ConvertF16F8e4m3Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFCp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.004p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFCp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFFFFEp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFFFFEp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e4m3; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); + Eigen::numext::bit_cast(static_cast(i)))); } - ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive3) { // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - using From = tsl::float8_e5m2; + using From = tsl::float8_e4m3; std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); @@ -899,7 +1076,7 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3F16RoundtripExhaustive4) { // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); @@ -910,7 +1087,7 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { } xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2); + ConvertElementType(all_f16_to_f8, F8E4M3); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } @@ -1366,15 +1543,21 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - static_cast(Eigen::numext::bit_cast( - static_cast(i)))); - } + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } - ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } } XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive3) { @@ -1569,5 +1752,178 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E3M4 + +XLA_TEST_F(ConvertTest, ConvertF16F8e3m4Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7Cp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.004p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7Cp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e3m4Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7FFFEp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.000002p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7FFFEp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tests/float8_test.cc b/xla/tests/float8_test.cc index 02be9bfa9356ea..648c718d7cd958 100644 --- a/xla/tests/float8_test.cc +++ b/xla/tests/float8_test.cc @@ -27,11 +27,12 @@ limitations under the License. namespace xla { namespace { -// Test FP8 floating-point types (F8E5M2, F8E4M3FN) +// Test FP8 floating-point types template class Float8Test : public ClientLibraryTestBase {}; -using DataTypes = ::testing::Types; +using DataTypes = ::testing::Types; TYPED_TEST_SUITE(Float8Test, DataTypes); XLA_TYPED_TEST(Float8Test, ScalarOperation) { diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 780968098cf32b..4f4895b57123ae 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -120,22 +120,28 @@ enum PrimitiveType { C64, C128, F8E5M2, + F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, + F8E3M4, }; const std::vector& primitive_strings() { - static auto vec = - new std::vector({"s2", "s4", "s8", - "s16", "s32", "s64", - "u2", "u4", "u8", - "u16", "u32", "u64", - "f16", "bf16", "f32", - "f64", "c64", "c128", - "f8e5m2", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz"}); + static auto vec = new std::vector({"s2", "s4", + "s8", "s16", + "s32", "s64", + "u2", "u4", + "u8", "u16", + "u32", "u64", + "f16", "bf16", + "f32", "f64", + "c64", "c128", + "f8e5m2", "f8e4m3", + "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz", + "f8e3m4"}); return *vec; } @@ -413,10 +419,12 @@ void Fill(void* buffer, const ArrayShape& shape) { return FillFloatT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: @@ -469,10 +477,12 @@ void Display(const void* buffer, const ArrayShape& shape) { return DisplayT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 46fa640ee62298..39644589d309e6 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,6 +70,8 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 695db935f6a0b4..2ac31005c16629 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -22,6 +22,8 @@ enum DataType { kF8E5M2FNUZ = 10; kF8E4M3FNUZ = 11; kInt64 = 12; + kF8E4M3 = 13; + kF8E3M4 = 14; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index 717ab3e462a7bf..e2c5eb295c6b12 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,10 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float8_e3m4 = + py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); + numpy_dtypes.float8_e4m3 = + py::dtype::from_args(ml_dtypes.attr("float8_e4m3")).num(); numpy_dtypes.float8_e4m3fn = py::dtype::from_args(ml_dtypes.attr("float8_e4m3fn")).num(); numpy_dtypes.float8_e5m2 = @@ -81,6 +85,8 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float8_e3m4 == NPY_NOTYPE || + numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || numpy_dtypes.float8_e4m3fnuz == NPY_NOTYPE || numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index bf9eab2200a76b..b3aa94e430239a 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,8 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float8_e3m4; + int float8_e4m3; int float8_e4m3fn; int float8_e4m3b11fnuz; int float8_e4m3fnuz; diff --git a/xla/util.cc b/xla/util.cc index 9b1a6db1fa22c0..0c92df2c69e76b 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -137,7 +137,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { static_assert(!std::is_same::value, - "RoundTripNanPayload does not support E4M3"); + "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FNUZ"); static_assert(!std::is_same::value, @@ -168,6 +168,12 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e4m3 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(tsl::float8_e4m3fnuz value) { std::string result = GenericRoundTripFpToString(value); return result; @@ -188,6 +194,12 @@ std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value) { return result; } +std::string RoundTripFpToString(tsl::float8_e3m4 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index a6e74601a809f7..a62096c866d1f5 100644 --- a/xla/util.h +++ b/xla/util.h @@ -420,6 +420,9 @@ std::string VectorString(const std::initializer_list& c) { std::string RoundTripFpToString(tsl::float8_e5m2 value); // Returns a string which can losslessly round trip to a float8 E4M3. +std::string RoundTripFpToString(tsl::float8_e4m3 value); + +// Returns a string which can losslessly round trip to a float8 E4M3FN. std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. @@ -431,6 +434,9 @@ std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); // Returns a string which can losslessly round trip to a float8 E4M3FNUZ. std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); +// Returns a string which can losslessly round trip to a float8 E3M4. +std::string RoundTripFpToString(tsl::float8_e3m4 value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); diff --git a/xla/util_test.cc b/xla/util_test.cc index 707696ea1c3a99..83b1b149c6916d 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "ml_dtypes/include/float8.h" #include "xla/maybe_owning.h" #include "xla/test.h" #include "xla/types.h" @@ -130,6 +131,18 @@ TEST(UtilTest, RoundTripFpToString) { EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( true, QuietNanWithoutPayload())), "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); EXPECT_EQ( RoundTripFpToString(std::numeric_limits::quiet_NaN()), "nan"); @@ -237,6 +250,18 @@ TEST(UtilTest, TotalOrder_F8E5M2) { } } +TEST(UtilTest, TotalOrder_F8E4M3) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e4m3 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e4m3 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E4M3FN) { for (int a = 0; a < 256; ++a) { tsl::float8_e4m3fn x = @@ -287,6 +312,18 @@ TEST(UtilTest, TotalOrder_F8E5M2FNUZ) { } } +TEST(UtilTest, TotalOrder_F8E3M4) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e3m4 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e3m4 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 57e8f5a93a7073..c67116a167eea5 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -66,6 +66,9 @@ enum PrimitiveType { // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the // existing IEEE types. // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only // Finite and NaN values are supported. Unlike IEEE types, infinities are not // supported. NaN is represented when the exponent and mantissa bits are all @@ -77,12 +80,17 @@ enum PrimitiveType { // the exponent and mantissa bits are all 0s with a sign bit of 1. All other // values are finite. // + // F8E3M4 has 3 exponent bits and 4 mantissa bits, and is similar to the + // existing IEEE types. + // // Support for these dtypes is under development. They do not yet work // properly in most cases. // TODO(b/259609697): Fully support FP8. F8E5M2 = 19; + F8E4M3 = 28; F8E4M3FN = 20; F8E4M3B11FNUZ = 23; + F8E3M4 = 29; // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 // @@ -126,7 +134,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 28 + // Next = 30 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -572,12 +580,14 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; bytes f8e5m2s = 19; + bytes f8e4m3s = 28; bytes f8e4m3fns = 20; bytes f8e4m3b11fnuzs = 23; bytes f8e5m2fnuzs = 24; bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 28 + // Next = 30 } message WindowDimension {