Skip to content

Commit

Permalink
[LinalgExt] Implement pad_attention TransformOp.
Browse files Browse the repository at this point in the history
Signed-off-by: stanley-nod <[email protected]>
  • Loading branch information
raikonenfnu authored and monorimet committed Jun 17, 2024
1 parent fa355a6 commit d36dd73
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ LinalgExt::LinalgExtTransformOpsExtension::LinalgExtTransformOpsExtension() {
void LinalgExt::LinalgExtTransformOpsExtension::init() {}

//===---------------------------------------------------------------------===//
// TileAndDecomposeAttention
// Attention related transformOps
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure LinalgExt::TileAttentionOp::applyToOne(
Expand All @@ -51,5 +51,21 @@ DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
LinalgExt::PadAttentionOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgExt::AttentionOp attentionOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<int64_t> padToMultipleOf =
extractFromIntegerArrayAttr<int64_t>(getPadToMultipleOf());

SmallVector<Operation *> ops;
LinalgExt::padAttention(attentionOp, ops, rewriter, padToMultipleOf);
for (auto op : ops) {
results.push_back(op);
}
return DiagnosedSilenceableFailure::success();
}

#define GET_OP_CLASSES
#include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp.inc"
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,44 @@ def DecomposeTiledAttentionOp : Op<Transform_Dialect, "iree.decompose_tiled_atte
}];
}

def PadAttentionOp : Op<Transform_Dialect, "iree.pad_attention",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target iree_linalg_ext.attention ops and pad them.
This transform consumes the target handle and produces a result handle.
Outputs:
1.Padded Attention handle
2.Subtensor output (extractSlices)
}];

let arguments = (
ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">: $pad_to_multiple_of
);
let results = (outs Variadic<TransformHandleTypeInterface>:$result);

let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_cc_library(
"DecomposeAttention.cpp"
"DecomposeWinogradPass.cpp"
"PadContractionToBlockSize.cpp"
"Padding.cpp"
"PassDetail.h"
"Passes.cpp"
"SplitReduction.cpp"
Expand Down
200 changes: 200 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Padding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

namespace {

OpFoldResult getPadding(RewriterBase &rewriter, Location loc,
OpFoldResult bound, int64_t padMultiple) {
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
AffineExpr padByExpr = (s0).ceilDiv(padMultiple) * padMultiple - s0;
return affine::makeComposedFoldedAffineApply(rewriter, loc, padByExpr,
{bound});
}

static Value
getPaddedValue(RewriterBase &rewriter, Location loc, Value padSource,
ArrayRef<OpFoldResult> padding,
std::optional<TypedAttr> padValueAttr = std::nullopt) {
auto sourceType = cast<RankedTensorType>(padSource.getType());
ArrayRef<int64_t> sourceShape = sourceType.getShape();
auto paddedShape =
llvm::map_to_vector(llvm::zip_equal(sourceShape, padding), [](auto it) {
std::optional<int64_t> padInt = getConstantIntValue(std::get<1>(it));
if (ShapedType::isDynamic(std::get<0>(it)) || !padInt) {
return ShapedType::kDynamic;
}
return std::get<0>(it) + padInt.value();
});
auto paddedResultType =
RankedTensorType::get(paddedShape, sourceType.getElementType());
auto zero = rewriter.getZeroAttr(sourceType.getElementType());
Value paddingValue =
rewriter.create<arith::ConstantOp>(loc, padValueAttr.value_or(zero));
SmallVector<OpFoldResult> low(padding.size(), rewriter.getIndexAttr(0));
Value paddedResult = rewriter.create<tensor::PadOp>(
loc, paddedResultType, padSource, low, padding, paddingValue);
return paddedResult;
}

struct PadAttentionPass : public PadAttentionBase<PadAttentionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
void runOnOperation() override;
};

} // namespace

/// Pads iree_linalg_ext.attention.
IREE::LinalgExt::AttentionOp padAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter,
ArrayRef<int64_t> padToMultipleOf) {
SmallVector<AffineMap> maps = attnOp.getIndexingMapsArray();
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
IREE::LinalgExt::AttentionOpDetail::get(maps);
assert(succeeded(maybeOpInfo) && "failed to infer attention dims");
auto opInfo = maybeOpInfo.value();
Location loc = attnOp.getLoc();
rewriter.setInsertionPoint(attnOp);

int64_t domainRank = maps[0].getNumDims();
assert(domainRank == 5 &&
"Currently only support base-case of attention dims.");
assert(padToMultipleOf.size() == domainRank &&
"Expected pad_to_multiple_of to have same rank as dimensions of "
"attention.");
SmallVector<Range> bounds = attnOp.getIterationDomain(rewriter);

int64_t batchIdx = opInfo.getBatchDims().back();
int64_t mIdx = opInfo.getMDims().back();
int64_t k1Idx = opInfo.getK1Dims().back();
int64_t k2Idx = opInfo.getK2Dims().back();
int64_t nIdx = opInfo.getNDims().back();

SmallVector<OpFoldResult> padValues(domainRank, rewriter.getIndexAttr(0));
for (auto [idx, bound] : enumerate(bounds)) {
if (padToMultipleOf[idx] != 0) {
padValues[idx] =
getPadding(rewriter, loc, bound.size, padToMultipleOf[idx]);
}
}

Value paddedQuery = attnOp.getQuery();
Value paddedKey = attnOp.getKey();
Value paddedValue = attnOp.getValue();
Value paddedAcc = attnOp.getOutput();
Value scale = attnOp.getScale();

OpFoldResult zero = rewriter.getIndexAttr(0);

// Pad Q-tensor if any of its' dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[mIdx], 0) ||
!isConstantIntValue(padValues[k1Idx], 0)) {
paddedQuery = getPaddedValue(
rewriter, loc, paddedQuery,
{padValues[batchIdx], padValues[mIdx], padValues[k1Idx]});
}

// Pad K1-dim of K-tensor by a large negative S.T when used by softmax it will
// generate the correct numerics.
if (!isConstantIntValue(padValues[k2Idx], 0)) {
Type keyElType = attnOp.getKeyType().getElementType();
auto largeNeg = rewriter.getFloatAttr(keyElType, 0.0);
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
{zero, padValues[k2Idx], zero}, largeNeg);
}

// Pad K-tensor if any non-K1 dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[k1Idx], 0)) {
paddedKey = getPaddedValue(rewriter, loc, paddedKey,
{padValues[batchIdx], zero, padValues[k1Idx]});
}

// Pad V-tensor if any of its' dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[k2Idx], 0) ||
!isConstantIntValue(padValues[nIdx], 0)) {
paddedValue = getPaddedValue(
rewriter, loc, paddedValue,
{padValues[batchIdx], padValues[k2Idx], padValues[nIdx]});
}

// Pad Acc-tensor if any of its' dims needs padding.
if (!isConstantIntValue(padValues[batchIdx], 0) ||
!isConstantIntValue(padValues[mIdx], 0) ||
!isConstantIntValue(padValues[nIdx], 0)) {
if (llvm::dyn_cast_or_null<tensor::EmptyOp>(paddedAcc.getDefiningOp())) {
SmallVector<OpFoldResult> paddedQueryShape =
tensor::getMixedSizes(rewriter, loc, paddedQuery);
SmallVector<OpFoldResult> paddedValueShape =
tensor::getMixedSizes(rewriter, loc, paddedValue);
SmallVector<OpFoldResult> paddedOutputShape = {
paddedQueryShape[0], paddedQueryShape[1], paddedValueShape[2]};
paddedAcc = rewriter.create<tensor::EmptyOp>(
loc, paddedOutputShape, attnOp.getOutputType().getElementType());
} else {
paddedAcc = getPaddedValue(
rewriter, loc, paddedAcc,
{padValues[batchIdx], padValues[mIdx], padValues[nIdx]});
}
}

// Generate padded attention op.
auto paddedAttnOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
loc, paddedAcc.getType(),
SmallVector<Value>{paddedQuery, paddedKey, paddedValue, scale},
paddedAcc);

ops.push_back(paddedAttnOp);

// Extract subtensor result.
IntegerAttr one = rewriter.getI64IntegerAttr(1);
SmallVector<OpFoldResult> offsets(3, zero);
SmallVector<OpFoldResult> strides(3, one);
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, attnOp.getOutput());
Operation *extracted = rewriter.create<tensor::ExtractSliceOp>(
loc, paddedAttnOp->getResults()[0], offsets, sizes, strides);
ops.push_back(extracted);

rewriter.replaceOp(attnOp, extracted);

return paddedAttnOp;
}

void PadAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
getOperation().walk([&](AttentionOp attnOp) {
SmallVector<Operation *> ops;
padAttention(attnOp, ops, rewriter, padToMultipleOf);
});
}

std::unique_ptr<Pass> createPadAttentionPass() {
return std::make_unique<PadAttentionPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ tileAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
std::optional<uint64_t> tileSize = std::nullopt);

IREE::LinalgExt::AttentionOp padAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter,
ArrayRef<int64_t> padToMultipleOf);

void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter,
Expand All @@ -62,6 +67,9 @@ std::unique_ptr<Pass> createTileAttentionPass();
// Creates a pass to convert the attention op into a sequence of linalg ops.
std::unique_ptr<Pass> createDecomposeAttentionPass();

// Creates a pass to pad the attention op along the specified dims.
std::unique_ptr<Pass> createPadAttentionPass();

//===---------------------------------------------------------------------===//
// Codegen Strategy passes that are moved into IREE.
//===---------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def TileAttention :
];
}

def PadAttention :
InterfacePass<"iree-linalg-ext-pad-attention", "mlir::FunctionOpInterface"> {
let summary =
"Tile the attention op along the reduction dimension";
let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
"createPadAttentionPass()";
let options = [
ListOption<"padToMultipleOf", "pad-to-multiple-of", "int64_t",
"Array to represent the number to pad to multiple of for attention dims.">,
];
}

def DecomposeAttention :
InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
Loading

0 comments on commit d36dd73

Please sign in to comment.