-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][spirv] Add 8-bit float type emulation #148811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
ab237ea
to
3755267
Compare
@llvm/pr-subscribers-mlir-spirv Author: Md Abdullah Shahneous Bari (mshahneo) Changes8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion. Full diff: https://github.com/llvm/llvm-project/pull/148811.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..0eb9720351027 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
+ /// Whether to emulate unsupported floats with integer types of same bit
+ /// width.
+ bool emulateUnsupportedFloatTypes{true};
+
/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..f42f779a69d33 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +304,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +379,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1351,6 +1377,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..56b6181018153 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..c0439a4033eac 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..8cd650e649008 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..6b2580b6541f2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ else
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +328,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+ return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return type;
+ auto srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -337,6 +385,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +482,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +646,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1439,6 +1493,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd2ec468..751e727534efe 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,17 @@ func.func @constant() {
return
}
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+ // CHECK: spirv.Constant 56 : i8
+ %cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ return
+}
+
// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a906bf8..0c77c88334572 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===//
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK: spirv.func @float8_to_integer8
+ // CHECK-SAME: (%arg0: i8
+ // CHECK-SAME: %arg1: i8
+ // CHECK-SAME: %arg2: i8
+ // CHECK-SAME: %arg3: i8
+ // CHECK-SAME: %arg4: i8
+ // CHECK-SAME: %arg5: i8
+ // CHECK-SAME: %arg6: i8
+ // CHECK-SAME: %arg7: i8
+ // CHECK-SAME: %arg8: vector<4xi8>
+ // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+ // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+ // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+ // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+ // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+ // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+ // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+ // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+ // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+ // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+ // UNSUPPORTED_FLOAT-SAME: ) {
+
+ func.func @float8_to_integer8(
+ %arg0: f8E5M2, // CHECK-NOT: f8E5M2
+ %arg1: f8E4M3, // CHECK-NOT: f8E4M3
+ %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
+ %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
+ %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
+ %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
+ %arg6: f8E3M4, // CHECK-NOT: f8E3M4
+ %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
+ %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+ %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+ %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
+ ) {
+ // CHECK: spirv.Return
+ return
+ }
+}
|
@llvm/pr-subscribers-mlir Author: Md Abdullah Shahneous Bari (mshahneo) Changes8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion. Full diff: https://github.com/llvm/llvm-project/pull/148811.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..0eb9720351027 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
+ /// Whether to emulate unsupported floats with integer types of same bit
+ /// width.
+ bool emulateUnsupportedFloatTypes{true};
+
/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..f42f779a69d33 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +304,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +379,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1351,6 +1377,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..56b6181018153 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..c0439a4033eac 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..8cd650e649008 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..6b2580b6541f2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ else
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +328,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+ return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return type;
+ auto srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -337,6 +385,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +482,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +646,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1439,6 +1493,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd2ec468..751e727534efe 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,17 @@ func.func @constant() {
return
}
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+ // CHECK: spirv.Constant 56 : i8
+ %cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ return
+}
+
// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a906bf8..0c77c88334572 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===//
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK: spirv.func @float8_to_integer8
+ // CHECK-SAME: (%arg0: i8
+ // CHECK-SAME: %arg1: i8
+ // CHECK-SAME: %arg2: i8
+ // CHECK-SAME: %arg3: i8
+ // CHECK-SAME: %arg4: i8
+ // CHECK-SAME: %arg5: i8
+ // CHECK-SAME: %arg6: i8
+ // CHECK-SAME: %arg7: i8
+ // CHECK-SAME: %arg8: vector<4xi8>
+ // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+ // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+ // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+ // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+ // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+ // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+ // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+ // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+ // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+ // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+ // UNSUPPORTED_FLOAT-SAME: ) {
+
+ func.func @float8_to_integer8(
+ %arg0: f8E5M2, // CHECK-NOT: f8E5M2
+ %arg1: f8E4M3, // CHECK-NOT: f8E4M3
+ %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
+ %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
+ %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
+ %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
+ %arg6: f8E3M4, // CHECK-NOT: f8E3M4
+ %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
+ %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+ %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+ %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
+ ) {
+ // CHECK: spirv.Return
+ return
+ }
+}
|
SPIR-V does not support any 8-bit floats. Threfore, 8-bit floats are emulated as 8-bit integers.
Handles scalar and vector.
This approach minimizes the code modification.
5cde869
to
2411797
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call for a test, otherwise approved
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions { | |||
/// The number of bits to store a boolean value. | |||
unsigned boolNumBits{8}; | |||
|
|||
/// Whether to emulate unsupported floats with integer types of same bit | |||
/// width. | |||
bool emulateUnsupportedFloatTypes{true}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we want this on by default? I think this can be a footgun when users inadvertently use unsupported fp types and get a dialect conversion error over integer types down the line...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather this was opt-in unless we have good justification for keeping this opt-out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the theory is that you'll generally want to do software emulation of the small floats earlier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm concerned about cases when you ended up with unsupported types by accident -- this often comes up with all the variants of fp8 and smaller fp types that are present in the input MLIR at the level of linalg/arith. This is much easier to diagnose when you can see the original type. IMO, dialect conversion should error out by default, unless someone opts into these types being handled in some other way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM precedent is to just do the f8* => i8, though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm ir or convert-to-llvm? llvm has the same issue as spirv here that its type system has fewer primitive types than mlir, so it's on mlir to figure out how to handle unsupported types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert-to-llvm
, which, IIRC, doesn't even provide this opt-out mechanism, it' just defines the FP8 types to be i8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so we can leave this on by default to match llvm conversion if this is the case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @krzysz00 . Yes, I was following the llvm precedence.
8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.