Skip to content

Commit

Permalink
[VectorDistribution] Fix 0-rank vector.broadcast distribution (iree-o…
Browse files Browse the repository at this point in the history
…rg#19007)

Fixes: iree-org#18955

Also removes locally carried revert
  • Loading branch information
Groverkss authored Nov 11, 2024
1 parent 55f2fce commit 7f7cfb0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern<vector::BroadcastOp> {
auto vectorType = VectorType::get(distShape, elementType);

VectorValue srcVector = dyn_cast<VectorValue>(broadcastOp.getSource());
// If the srcVector is a scalar (like f32) or a rank-0 vector (like
// vector<f32>), we proceed with the scalar distribution branch.
if (!srcVector || !isNonZeroRank(srcVector)) {
// If the srcVector is a scalar (like f32) we proceed with the scalar
// distribution branch.
if (!srcVector) {
// The way distribution currently works, there is no partial thread
// distribution, so a scalar is available to all threads. Scalar
// distribution is simply a broadcast from scalar to the distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,14 @@ void DistributionPattern::replaceOpWithDistributedValues(
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
// If this value is a vector type, it must be converted back to simd.
if (auto replacementType = dyn_cast<VectorType>(replacement.getType())) {
if (replacementType.getRank() != 0) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
if (isa<VectorType>(replacement.getType())) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
replacements.push_back(replacement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,47 @@ builtin.module attributes { transform.with_named_sequence } {

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 2, 2],
batch_tile = [2, 2, 1],
outer_tile = [2, 1, 1],
thread_tile = [4, 16, 8],
element_tile = [1, 4, 4],
subgroup_strides = [4, 2, 1],
thread_strides = [128, 8, 1]
>

func.func @zero_rank_broadcast(%src: vector<f16>) -> (vector<32x256x64xf16>) {
%bcast = vector.broadcast %src : vector<f16> to vector<32x256x64xf16>
%bcastl = iree_vector_ext.to_layout %bcast to layout(#layout) : vector<32x256x64xf16>
return %bcastl : vector<32x256x64xf16>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @zero_rank_broadcast
// CHECK-SAME: (%[[SRC:.*]]: vector<f16>)
// CHECK: %[[SRC_SIMT:.*]] = iree_vector_ext.to_simt %[[SRC]] : vector<f16>
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC_SIMT]]
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f16 to vector<1x4x4xf16>
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: %[[OUT:.*]] = vector.insert %[[BCAST]], %{{.*}}
// CHECK: iree_vector_ext.to_simd %[[OUT]] : vector<2x2x1x2x1x1x1x4x4xf16> -> vector<32x256x64xf16>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 2, 2],
batch_tile = [2, 2, 1],
Expand Down

0 comments on commit 7f7cfb0

Please sign in to comment.