Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DT][NFC] Internalize transposeNarrowN logic to LayoutAttrInterface Impl #19453

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
IREE::HAL::ExecutableTargetAttr targetAttr) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet materializeEncodingPattern(ctx);
// On CPU, we use transposeNarrowN=true for a combination of reasons:
// 1. As linalg.matmul materializes into linalg.mmt4d, which has a transposed
// RHS and therefore LHS<->RHS symmetry, transposeNarrowN is easy to
// implement at that level.
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
// so it is nice that they have fewer narrow cases to consider.
DictionaryAttr targetConfig = targetAttr.getConfiguration();
IREE::Codegen::LayoutAttrInterface layoutAttr;
if (isVMVXBackend(targetAttr)) {
Expand All @@ -85,8 +78,7 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::CPU::CPUEncodingLayoutAttr::get(ctx, targetConfig));
}
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/true, layoutAttr);
MaterializeEncodingTypeConverter typeConverter(layoutAttr);
MaterializeEncodingConversionTarget target(*ctx);
auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
populateMaterializeEncodingIntoPackUnPackPatterns(
Expand Down
77 changes: 4 additions & 73 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,76 +20,9 @@ using IREE::Encoding::EncodingAttr;
using IREE::Encoding::getEncodingAttr;
using IREE::Encoding::getEncodingContractionDims;

// If tensorType has the encoding of a matmul RESULT with narrow N, returns
// the transposed type. Otherwise, just returns tensorType.
static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting to see this code copied to some CPU-specific file, but I can't see it? Or is it being not needed anymore as we have a different way of performing the equivalent logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, this is what you are explaining in the PR description. Wow --- this is fantastic! I'm just now realizing the power of the refactoring you've been doing. Being able to simply transpose the encoding because it hasn't been materialized yet in a RankedTensorType. Amazing!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: this is making transpose-narrow-N so neat that it removes much of the objection to doing it on GPU, if a need arises -- such as maybe with GPU ukernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not needed because it is handled by getEncodingInfo method. Firstly, we get the encoding info. Secondly, we transpose the encoding info inplace if it is a narrow_n case. I think the reason of constructing the ranked tensor type with the new encoding is that we want to infer the same encoding info. This logic can be handled by the attribute, and we do not need to expose the logic to the type converter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, what you said. you beat me to this. :D

auto encoding =
llvm::dyn_cast_or_null<EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return tensorType;
}
if (!isNarrowNResult(encoding)) {
return tensorType;
}
SmallVector<int64_t> newOriginalShape(tensorType.getShape());
auto userIndexingMaps = encoding.getUserIndexingMaps();
SmallVector<AffineMap> maps;
for (auto a : userIndexingMaps) {
maps.push_back(cast<AffineMapAttr>(a).getAffineMap());
}
auto cDims = linalg::inferContractionDims(maps);
SmallVector<int64_t> newShape(tensorType.getShape());
SmallVector<int64_t> permIndices(maps[0].getNumDims());
std::iota(std::begin(permIndices), std::end(permIndices), 0);
// Matrix case: there are both M and N dimensions. Transposing means swapping
// them.
if (cDims->m.size() == 1 && cDims->n.size() == 1) {
int m = cDims->m[0];
int n = cDims->n[0];
std::swap(permIndices[m], permIndices[n]);
std::optional<unsigned> mDim = encoding.mapDimToOperandIndex(m);
std::optional<unsigned> nDim = encoding.mapDimToOperandIndex(n);
if (mDim.has_value() && nDim.has_value()) {
std::swap(newShape[mDim.value()], newShape[nDim.value()]);
std::swap(newOriginalShape[mDim.value()], newOriginalShape[nDim.value()]);
}
}
// Vector case: there is no N dimension to swap the M dimension with. We
// swap the maps themselves.
if (cDims->n.empty()) {
std::swap(maps[0], maps[1]);
}

SmallVector<int64_t> newRoundDimsTo(encoding.getRoundDimsToArray());
assert(newRoundDimsTo.size() == 0 || newRoundDimsTo.size() >= 3);
if (newRoundDimsTo.size() != 0) {
std::swap(newRoundDimsTo[newRoundDimsTo.size() - 3],
newRoundDimsTo[newRoundDimsTo.size() - 2]);
}
auto context = tensorType.getContext();
AffineMap permutation = AffineMap::getPermutationMap(permIndices, context);
for (auto &map : maps) {
map = map.compose(permutation);
}
auto elemType = tensorType.getElementType();
auto operandIndex = encoding.getOperandIndex().getInt();

// TODO(#17718): Handle the broadcast map for transpose cases. It is on the
// experimental path, so it is not clear what needs to be done here. For now
// just use the original map for the new encoding.
std::optional<AffineMap> newBcastMap;
if (encoding.getBcastMap()) {
newBcastMap = encoding.getBcastMap().getValue();
}
auto newEncoding = IREE::Encoding::EncodingAttr::get(
context, operandIndex, encoding.getOpType().getValue(),
encoding.getElementTypesArray(), maps, newBcastMap, newRoundDimsTo);
return RankedTensorType::get(newShape, elemType, newEncoding);
}

MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
bool transposeNarrowN, IREE::Codegen::LayoutAttrInterface layoutAttr)
: transposeNarrowN(transposeNarrowN), layoutAttr(layoutAttr) {
IREE::Codegen::LayoutAttrInterface layoutAttr)
: layoutAttr(layoutAttr) {
addConversion([](IntegerType intType) { return intType; });
addConversion([](IndexType indexType) { return indexType; });
addConversion([](FloatType floatType) { return floatType; });
Expand All @@ -98,14 +31,12 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
// For a given tensor type with an encoding, return the materialized
// type to use for it. If no encoding is set, then return the tensor type
// itself.
RankedTensorType tensorType =
transposeNarrowN ? transposeIfNarrowNResult(type) : type;
MaterializeEncodingInfo encodingInfo = getEncodingInfo(tensorType);
MaterializeEncodingInfo encodingInfo = getEncodingInfo(type);
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return dropEncoding(type);
}
auto packedType = cast<RankedTensorType>(tensor::PackOp::inferPackedType(
tensorType, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
type, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm));

// There is no swizzle, we are already done. Typically the case on CPU.
Expand Down
7 changes: 1 addition & 6 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using MaterializeEncodingValueFn =
class MaterializeEncodingTypeConverter : public TypeConverter {
public:
MaterializeEncodingTypeConverter(
bool transposeNarrowN, IREE::Codegen::LayoutAttrInterface layoutAttr);
IREE::Codegen::LayoutAttrInterface layoutAttr);

const IREE::Codegen::LayoutAttrInterface &getLayoutAttr() const {
return layoutAttr;
Expand All @@ -47,12 +47,7 @@ class MaterializeEncodingTypeConverter : public TypeConverter {
return layoutAttr.getEncodingInfo(type);
}

bool getTransposeNarrowN() const { return transposeNarrowN; }

private:
bool transposeNarrowN = false;
// TODO(hanchung): Move the logic that takes `transposeNarrowN` into account
// to their own attribute implementation.
const IREE::Codegen::LayoutAttrInterface layoutAttr;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,13 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
MLIRContext *ctx = funcOp.getContext();
{
RewritePatternSet patterns(ctx);
// On GPU, we use transposeNarrowN=false for a combination of reasons:
// 1. As linalg.matmul materializes into iree_gpu.multi_mma, which inherits
// its semantics from the wrapped intrinsic, we can't rely on any kind of
// LHS<->RHS symmetry.
// 2. We do not currently use ukernels, which would be one of the main areas
// to benefit from transposeNarrowN.
// 3. Heuristics for cache-friendly dispatch tiling are internal to the GPU
// runtime, so we don't need a simplification at that level either.
IREE::GPU::TargetAttr gpuTargetAttr;
if (targetAttr) {
gpuTargetAttr = getGPUTargetAttr(targetAttr);
} else {
gpuTargetAttr = getCLGPUTarget(ctx);
}
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/false,
cast<IREE::Codegen::LayoutAttrInterface>(
IREE::GPU::GPUEncodingLayoutAttr::get(ctx, gpuTargetAttr)));
MaterializeEncodingConversionTarget target(*ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ struct MaterializeEncodingIntoNopPass final

RewritePatternSet materializeEncodingPattern(context);
MaterializeEncodingTypeConverter typeConverter(
/*transposeNarrowN=*/false,
IREE::Codegen::EncodingNopLayoutAttr::get(context));
MaterializeEncodingConversionTarget target(*context);
populateMaterializeEncodingIntoPackUnPackPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
return result;
}

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
return;
}
// Not a vector case, so all three arrays in `info` have size at least 2,
// outerDimsPerm may have size 3 if there is a batch dimension, but in all
// cases, the last 2 entries of each array are M and N, not batch.
auto transpose = [](SmallVector<int64_t> &a) {
std::swap(a[a.size() - 2], a[a.size() - 1]);
};
transpose(info.innerDimsPos);
transpose(info.innerTileSizes);
transpose(info.outerDimsPerm);
}

//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
Expand All @@ -139,9 +123,6 @@ FailureOr<Value> lowerSetEncodingOpToPackOp(
if (!encoding) {
return failure();
}
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(encodingInfo);
}

// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
Expand Down Expand Up @@ -180,10 +161,6 @@ FailureOr<Value> lowerUnsetEncodingToUnpackOp(
return packedValue;
}

auto encoding = IREE::Encoding::getEncodingAttr(sourceType);
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(encodingInfo);
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
SmallVector<OpFoldResult> resultDims =
Expand Down Expand Up @@ -222,11 +199,6 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
.getOperation();
}

if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
transposeInPlace(encodingInfo);
}

FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
Expand Down Expand Up @@ -389,10 +361,6 @@ static FailureOr<SmallVector<OpFoldResult>> getPackedDimsForDispatchTensor(
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return failure();
}
if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
transposeInPlace(encodingInfo);
}

SmallVector<OpFoldResult> targetShape =
getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- CPUEncodingExternalModels.cpp --------------------------------------===//
//
// This file implements the IREE::Codegen::LayoutAttrInterface for CPU backends
// and the VMVX backend. In these backends, we transpose narrow-N into narrow-M
// for a combination of reasons:
//
// 1. As linalg.matmul materializes into linalg.mmt4d, which has a transposed
// RHS and therefore LHS<->RHS symmetry, transposeNarrowN is easy to
// implement at that level.
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
// so it is nice that they have fewer narrow cases to consider.
//
// This transposition is made easier by (and was all along part of the idea in)
// the RHS-transposition in mmt4d (the t in mmt4d), as generally with matrix
// multiplication
//
// B * Transpose(A) == Transpose( A * Transpose(B) )
//
// so in mmt4d terms
//
// mmt4d(B, A) == Transpose(mmt4d(A, B))
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.h"

Expand All @@ -12,6 +36,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"

Expand All @@ -28,6 +53,22 @@ namespace {
// Utilities.
//===----------------------------------------------------------------------===//

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
return;
}
// Not a vector case, so all three arrays in `info` have size at least 2,
// outerDimsPerm may have size 3 if there is a batch dimension, but in all
// cases, the last 2 entries of each array are M and N, not batch.
auto transpose = [](SmallVector<int64_t> &a) {
std::swap(a[a.size() - 2], a[a.size() - 1]);
};
transpose(info.innerDimsPos);
transpose(info.innerTileSizes);
transpose(info.outerDimsPerm);
}

static RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}
Expand Down Expand Up @@ -576,7 +617,11 @@ struct CPUDeviceEncodingLayoutAttrInterface
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK = chooseMatmulTile(
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());
return getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
info = getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
if (Encoding::isNarrowNResult(encoding)) {
transposeInPlace(info);
}
return info;
}

Operation *lowerOp(Attribute attr, OpBuilder &b, Operation *op,
Expand Down Expand Up @@ -660,7 +705,11 @@ struct VMVXDeviceEncodingLayoutAttrInterface
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK = chooseMatmulTile(
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());
return getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
info = getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
if (Encoding::isNarrowNResult(encoding)) {
transposeInPlace(info);
}
return info;
}

Operation *lowerOp(Attribute attr, OpBuilder &b, Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- GPUEncodingExternalModels.cpp --------------------------------------===//
//
// This file implements the IREE::Codegen::LayoutAttrInterface for GPU backends.
// Different from CPU backends, we do not tranpose narrow-N to narrow-M for a
// combination of reasons:
//
// 1. As linalg.matmul materializes into iree_gpu.multi_mma, which inherits
// its semantics from the wrapped intrinsic, we can't rely on any kind of
// LHS<->RHS symmetry.
// 2. We do not currently use ukernels, which would be one of the main areas
// to benefit from transposeNarrowN.
// 3. Heuristics for cache-friendly dispatch tiling are internal to the GPU
// runtime, so we don't need a simplification at that level either.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.h"

Expand Down
Loading