Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPUVectorDistribute] Re-arrange nested layouts for better conversions #19437

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
}
};

struct DistributeBatchOuterToLayoutConversions final
struct DistributeInThreadToLayoutConversions final
: OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;

Expand All @@ -874,7 +874,7 @@ struct DistributeBatchOuterToLayoutConversions final
return rewriter.notifyMatchFailure(toLayoutOp, "non-nested layout");
}

// Check if everything other than batch and outer tile matches.
// Check if everything other out-of-thread tiles matches.
if (layoutA.getSubgroupTile() != layoutB.getSubgroupTile()) {
return failure();
}
Expand All @@ -887,9 +887,6 @@ struct DistributeBatchOuterToLayoutConversions final
if (layoutA.getThreadStrides() != layoutB.getThreadStrides()) {
return failure();
}
if (layoutA.getElementTile() != layoutB.getElementTile()) {
return failure();
}

auto batchTileA = SmallVector<int64_t>(layoutA.getBatchTile());
auto outerTileA = SmallVector<int64_t>(layoutA.getOuterTile());
Expand All @@ -906,17 +903,19 @@ struct DistributeBatchOuterToLayoutConversions final
SmallVector<int64_t> shapeB = layoutB.getDistributedShape();
int64_t rank = layoutA.getRank();

// Interleave batch and outer dims by transposing.
// Interleave in-thread elements by transposing.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this is generally not right.. thanks @Groverkss point it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works currently because the producer of this is a transfer_read


// Build a permutation for interleaving.
auto interleavePermutation =
llvm::to_vector(llvm::seq<int64_t>(shapeA.size()));
for (int i = 0; i < rank; ++i) {
// Batch tile : [0...rank]
// OuterTile : [rank+1...2*rank]
// Interleave : [batch0, outer0, batch1, outer1,...]
interleavePermutation[2 * i] = i;
interleavePermutation[2 * i + 1] = i + rank;
// OuterTile : [rank+1...2*rank]
// ElementTile :[2*rank+1...3*rank]
// Interleave : [batch0, outer0, element0, batch1, outer1, element1, ...]
interleavePermutation[3 * i] = i;
interleavePermutation[3 * i + 1] = i + rank;
interleavePermutation[3 * i + 2] = i + 2 * rank;
}

auto interleaved = rewriter.create<vector::TransposeOp>(
Expand Down Expand Up @@ -1153,7 +1152,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeInThreadToLayoutConversions>(patterns.getContext());
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,54 @@ struct GPUVectorAllocPass final

VectorType inputTy = cast<VectorType>(op.getType());
Value read = readVectorFromTensor(builder, inputTy, synced.getResult(0));
operand.set(read);

if (IREE::VectorExt::NestedLayoutAttr vectorLayout =
dyn_cast<IREE::VectorExt::NestedLayoutAttr>(op.getLayoutAttr())) {
// Re-arrange the layout to read large as possible vectors
// from shared memory. This is done by pulling in elements
// from in-thread tiles : batch & outer into element tile.
SmallVector<int64_t> elementTile =
llvm::to_vector(vectorLayout.getElementTile());
SmallVector<int64_t> batchTile =
llvm::to_vector(vectorLayout.getBatchTile());
SmallVector<int64_t> outerTile =
llvm::to_vector(vectorLayout.getOuterTile());
int64_t &elementTileLen = elementTile.back();
int64_t &batchTileLen = batchTile.back();
int64_t &outerTileLen = outerTile.back();
// TODO: maybe we should obtain this from somewhere ?
constexpr int64_t maxVecLenBits = 128;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to obtain arch specific information in a pass ?

// Pull in in-thread elements to reach max
// vector length.
Type elemType = op.getType().getElementType();
int64_t maxVecLen = maxVecLenBits / elemType.getIntOrFloatBitWidth();
int64_t remainingElementsForVec = maxVecLen / elementTileLen;
if (remainingElementsForVec <= outerTileLen) {
elementTileLen *= remainingElementsForVec;
outerTileLen = outerTileLen / remainingElementsForVec;
} else {
elementTileLen *= outerTileLen;
outerTileLen = 1;
remainingElementsForVec /= outerTileLen;
if (remainingElementsForVec <= batchTileLen) {
elementTileLen *= remainingElementsForVec;
batchTileLen = batchTileLen / remainingElementsForVec;
} else {
elementTileLen *= batchTileLen;
batchTileLen = 1;
}
}
auto betterVecLayout = IREE::VectorExt::NestedLayoutAttr::get(
op.getContext(), vectorLayout.getSubgroupTile(), batchTile,
vectorLayout.getOuterTile(), vectorLayout.getThreadTile(),
elementTile, vectorLayout.getSubgroupStrides(),
vectorLayout.getThreadStrides());
auto betterVecLayoutOp = builder.create<IREE::VectorExt::ToLayoutOp>(
op.getLoc(), read, betterVecLayout);
operand.set(betterVecLayoutOp);
} else {
operand.set(read);
}

// Remove the shared_memory_conversion attribute from the to_layout
// operation.
Expand Down
Loading