From 0cafee98706cf5683bb6fb5cd5ea1a7816c3dad7 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:51:26 +0530 Subject: [PATCH] [vm] Add support for SI64 to F32 casts (#19455) Adds support to the VM for casting from `si64` type to `f32` type. Enables the lowering of `arith.sitofp %arg0 : i64 to f32` after demotion. --- .../VM/Conversion/ArithToVM/Patterns.cpp | 100 ++++++++++++++---- .../ArithToVM/test/conversion_ops.mlir | 12 +++ .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 1 + .../compiler/Dialect/VM/IR/VMOpFolders.cpp | 10 ++ .../compiler/Dialect/VM/IR/VMOpcodesF32.td | 3 + .../src/iree/compiler/Dialect/VM/IR/VMOps.td | 7 ++ runtime/src/iree/vm/bytecode/disassembler.c | 10 ++ runtime/src/iree/vm/bytecode/dispatch.c | 5 + .../vm/bytecode/utils/generated/op_table.h | 12 +-- runtime/src/iree/vm/bytecode/verifier.c | 4 + runtime/src/iree/vm/ops.h | 1 + .../onnx_ops/onnx_ops_gpu_vulkan.json | 3 - 12 files changed, 139 insertions(+), 29 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index 8e5f96f65f81..6d902e549611 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -526,30 +526,93 @@ struct TruncateIOpConversion : public OpConversionPattern { } }; -template -struct IntToFPOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct SIToFPOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(OpTy srcOp, typename OpTy::Adaptor adaptor, + matchAndRewrite(arith::SIToFPOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = srcOp.getIn().getType(); + auto input = srcOp.getIn(); + auto srcType = input.getType(); auto dstType = srcOp.getResult().getType(); - if (!dstType.isF32() || - !(srcType.isSignedInteger() || srcType.isSignlessInteger())) { + auto resultType = getTypeConverter()->convertType(dstType); + + if (!(dstType.isF32() || dstType.isF64())) { return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } - Value input = srcOp.getIn(); - if (!(srcType.isSignlessInteger(32) || srcType.isSignedInteger(32))) { - if (srcType.getIntOrFloatBitWidth() < 32) { - input = rewriter.create( - srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); + + if (srcType.isSignedInteger(32) || srcType.isSignlessInteger(32)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + if (dstType.isF64()) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + } + if (srcType.isSignedInteger(64) || srcType.isSignlessInteger(64)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); } else { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + } + return success(); + } + + if (srcType.getIntOrFloatBitWidth() < 32) { + input = rewriter.create( + srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); + } + + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } +}; + +struct UIToFPOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::UIToFPOp srcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto input = srcOp.getIn(); + auto srcType = input.getType(); + auto dstType = srcOp.getResult().getType(); + + if (!(dstType.isF32() || dstType.isF64())) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + + auto resultType = getTypeConverter()->convertType(dstType); + if (srcType.isUnsignedInteger(32) || srcType.isSignlessInteger(32)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + if (dstType.isF64()) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + } + if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) { + if (dstType.isF32()) { return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } + + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + + if (srcType.getIntOrFloatBitWidth() < 32) { + input = rewriter.create( + srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); } - auto resultType = this->getTypeConverter()->convertType(dstType); - rewriter.replaceOpWithNewOp(srcOp, resultType, input); + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); return success(); } }; @@ -742,12 +805,9 @@ void populateArithToVMPatterns(MLIRContext *context, IREE::VM::MaxF64Op>>(typeConverter, context); // Floating-point conversion ops. - patterns.insert, - IntToFPOpConversion, - FPToSIOpConversion, FPToUIOpConversion, BitcastOpConversion>( - typeConverter, context); + patterns.insert(typeConverter, + context); // Shift ops. patterns diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir index be4ec1f83b87..5b8da0ba9f03 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir @@ -275,6 +275,18 @@ module @sitofp_i8_f32 { // ----- +// CHECK-LABEL: @sitofp_i64_f32 +module @sitofp_i64_f32 { + // CHECK: vm.func private @fn(%[[ARG0:.+]]: i64) + func.func @fn(%arg0: i64) -> f32 { + // CHECK: vm.cast.si64.f32 %[[ARG0]] : i64 -> f32 + %0 = arith.sitofp %arg0 : i64 to f32 + return %0 : f32 + } +} + +// ----- + // CHECK-LABEL: @uitofp_i8_f32 module @uitofp_i8_f32 { // CHECK: vm.func private @fn(%[[ARG0:.+]]: i32) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 71eaea620cb2..845c3e3ace4b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -4521,6 +4521,7 @@ void populateVMToEmitCPatterns(ConversionTarget &conversionTarget, ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32"); ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64"); ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32"); + ADD_GENERIC_PATTERN(IREE::VM::CastSI64F32Op, "vm_cast_si64f32"); ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32"); ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32"); ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32OOp, "vm_cmp_eq_f32o"); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index bec93d4c43b5..f227292e3e4a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -1683,6 +1683,16 @@ OpFoldResult CastSI32F32Op::fold(FoldAdaptor operands) { }); } +OpFoldResult CastSI64F32Op::fold(FoldAdaptor operands) { + return constFoldCastOp( + Float32Type::get(getContext()), operands.getOperand(), + [&](const APInt &a) { + APFloat b = APFloat(0.0f); + b.convertFromAPInt(a, /*IsSigned=*/true, APFloat::rmNearestTiesToAway); + return b; + }); +} + OpFoldResult CastUI32F32Op::fold(FoldAdaptor operands) { return constFoldCastOp( Float32Type::get(getContext()), operands.getOperand(), diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td index af9295f165f4..1ce37ae032f0 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td @@ -45,6 +45,7 @@ def VM_OPC_MinF32 : VM_OPC<0x37, "MinF32">; def VM_OPC_MaxF32 : VM_OPC<0x38, "MaxF32">; def VM_OPC_CastSI32F32 : VM_OPC<0x14, "CastSI32F32">; +def VM_OPC_CastSI64F32 : VM_OPC<0x3C, "CastSI64F32">; def VM_OPC_CastUI32F32 : VM_OPC<0x15, "CastUI32F32">; def VM_OPC_CastF32SI32 : VM_OPC<0x16, "CastF32SI32">; def VM_OPC_CastF32SI64 : VM_OPC<0x3A, "CastF32SI64">; @@ -116,10 +117,12 @@ def VM_ExtF32OpcodeAttr : VM_OPC_CeilF32, VM_OPC_FloorF32, VM_OPC_RoundF32, + VM_OPC_RoundF32Even, VM_OPC_MinF32, VM_OPC_MaxF32, VM_OPC_CastSI32F32, + VM_OPC_CastSI64F32, VM_OPC_CastUI32F32, VM_OPC_CastF32SI32, VM_OPC_CastF32SI64, diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index c23e687a8c6d..f7e59449211b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -3167,6 +3167,13 @@ def VM_CastSI64F64Op : let hasFolder = 1; } +def VM_CastSI64F32Op : + VM_ConversionOp { + let summary = [{cast from a signed integer to a float-point value}]; + let hasFolder = 1; +} + def VM_CastUI64F64Op : VM_ConversionOp { diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c index 02d93e9070e6..4b24843a4a07 100644 --- a/runtime/src/iree/vm/bytecode/disassembler.c +++ b/runtime/src/iree/vm/bytecode/disassembler.c @@ -2111,6 +2111,16 @@ iree_status_t iree_vm_bytecode_disassemble_op( EMIT_OPTIONAL_VALUE_I32(regs->i32[operand_reg]); break; } + DISASM_OP(EXT_F32, CastSI64F32) { + uint16_t operand_reg = VM_ParseOperandRegI64("operand"); + uint16_t result_reg = VM_ParseResultRegF32("result"); + EMIT_F32_REG_NAME(result_reg); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_cstring(b, " = vm.cast.si64.f32 ")); + EMIT_I64_REG_NAME(operand_reg); + EMIT_OPTIONAL_VALUE_I64(regs->i32[operand_reg]); + break; + } DISASM_OP(EXT_F32, CastUI32F32) { uint16_t operand_reg = VM_ParseOperandRegI32("operand"); uint16_t result_reg = VM_ParseResultRegF32("result"); diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c index 40ae195b660d..ba48f3228477 100644 --- a/runtime/src/iree/vm/bytecode/dispatch.c +++ b/runtime/src/iree/vm/bytecode/dispatch.c @@ -2046,6 +2046,11 @@ static iree_status_t iree_vm_bytecode_dispatch( float* result = VM_DecResultRegF32("result"); *result = vm_cast_si32f32(operand); }); + DISPATCH_OP(EXT_F32, CastSI64F32, { + int64_t operand = (int64_t)VM_DecOperandRegI64("operand"); + float* result = VM_DecResultRegF32("result"); + *result = vm_cast_si64f32(operand); + }); DISPATCH_OP(EXT_F32, CastUI32F32, { int32_t operand = (int32_t)VM_DecOperandRegI32("operand"); float* result = VM_DecResultRegF32("result"); diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h index 2a5a76c0d7ab..2c760731a023 100644 --- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h +++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h @@ -388,10 +388,10 @@ typedef enum { OPC(0x77, AbsI32) \ OPC(0x78, AbsI64) \ OPC(0x79, Block) \ - OPC(0x7A, MinI64S) \ - OPC(0x7B, MinI64U) \ - OPC(0x7C, MaxI64S) \ - OPC(0x7D, MaxI64U) \ + OPC(0x7A, MinI32S) \ + OPC(0x7B, MinI32U) \ + OPC(0x7C, MaxI32S) \ + OPC(0x7D, MaxI32U) \ OPC(0x7E, MinI64S) \ OPC(0x7F, MinI64U) \ OPC(0x80, MaxI64S) \ @@ -584,7 +584,7 @@ typedef enum { IREE_VM_OP_EXT_F32_RoundF32Even = 0x39, IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A, IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B, - IREE_VM_OP_EXT_F32_RSV_0x3C, + IREE_VM_OP_EXT_F32_CastSI64F32 = 0x3C, IREE_VM_OP_EXT_F32_RSV_0x3D, IREE_VM_OP_EXT_F32_RSV_0x3E, IREE_VM_OP_EXT_F32_RSV_0x3F, @@ -843,7 +843,7 @@ typedef enum { OPC(0x39, RoundF32Even) \ OPC(0x3A, CastF32SI64) \ OPC(0x3B, CastF32UI64) \ - RSV(0x3C) \ + OPC(0x3C, CastSI64F32) \ RSV(0x3D) \ RSV(0x3E) \ RSV(0x3F) \ diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c index c5b9d635f220..5c726db74e8f 100644 --- a/runtime/src/iree/vm/bytecode/verifier.c +++ b/runtime/src/iree/vm/bytecode/verifier.c @@ -1823,6 +1823,10 @@ static iree_status_t iree_vm_bytecode_function_verify_bytecode_op( VM_VerifyOperandRegI32(operand); VM_VerifyResultRegF32(result); }); + VERIFY_OP(EXT_F32, CastSI64F32, { + VM_VerifyOperandRegI64(operand); + VM_VerifyResultRegF32(result); + }); VERIFY_OP(EXT_F32, CastUI32F32, { VM_VerifyOperandRegI32(operand); VM_VerifyResultRegF32(result); diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h index b9ffd70122da..68c939e62350 100644 --- a/runtime/src/iree/vm/ops.h +++ b/runtime/src/iree/vm/ops.h @@ -599,6 +599,7 @@ static inline float vm_erf_f32(float operand) { return erff(operand); } //===------------------------------------------------------------------===// static inline float vm_cast_si32f32(int32_t operand) { return (float)operand; } +static inline float vm_cast_si64f32(int64_t operand) { return (float)operand; } static inline float vm_cast_ui32f32(int32_t operand) { return (float)(uint32_t)operand; } diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json index 721c34c4ad82..eb8e94e5aa36 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json @@ -373,9 +373,6 @@ "onnx/node/generated/test_slice_start_out_of_bounds", "onnx/node/generated/test_stft", "onnx/node/generated/test_stft_with_window", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", "onnx/node/generated/test_tfidfvectorizer_tf_only_bigrams_skip0", "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_levelempty", "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_skip5",