From 13104db6a11d89cb1121cda2e95534fa3b250d45 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Wed, 2 Oct 2024 12:57:38 -0400 Subject: [PATCH] fixes Signed-off-by: Benoit Jacob --- .../StableHLO/Conversion/ConvertCollectives.cpp | 17 +++++++++-------- .../Preprocessing/DotGeneralToDot.cpp | 2 +- .../Preprocessing/EinsumToDotGeneral.cpp | 2 +- .../Preprocessing/StableHLOToStableHLO.cpp | 2 +- .../Conversion/StableHLOCustomCalls.cpp | 3 ++- compiler/src/iree/compiler/Tools/BUILD.bazel | 1 + .../iree/compiler/Tools/init_mlir_dialects.h | 4 ++-- third_party/stablehlo | 2 +- third_party/torch-mlir | 2 +- 9 files changed, 19 insertions(+), 16 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp index 21a0086381a8..1525aff4aab6 100644 --- a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp @@ -491,7 +491,7 @@ struct AllGatherOpConversion final op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter); // Get the collective element type attribute. - auto resultType = cast(op.getResult().getType()); + auto resultType = cast(op.getResult(0).getType()); IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = IREE::Flow::getCollectiveElementTypeAttr(resultType); if (!elementTypeAttr) { @@ -499,7 +499,7 @@ struct AllGatherOpConversion final op, "unsupported element type for collective op"); } uint64_t allGatherDim = op.getAllGatherDim(); - Value gatherInput = adaptor.getOperand(); + Value gatherInput = adaptor.getOperands()[0]; SmallVector gatherResultShape(resultType.getShape()); // When all_gather_dim != 0, we need to transpose between 0 and @@ -513,7 +513,7 @@ struct AllGatherOpConversion final // Create an empty tensor for the result. Value target = rewriter.create( loc, gatherResultShape, - getElementTypeOrSelf(adaptor.getOperand().getType())); + getElementTypeOrSelf(adaptor.getOperands()[0].getType())); Value gatherResult = rewriter.create( op.getLoc(), elementTypeAttr, target, gatherInput, channel); @@ -585,7 +585,7 @@ struct AllReduceOpConversion final auto reductionOpAttr = IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp); - auto inputType = cast(op.getOperand().getType()); + auto inputType = cast(op.getOperand(0).getType()); // Get the collective element type attribute. IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = @@ -597,10 +597,11 @@ struct AllReduceOpConversion final // Create an empty tensor for the result. ArrayRef inputShape = inputType.getShape(); Value target = rewriter.create( - loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType())); + loc, inputShape, + getElementTypeOrSelf(adaptor.getOperands()[0].getType())); auto allReduceOp = rewriter.create( op.getLoc(), reductionOpAttr, elementTypeAttr, target, - adaptor.getOperand(), channel); + adaptor.getOperands()[0], channel); rewriter.replaceOp(op, allReduceOp.getResult()); return success(); } @@ -676,7 +677,7 @@ struct AllToAllOpConversion final op.getReplicaGroups(), /*useGlobalDeviceIds=*/std::nullopt, rewriter); // Get the collective element type attribute. - auto resultType = cast(op.getType()); + auto resultType = cast(op.getType(0)); IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = IREE::Flow::getCollectiveElementTypeAttr(resultType); if (!elementTypeAttr) { @@ -687,7 +688,7 @@ struct AllToAllOpConversion final uint64_t splitDim = op.getSplitDimension(); uint64_t concatDim = op.getConcatDimension(); uint64_t splitCount = op.getSplitCount(); - Value allToAllInput = adaptor.getOperand(); + Value allToAllInput = adaptor.getOperands()[0]; // When splitDim != 0, we need to transpose splitDim to 0 before and after // the all-to-all. diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp index 34d1a93ef531..60b824a687e7 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp @@ -198,7 +198,7 @@ struct GeneralDotRemoveBatch final auto dot = rewriter.create( op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs, - newDimNumbers, op.getPrecisionConfigAttr()); + newDimNumbers, op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); rewriter.replaceOpWithNewOp(op, ty, dot.getResult()); return success(); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/EinsumToDotGeneral.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/EinsumToDotGeneral.cpp index 8ca503fdf446..b84d283234e8 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/EinsumToDotGeneral.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/EinsumToDotGeneral.cpp @@ -141,7 +141,7 @@ struct EinsumToDotGeneralPattern final auto dotGeneralOp = rewriter.create( einsum.getLoc(), dotGeneralResultType, einsum.getLhs(), einsum.getRhs(), dimNumbers, - /*precision_config=*/ArrayAttr{}); + /*precision_config=*/ArrayAttr{}, mlir::stablehlo::DotAlgorithmAttr{}); if (isNaturalOrder) { // The dot_general is already in an appropriate result order. diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp index 4970994e66f9..f3cbff47f8fb 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp @@ -424,7 +424,7 @@ struct TransposeReshapeGenericDotGeneral final auto newOp = rewriter.create( op.getLoc(), newResultType, lhs, rhs, dimensionNumbers, - op.getPrecisionConfigAttr()); + op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); // Copy over unknown attributes as we currently rely on it to let user tune // lowering parameters. diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp index a4118689191b..7021f09744fb 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp @@ -200,7 +200,8 @@ struct HouseholderReflectorRewriter final auto dotNums = mlir::stablehlo::DotDimensionNumbersAttr::get( b.getContext(), batch, batch, lhsContract, rhsContract); Value dot = b.create( - householder0.getType(), args[0], householder, dotNums, nullptr); + householder0.getType(), args[0], householder, dotNums, nullptr, + mlir::stablehlo::DotAlgorithmAttr{}); b.create(loc, dot); }); diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index ad3bf8767236..7c717d641ebf 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -113,6 +113,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index 6dde522a94ee..e399e63dffe8 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -40,7 +40,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" @@ -81,7 +81,7 @@ inline void registerMlirDialects(DialectRegistry ®istry) { pdl::PDLDialect, pdl_interp::PDLInterpDialect, scf::SCFDialect, - quant::QuantizationDialect, + quant::QuantDialect, spirv::SPIRVDialect, arm_neon::ArmNeonDialect, arm_sve::ArmSVEDialect, diff --git a/third_party/stablehlo b/third_party/stablehlo index f7f8e4e35296..ecc76d6cb7c5 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit f7f8e4e35296deeff2e12e39421ac8d9599ba340 +Subproject commit ecc76d6cb7c564a6fabc7dc44f4539426138ee23 diff --git a/third_party/torch-mlir b/third_party/torch-mlir index 9938abf25e1e..b423ae6d2b18 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 9938abf25e1e7526ca7f43a8c49e9078c14fc55c +Subproject commit b423ae6d2b18e7e556b17bb12ba10496fae1e8cd