Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Oct 2, 2024
1 parent 6ac2138 commit 13104db
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,15 @@ struct AllGatherOpConversion final
op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);

// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getResult().getType());
auto resultType = cast<RankedTensorType>(op.getResult(0).getType());
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
IREE::Flow::getCollectiveElementTypeAttr(resultType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(
op, "unsupported element type for collective op");
}
uint64_t allGatherDim = op.getAllGatherDim();
Value gatherInput = adaptor.getOperand();
Value gatherInput = adaptor.getOperands()[0];
SmallVector<int64_t> gatherResultShape(resultType.getShape());

// When all_gather_dim != 0, we need to transpose between 0 and
Expand All @@ -513,7 +513,7 @@ struct AllGatherOpConversion final
// Create an empty tensor for the result.
Value target = rewriter.create<tensor::EmptyOp>(
loc, gatherResultShape,
getElementTypeOrSelf(adaptor.getOperand().getType()));
getElementTypeOrSelf(adaptor.getOperands()[0].getType()));
Value gatherResult = rewriter.create<IREE::Flow::CollectiveAllGatherOp>(
op.getLoc(), elementTypeAttr, target, gatherInput, channel);

Expand Down Expand Up @@ -585,7 +585,7 @@ struct AllReduceOpConversion final
auto reductionOpAttr =
IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp);

auto inputType = cast<RankedTensorType>(op.getOperand().getType());
auto inputType = cast<RankedTensorType>(op.getOperand(0).getType());

// Get the collective element type attribute.
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
Expand All @@ -597,10 +597,11 @@ struct AllReduceOpConversion final
// Create an empty tensor for the result.
ArrayRef<int64_t> inputShape = inputType.getShape();
Value target = rewriter.create<tensor::EmptyOp>(
loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
loc, inputShape,
getElementTypeOrSelf(adaptor.getOperands()[0].getType()));
auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
op.getLoc(), reductionOpAttr, elementTypeAttr, target,
adaptor.getOperand(), channel);
adaptor.getOperands()[0], channel);
rewriter.replaceOp(op, allReduceOp.getResult());
return success();
}
Expand Down Expand Up @@ -676,7 +677,7 @@ struct AllToAllOpConversion final
op.getReplicaGroups(), /*useGlobalDeviceIds=*/std::nullopt, rewriter);

// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getType());
auto resultType = cast<RankedTensorType>(op.getType(0));
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
IREE::Flow::getCollectiveElementTypeAttr(resultType);
if (!elementTypeAttr) {
Expand All @@ -687,7 +688,7 @@ struct AllToAllOpConversion final
uint64_t splitDim = op.getSplitDimension();
uint64_t concatDim = op.getConcatDimension();
uint64_t splitCount = op.getSplitCount();
Value allToAllInput = adaptor.getOperand();
Value allToAllInput = adaptor.getOperands()[0];

// When splitDim != 0, we need to transpose splitDim to 0 before and after
// the all-to-all.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ struct GeneralDotRemoveBatch final

auto dot = rewriter.create<mlir::stablehlo::DotGeneralOp>(
op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs,
newDimNumbers, op.getPrecisionConfigAttr());
newDimNumbers, op.getPrecisionConfigAttr(), op.getAlgorithmAttr());
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty,
dot.getResult());
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct EinsumToDotGeneralPattern final
auto dotGeneralOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
einsum.getLoc(), dotGeneralResultType, einsum.getLhs(), einsum.getRhs(),
dimNumbers,
/*precision_config=*/ArrayAttr{});
/*precision_config=*/ArrayAttr{}, mlir::stablehlo::DotAlgorithmAttr{});

if (isNaturalOrder) {
// The dot_general is already in an appropriate result order.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct TransposeReshapeGenericDotGeneral final

auto newOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
op.getPrecisionConfigAttr());
op.getPrecisionConfigAttr(), op.getAlgorithmAttr());

// Copy over unknown attributes as we currently rely on it to let user tune
// lowering parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ struct HouseholderReflectorRewriter final
auto dotNums = mlir::stablehlo::DotDimensionNumbersAttr::get(
b.getContext(), batch, batch, lhsContract, rhsContract);
Value dot = b.create<mlir::stablehlo::DotGeneralOp>(
householder0.getType(), args[0], householder, dotNums, nullptr);
householder0.getType(), args[0], householder, dotNums, nullptr,
mlir::stablehlo::DotAlgorithmAttr{});
b.create<scf::YieldOp>(loc, dot);
});

Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:MLProgramDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToGPU",
"@llvm-project//mlir:SCFTransforms",
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Tools/init_mlir_dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
Expand Down Expand Up @@ -81,7 +81,7 @@ inline void registerMlirDialects(DialectRegistry &registry) {
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
quant::QuantizationDialect,
quant::QuantDialect,
spirv::SPIRVDialect,
arm_neon::ArmNeonDialect,
arm_sve::ArmSVEDialect,
Expand Down
2 changes: 1 addition & 1 deletion third_party/stablehlo
2 changes: 1 addition & 1 deletion third_party/torch-mlir

0 comments on commit 13104db

Please sign in to comment.