Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass to transfer the strided access pattern from L3 to L2 #792

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,377 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-transfer-strided-access-pattern"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility to copy a vector except for certain dim positions.
static SmallVector<OpFoldResult> copyExcludeDims(
SmallVector<OpFoldResult> origVals, DenseSet<size_t> excludeDims) {
if (excludeDims.size() == 0) return origVals;
SmallVector<OpFoldResult> results;
for (size_t i = 0; i < origVals.size(); i++) {
if (!excludeDims.contains(i)) {
results.push_back(origVals[i]);
}
}
return results;
};

/// Utility to check if any dimension from the L3 dma addressing can be combined
/// with the innermost dimension, if so return the position of the dimension.
/// Two dimensions (i and innermost) can be combined if the following conditions
/// are satisfied: 1) stride[i] = innermost_stride * innermost_size;
/// 2) offset[i] = 0.
static FailureOr<size_t> isL3AddressingCombinable(
SmallVector<OpFoldResult> &dmaOffsets, SmallVector<OpFoldResult> &dmaSizes,
SmallVector<OpFoldResult> &dmaStrides) {
// Offsets could be dynamic but sizes and strides should be static.
std::optional<SmallVector<int64_t>> maybeSizes =
getConstantIntValues(dmaSizes);
std::optional<SmallVector<int64_t>> maybeStrides =
getConstantIntValues(dmaStrides);
if (!maybeSizes.has_value() || !maybeSizes.has_value()) {
return failure();
}
SmallVector<int64_t> sizeVals = maybeSizes.value();
SmallVector<int64_t> strideVals = maybeStrides.value();

// Get the index of the dim that can be potentially combined with the
// innermost dim. If there is no such dim, return the last index of the
// vector.
auto getPos = [&](SmallVector<int64_t> values, int64_t target) {
size_t i = 0;
for (; i < values.size() - 1; i++) {
if (values[i] == target) return i;
}
return i;
};

int64_t innerDimTotal = strideVals.back() * sizeVals.back();
size_t dimForCombine = getPos(strideVals, innerDimTotal);
if (dimForCombine >= (dmaSizes.size() - 1)) return failure();

std::optional<int64_t> offsetAtPos =
getConstantIntValue(dmaOffsets[dimForCombine]);
if (!offsetAtPos.has_value() || offsetAtPos.value() != 0) return failure();
return dimForCombine;
}

/// Utility to check if L2 dma addressing is linear. Note here the assumption is
/// the dma ops are already canonicalized, so that the L2 addressing should be
/// empty or 1-d vectors.
static bool isL2AddressingLinear(SmallVector<OpFoldResult> &dmaOffsets,
SmallVector<OpFoldResult> &dmaSizes,
SmallVector<OpFoldResult> &dmaStrides) {
assert(dmaOffsets.size() == dmaSizes.size() &&
dmaOffsets.size() == dmaStrides.size() &&
"expected same number of source offsets and sizes");
if (dmaOffsets.size() == 0) return true;
if (dmaOffsets.size() != 1) return false;
if (!isConstantIntValue(dmaOffsets[0], 0)) return false;
if (!isConstantIntValue(dmaStrides[0], 1)) return false;
return true;
}

/// Utility to check if all users of the connection op statisfy the conditions
/// for dma access pattern transfer.
static FailureOr<bool> checkConnectionUsers(AMDAIE::ConnectionOp connectionOp) {
for (Operation *user : connectionOp->getUsers()) {
// Check if L3 addressing is combinable.
if (auto dmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(user)) {
if (dmaOp.hasSourceAddressing() && dmaOp.hasTargetAddressing()) {
dmaOp.emitOpError()
<< "should not have both source and target addressing";
return failure();
}
if (!dmaOp.hasSourceAddressing() && !dmaOp.hasTargetAddressing()) {
dmaOp.emitOpError() << "should have either source or target addressing";
return failure();
}

SmallVector<OpFoldResult> dmaOffsets;
SmallVector<OpFoldResult> dmaSizes;
SmallVector<OpFoldResult> dmaStrides;
if (dmaOp.hasSourceAddressing()) {
dmaOffsets = dmaOp.getSourceMixedOffsets();
dmaSizes = dmaOp.getSourceMixedSizes();
dmaStrides = dmaOp.getSourceMixedStrides();
} else {
dmaOffsets = dmaOp.getTargetMixedOffsets();
dmaSizes = dmaOp.getTargetMixedSizes();
dmaStrides = dmaOp.getTargetMixedStrides();
}

if (failed(isL3AddressingCombinable(dmaOffsets, dmaSizes, dmaStrides))) {
return false;
}
}
// Check if L2 addressing is linear.
if (auto circularDma = dyn_cast<AMDAIE::NpuCircularDmaCpyNdOp>(user)) {
// Circular dma op could have both source and target addressing empty.
if (circularDma.hasSourceAddressing() &&
circularDma.hasTargetAddressing()) {
circularDma.emitOpError()
<< "should not have both source and target addressing";
return failure();
}

SmallVector<OpFoldResult> circularOffsets;
SmallVector<OpFoldResult> circularSizes;
SmallVector<OpFoldResult> circularStrides;

if (circularDma.hasSourceAddressing()) {
circularOffsets = circularDma.getSourceMixedOffsets();
circularSizes = circularDma.getSourceMixedSizes();
circularStrides = circularDma.getSourceMixedStrides();
}
if (circularDma.hasTargetAddressing()) {
circularOffsets = circularDma.getTargetMixedOffsets();
circularSizes = circularDma.getTargetMixedSizes();
circularStrides = circularDma.getTargetMixedStrides();
}
if (!isL2AddressingLinear(circularOffsets, circularSizes,
circularStrides)) {
return false;
}
}
}
return true;
}

/// Utility to change the addressing of NpuDmaCpyNdOp and NpuCircularDmaCpyNdOp
/// in place. If the source of NpuDmaCpyNdOp is in L3, then the source
/// addressing from NpuDmaCpyNdOp and target addressing from
/// NpuCircularDmaCpyNdOp need to be changed. The other way around.
static LogicalResult createNewAddressing(
MLIRContext *ctx, SmallVector<OpFoldResult> &dmaOffsets,
SmallVector<OpFoldResult> &dmaSizes, SmallVector<OpFoldResult> &dmaStrides,
SmallVector<OpFoldResult> &circularDmaOffsets,
SmallVector<OpFoldResult> &circularDmaSizes,
SmallVector<OpFoldResult> &circularDmaStrides) {
IRRewriter rewriter(ctx);

// Make copies of L3 original sizes and strides which will be needed later
// when creating new L2 addressing.
SmallVector<OpFoldResult> l3OrigSizes = dmaSizes;
SmallVector<OpFoldResult> l3OrigStrides = dmaStrides;

FailureOr<size_t> isCombinable =
isL3AddressingCombinable(dmaOffsets, dmaSizes, dmaStrides);
if (failed(isCombinable)) {
return emitError(rewriter.getUnknownLoc())
<< "failed to get dim position for combination";
}
size_t dimForCombine = isCombinable.value();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, you're looking for a single dimension that can be combined with the innermost dimension? Ideally, this would work for multiple dimensions as well. for example:

[[0, 0, 0, 0] [3, 2, 32, 32] [64, 32, 128, 1]]

should become:

[[0, 0] [32, 192] [128, 1]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this case should be considered. However, I'd leave out such change in this revision until the logic of new sizes/strides (see the other comments below) is confirmed correct.


// Generate L3 side new source offsets/sizes/strides.
// Example: [[0, 0, 0] [2, 32, 32] [32, 128, 1]] will become
// [[0, 0] [32, 64] [128, 1]] after the first and the innermost dims are
// combined.
DenseSet<size_t> excludeDims = {dimForCombine};
dmaOffsets = copyExcludeDims(dmaOffsets, excludeDims);
dmaStrides = copyExcludeDims(dmaStrides, excludeDims);

std::optional<SmallVector<int64_t>> maybeSizes =
getConstantIntValues(l3OrigSizes);
std::optional<SmallVector<int64_t>> maybeStrides =
getConstantIntValues(l3OrigStrides);
if (!maybeSizes.has_value() || !maybeSizes.has_value()) {
return emitError(rewriter.getUnknownLoc())
<< "failed to get original source sizes / strides.";
}
SmallVector<int64_t> sizeVals = maybeSizes.value();
SmallVector<int64_t> strideVals = maybeStrides.value();

int64_t innerDimTotal = strideVals.back() * sizeVals.back();
int64_t newInnerSize = sizeVals[dimForCombine] * innerDimTotal;

size_t lastIndex = l3OrigSizes.size() - 1;
excludeDims.insert(lastIndex);
dmaSizes = copyExcludeDims(dmaSizes, excludeDims);
dmaSizes.push_back(getAsIndexOpFoldResult(ctx, newInnerSize));

// Generate L2 side new target offsets/sizes/strides.
SmallVector<OpFoldResult> newCircularOffsets(l3OrigSizes.size(),
rewriter.getIndexAttr(0));
circularDmaOffsets = newCircularOffsets;

circularDmaSizes = copyExcludeDims(l3OrigSizes, excludeDims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this copy? I would have expected copyIncludeDims, i.e. the dimensions that were excluded on the L3 side, should be included on the L2 side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to copy the dimensions that are not supposed to be combined from the original L3 addressing, and "exclude" those dimensions that need to be combined.

circularDmaSizes.push_back(
getAsIndexOpFoldResult(ctx, sizeVals[dimForCombine]));
circularDmaSizes.push_back(getAsIndexOpFoldResult(ctx, innerDimTotal));

// Function to create new strides for NpuCircularDmaCpyNdOp.
auto getNewL2Strides = [&](SmallVector<int64_t> values) {
SmallVector<OpFoldResult> res = {getAsIndexOpFoldResult(ctx, 1)};
int64_t initial = values.back();
// Leave out one dimension for insertion afterwards
for (size_t i = values.size() - 2; i > 0; i--) {
initial *= values[i];
res.push_back(getAsIndexOpFoldResult(ctx, initial));
}
return llvm::to_vector(llvm::reverse(res));
};
Comment on lines +212 to +221
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain how to calculate the new strides? I don't understand how it's just initial *= values[i];?

Copy link
Contributor Author

@yzhang93 yzhang93 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take the following dma ops for example:

%45 = amdaie.npu.circular_dma_cpy_nd %8([0] [2048] [1], [] [] [])
%46 = amdaie.npu.dma_cpy_nd %8([] [] [], %31[0, 0, 0, %41] [4, 2, 32, 32] [4096, 32, 128, 1])

The logic is to create L2 side strides from the innermost dimension, and then reverse the vector to have the final order. The new L2 side strides always start with [1], and should have the same number of dimensions as the original L3 side source addressing. The next dimensions are calculated by the logic initial *= l3OrigSizes[i].

The initial means the innermost continuous elements which is l3OrigSizes[-1]* l3OrigStrides.back[-1] (the implementation omit l3OrigStrides[-1] because l3OrigStrides[-1] == 1). The combined elements are now continuous on L3 side, but should have a strided addressing on L2 side, the stride should be initial * l3OrigSizes[-2]. So after this iteration, the strides are [1, 32 * 32].

Same logic for the next iteration, the strides become [1, 32 * 32, 32 * 32 * 2] = [1, 1024, 2048]. After reversion, it's [2048, 1024, 1]. At last insert the stride for the position of the combined dimension (e.g., index 1 in this example), which is l3OrigStride[dimForCombine], and get final strides [2048, 32, 1024, 1].

Let me know if this is the correct logic to get the L2 side strides, or if there's a better way to calculate this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combined elements are now continuous on L3 side, but should have a strided addressing on L2 side, the stride should be initial * l3OrigSizes[-2].

Yeah, I think the core idea is good: i.e. that the strides should be created as if the original L3 was contiguous and then rearranged based on which dimension(s) are combined with the innermost one.

However, I do think this needs extensive tests to ensure correctness as this will otherwise lead to hard-debug numerical errors in the future. So, it would be good to create a standalone utility function that takes in a set of static offsets/sizes/strides and produces the new static L3 and L2 offsets/sizes/strides, so that it can be tested standalone (ctest, not lit) on a lot of different cases, see for example: https://github.com/nod-ai/iree-amd-aie/blob/main/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp


circularDmaStrides = getNewL2Strides(sizeVals);
circularDmaStrides.insert(
circularDmaStrides.begin() + dimForCombine,
getAsIndexOpFoldResult(ctx, strideVals[dimForCombine]));
return success();
}

/// Walk through all users of a connection op and change the dma addressing of
/// NpuDmaCpyNdOp and NpuCircularDmaCpyNdOp at the same time. A connection op
/// can have multiple NpuDmaCpyNdOp users (with different offsets) but only one
/// NpuCircularDmaCpyNdOp user.
static LogicalResult transferDmaAddressing(MLIRContext *ctx,
AMDAIE::ConnectionOp connectionOp) {
IRRewriter rewriter(ctx);

FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
connectionOp.getNpuCircularDmaCpyNdUser();
if (failed(maybeNpuDmaUserOp)) {
connectionOp.emitOpError() << "failed to get circular NPU DMA op user";
return failure();
}

AMDAIE::NpuCircularDmaCpyNdOp circularDma = maybeNpuDmaUserOp.value();
SmallVector<OpFoldResult> srcCircularOffsets =
circularDma.getSourceMixedOffsets();
SmallVector<OpFoldResult> srcCircularSizes =
circularDma.getSourceMixedSizes();
SmallVector<OpFoldResult> srcCircularStrides =
circularDma.getSourceMixedStrides();
SmallVector<OpFoldResult> tgtCircularOffsets =
circularDma.getTargetMixedOffsets();
SmallVector<OpFoldResult> tgtCircularSizes =
circularDma.getTargetMixedSizes();
SmallVector<OpFoldResult> tgtCircularStrides =
circularDma.getTargetMixedStrides();

// Change the source/target addressing of all users from a connection op.
for (Operation *user : connectionOp->getUsers()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if different NPU DMA users have different strides/sizes/offsets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we would have such cases. I looked through the current IR, and only find the case that the connection op has multiple NpuDmaCpyNdOp users (just with different offsets) and one NpuCircularDmaCpyNdOp user.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do see it with peeled matmul and regardless, we should check for it and return/abort.

if (auto dmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(user)) {
SmallVector<OpFoldResult> srcOffsets = dmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult> srcSizes = dmaOp.getSourceMixedSizes();
SmallVector<OpFoldResult> srcStrides = dmaOp.getSourceMixedStrides();
SmallVector<OpFoldResult> tgtOffsets = dmaOp.getTargetMixedOffsets();
SmallVector<OpFoldResult> tgtSizes = dmaOp.getTargetMixedSizes();
SmallVector<OpFoldResult> tgtStrides = dmaOp.getTargetMixedStrides();

// Generate new L3 source addressing, and L2 target addressing.
if (dmaOp.getSourceMemorySpaceAsUInt() == 0) {
if (circularDma.getTargetMemorySpaceAsUInt() != 1) {
dmaOp.emitOpError() << "has source in L3, but circular dma doesn't "
"have target in L2.";
return failure();
}
if (failed(createNewAddressing(ctx, srcOffsets, srcSizes, srcStrides,
tgtCircularOffsets, tgtCircularSizes,
tgtCircularStrides))) {
return failure();
}
}

// Generate new L3 target addressing, and L2 source addressing.
if (dmaOp.getTargetMemorySpaceAsUInt() == 0) {
if (circularDma.getSourceMemorySpaceAsUInt() != 1) {
dmaOp.emitOpError() << "has target in L3, but circular dma doesn't "
"have source in L2.";
return failure();
}
if (failed(createNewAddressing(ctx, tgtOffsets, tgtSizes, tgtStrides,
srcCircularOffsets, srcCircularSizes,
srcCircularStrides))) {
return failure();
}
}

// Replace the npu.dma_cpy_nd with the combined access pattern.
rewriter.setInsertionPoint(dmaOp);
dmaOp = rewriter.replaceOpWithNewOp<AMDAIE::NpuDmaCpyNdOp>(
dmaOp, dmaOp.getConnection(), dmaOp.getTarget(), tgtOffsets, tgtSizes,
tgtStrides, dmaOp.getTargetBdId(), dmaOp.getSource(), srcOffsets,
srcSizes, srcStrides, dmaOp.getSourceBdId());
}
}

// Replace the npu.circular_dma_cpy_nd with the new access pattern.
rewriter.setInsertionPoint(circularDma);
circularDma = rewriter.replaceOpWithNewOp<AMDAIE::NpuCircularDmaCpyNdOp>(
circularDma, circularDma.getConnection(), tgtCircularOffsets,
tgtCircularSizes, tgtCircularStrides, srcCircularOffsets,
srcCircularSizes, srcCircularStrides);
return success();
}

class AMDAIETransferStridedAccessPatternPass
: public impl::AMDAIETransferStridedAccessPatternBase<
AMDAIETransferStridedAccessPatternPass> {
public:
AMDAIETransferStridedAccessPatternPass() = default;
AMDAIETransferStridedAccessPatternPass(
const AMDAIETransferStridedAccessPatternPass &pass){};
void runOnOperation() override;
};

void AMDAIETransferStridedAccessPatternPass::runOnOperation() {
Operation *parentOp = getOperation();
MLIRContext *ctx = &getContext();

// Walk the NpuDmaCpyNdOp ops and get the defining connections between L3 and
// L2 objectFifos. Then go through all users of each connection op and check
// if there is optimization opportunity to transfer strided access pattern
// from L3 to L2 side. Currently, a connection op can have multiple
// NpuDmaCpyNdOp users but only one NpuCircularDmaCpyNdOp user.
DenseSet<AMDAIE::ConnectionOp> connectionOps;
WalkResult walkRes = parentOp->walk([&](NpuDmaCpyNdOp dmaOp) {
AMDAIE::ConnectionOp connectionOp = dmaOp.getConnectionOp();
if (!connectionOp) {
dmaOp.emitOpError() << "no connection op is found";
return WalkResult::interrupt();
}
if (connectionOps.contains(connectionOp)) {
return WalkResult::advance();
}

FailureOr<bool> checkRes = checkConnectionUsers(connectionOp);
if (failed(checkRes)) {
return WalkResult::interrupt();
}
if (checkRes.value()) {
connectionOps.insert(connectionOp);
}
return WalkResult::advance();
});
if (walkRes.wasInterrupted()) return signalPassFailure();

// Walk through all users of each connection op and change the dma addressing
// from NpuDmaCpyNdOp and NpuCircularDmaCpyNdOp at the same time.
for (AMDAIE::ConnectionOp connectionOp : connectionOps) {
if (failed(transferDmaAddressing(ctx, connectionOp))) {
return signalPassFailure();
}
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIETransferStridedAccessPatternPass() {
return std::make_unique<AMDAIETransferStridedAccessPatternPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_cc_library(
"AMDAIETemporaryAllocBufferization.cpp"
"AMDAIETile.cpp"
"AMDAIETileAndFuse.cpp"
"AMDAIETransferStridedAccessPattern.cpp"
"AMDAIEUtils.cpp"
"AMDAIEVectorization.cpp"
"BridgeToAIRPass.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIETEMPORARYALLOCBUFFERIZATION
#define GEN_PASS_DEF_AMDAIETILE
#define GEN_PASS_DEF_AMDAIETILEANDFUSE
#define GEN_PASS_DEF_AMDAIETRANSFERSTRIDEDACCESSPATTERN
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
#include "iree-amd-aie/Transforms/Passes.h.inc"

Expand Down
Loading
Loading