diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index e2d773f13b37..1b8e0f8d82e8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -857,7 +857,7 @@ struct DistributeTranspose final : OpDistributionPattern { } }; -struct DistributeBatchOuterToLayoutConversions final +struct DistributeInThreadToLayoutConversions final : OpDistributionPattern { using OpDistributionPattern::OpDistributionPattern; @@ -874,7 +874,7 @@ struct DistributeBatchOuterToLayoutConversions final return rewriter.notifyMatchFailure(toLayoutOp, "non-nested layout"); } - // Check if everything other than batch and outer tile matches. + // Check if everything other out-of-thread tiles matches. if (layoutA.getSubgroupTile() != layoutB.getSubgroupTile()) { return failure(); } @@ -887,9 +887,6 @@ struct DistributeBatchOuterToLayoutConversions final if (layoutA.getThreadStrides() != layoutB.getThreadStrides()) { return failure(); } - if (layoutA.getElementTile() != layoutB.getElementTile()) { - return failure(); - } auto batchTileA = SmallVector(layoutA.getBatchTile()); auto outerTileA = SmallVector(layoutA.getOuterTile()); @@ -906,17 +903,19 @@ struct DistributeBatchOuterToLayoutConversions final SmallVector shapeB = layoutB.getDistributedShape(); int64_t rank = layoutA.getRank(); - // Interleave batch and outer dims by transposing. + // Interleave in-thread elements by transposing. // Build a permutation for interleaving. auto interleavePermutation = llvm::to_vector(llvm::seq(shapeA.size())); for (int i = 0; i < rank; ++i) { // Batch tile : [0...rank] - // OuterTile : [rank+1...2*rank] - // Interleave : [batch0, outer0, batch1, outer1,...] - interleavePermutation[2 * i] = i; - interleavePermutation[2 * i + 1] = i + rank; + // OuterTile : [rank+1...2*rank] + // ElementTile :[2*rank+1...3*rank] + // Interleave : [batch0, outer0, element0, batch1, outer1, element1, ...] + interleavePermutation[3 * i] = i; + interleavePermutation[3 * i + 1] = i + rank; + interleavePermutation[3 * i + 2] = i + 2 * rank; } auto interleaved = rewriter.create( @@ -1153,7 +1152,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns, maxBitsPerShuffle); patterns.add(patterns.getContext(), subgroupSize, maxBitsPerShuffle); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), threadId, subgroupSize); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp index 3eae712f02f9..dac15a32d7bb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp @@ -135,7 +135,54 @@ struct GPUVectorAllocPass final VectorType inputTy = cast(op.getType()); Value read = readVectorFromTensor(builder, inputTy, synced.getResult(0)); - operand.set(read); + + if (IREE::VectorExt::NestedLayoutAttr vectorLayout = + dyn_cast(op.getLayoutAttr())) { + // Re-arrange the layout to read large as possible vectors + // from shared memory. This is done by pulling in elements + // from in-thread tiles : batch & outer into element tile. + SmallVector elementTile = + llvm::to_vector(vectorLayout.getElementTile()); + SmallVector batchTile = + llvm::to_vector(vectorLayout.getBatchTile()); + SmallVector outerTile = + llvm::to_vector(vectorLayout.getOuterTile()); + int64_t &elementTileLen = elementTile.back(); + int64_t &batchTileLen = batchTile.back(); + int64_t &outerTileLen = outerTile.back(); + // TODO: maybe we should obtain this from somewhere ? + constexpr int64_t maxVecLenBits = 128; + // Pull in in-thread elements to reach max + // vector length. + Type elemType = op.getType().getElementType(); + int64_t maxVecLen = maxVecLenBits / elemType.getIntOrFloatBitWidth(); + int64_t remainingElementsForVec = maxVecLen / elementTileLen; + if (remainingElementsForVec <= outerTileLen) { + elementTileLen *= remainingElementsForVec; + outerTileLen = outerTileLen / remainingElementsForVec; + } else { + elementTileLen *= outerTileLen; + outerTileLen = 1; + remainingElementsForVec /= outerTileLen; + if (remainingElementsForVec <= batchTileLen) { + elementTileLen *= remainingElementsForVec; + batchTileLen = batchTileLen / remainingElementsForVec; + } else { + elementTileLen *= batchTileLen; + batchTileLen = 1; + } + } + auto betterVecLayout = IREE::VectorExt::NestedLayoutAttr::get( + op.getContext(), vectorLayout.getSubgroupTile(), batchTile, + vectorLayout.getOuterTile(), vectorLayout.getThreadTile(), + elementTile, vectorLayout.getSubgroupStrides(), + vectorLayout.getThreadStrides()); + auto betterVecLayoutOp = builder.create( + op.getLoc(), read, betterVecLayout); + operand.set(betterVecLayoutOp); + } else { + operand.set(read); + } // Remove the shared_memory_conversion attribute from the to_layout // operation.