Skip to content

Commit

Permalink
[VectorDistribution] Add option to set a default layout (iree-org#19367)
Browse files Browse the repository at this point in the history
This PR adds a way to allow vector distribution to set a default layout
for a vector for which a layout cannot be inferred. We currently enable
this for 0-D vectors, since they are trivially distributed.
  • Loading branch information
Groverkss authored Dec 5, 2024
1 parent 5dee2c8 commit 543fb31
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,42 @@ constexpr StringLiteral kVectorLayoutFetcherStorageAttrName =
constexpr StringLiteral kVectorLayoutRedistributeAttrName =
"__vector_layout_redistribute";

static void setOpSignature(Operation *op, VectorLayoutAnalysis &analysis) {
/// Set signature for the operation based on the analysis. Returns failure if
/// an operation contains vectors that cannot be distributed i.e. they have no
/// layout.
LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,
const VectorLayoutOptions &options) {
SmallVector<Attribute> operands;
SmallVector<Attribute> results;

for (Value operand : op->getOperands()) {
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
operands.push_back(
analysis.getLayout<VectorLayoutInterface>(vectorOperand));
continue;
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorOperand)) {
operands.push_back(layout);
continue;
}
if (auto layout = options.getDefaultLayout(vectorOperand.getType())) {
operands.push_back(layout);
continue;
}
return failure();
}
operands.push_back(UnitAttr::get(op->getContext()));
}

for (Value result : op->getResults()) {
if (auto vectorResult = dyn_cast<VectorValue>(result)) {
results.push_back(
analysis.getLayout<VectorLayoutInterface>(vectorResult));
continue;
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorResult)) {
results.push_back(layout);
continue;
}
if (auto layout = options.getDefaultLayout(vectorResult.getType())) {
results.push_back(layout);
continue;
}
return failure();
}
results.push_back(UnitAttr::get(op->getContext()));
}
Expand All @@ -58,6 +76,7 @@ static void setOpSignature(Operation *op, VectorLayoutAnalysis &analysis) {
Attribute signature[] = {operandsAttr, resultsAttr};
op->setAttr(kVectorLayoutFetcherStorageAttrName,
ArrayAttr::get(op->getContext(), signature));
return success();
}

static bool hasOpSignature(Operation *op) {
Expand Down Expand Up @@ -264,21 +283,6 @@ static void applyVectorDistribution(Operation *root,
}
}

static bool canDistribute(Operation *op, VectorLayoutAnalysis &analysis) {
auto values = llvm::to_vector_of<Value>(op->getOperands());
llvm::append_range(values, op->getResults());

// First check if any of them are vector values.
if (llvm::none_of(values, llvm::IsaPred<VectorValue>))
return false;

// Check if all operands and results of this operation have a layout.
return llvm::all_of(values, [&analysis](Value value) {
auto vectorValue = dyn_cast<VectorValue>(value);
return !vectorValue || analysis.getLayout<Attribute>(vectorValue);
});
}

LogicalResult distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options) {
Expand All @@ -294,8 +298,12 @@ LogicalResult distributeVectorOps(Operation *root,
LLVM_DEBUG(
llvm::dbgs() << "Setting distribution signatures for operations\n");
root->walk([&](Operation *op) {
if (canDistribute(op, analysis)) {
setOpSignature(op, analysis);
if (failed(setOpSignature(op, analysis, options))) {
LLVM_DEBUG({
llvm::dbgs() << "Skipping operation because not all vector "
"operands/results have a layout:\n";
op->print(llvm::dbgs());
});
}
});
LLVM_DEBUG(llvm::dbgs() << "Distribution signatures set\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class VectorLayoutOptions {

bool verifyConversion() const { return fullConversion; }

virtual VectorLayoutInterface getDefaultLayout(VectorType type) const = 0;

protected:
Operation *root;
bool fullConversion = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ func.func @distribute_scf_for(%arr: memref<32x32xf16>, %a: vector<32x32xf16>) ->
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<32x32xf16>
%b = arith.addf %rootl, %a : vector<32x32xf16>
%c = arith.extf %b : vector<32x32xf16> to vector<32x32xf32>
%init = vector.extractelement %arg0[] : vector<f32>
%init = vector.extract %arg0[] : f32 from vector<f32>
%root_red = vector.multi_reduction<add>, %c, %init [0, 1] : vector<32x32xf32> to f32
%d = vector.broadcast %root_red : f32 to vector<f32>
scf.yield %d : vector<f32>
Expand All @@ -1197,7 +1197,7 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf16> -> vector<2x2x1x1x1x4xf16>
// CHECK: %[[B:.*]] = arith.addf %{{.*}}, %[[A]]
// CHECK: %[[C:.*]] = arith.extf %[[B]]
// CHECK-NEXT: %[[D:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
// CHECK-NEXT: %[[D:.*]] = vector.extract %[[ARG0]][]
// Local reduction
// CHECK: vector.multi_reduction <add>, %[[C]], %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
// Global reduction
Expand Down Expand Up @@ -1291,7 +1291,7 @@ func.func @distribute_scf_for_contraction(%arr: memref<32x32xf16>, %a: vector<32
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<32x32xf16>
%b = arith.addf %rootl, %a : vector<32x32xf16>
%c = arith.extf %b : vector<32x32xf16> to vector<32x32xf32>
%init = vector.extractelement %arg0[] : vector<f32>
%init = vector.extract %arg0[] : f32 from vector<f32>
%root_red = vector.contract #contraction_trait %c, %c, %init : vector<32x32xf32>, vector<32x32xf32> into f32
%d = vector.broadcast %root_red : f32 to vector<f32>
scf.yield %d : vector<f32>
Expand All @@ -1313,7 +1313,7 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf16> -> vector<2x2x1x1x1x4xf16>
// CHECK: %[[B:.*]] = arith.addf %{{.*}}, %[[A]]
// CHECK: %[[C:.*]] = arith.extf %[[B]]
// CHECK-NEXT: %[[D:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
// CHECK-NEXT: %[[D:.*]] = vector.extract %[[ARG0]][]
// Local contraction
// CHECK: vector.contract {{.*}} vector<2x2x1x1x1x4xf32>, vector<2x2x1x1x1x4xf32> into f32
// Global reduction
Expand Down Expand Up @@ -1395,3 +1395,42 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK-NEXT: gpu.subgroup_reduce maxnumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// Accumulator reduction
// CHECK: arith.maxnumf %{{.*}}, %{{.*}} : vector<2x1x1xf32>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1],
batch_tile = [1],
outer_tile = [1],
thread_tile = [32],
element_tile = [2],

subgroup_strides = [1],
thread_strides = [1]
>

func.func @zero_d_vector_extract(%vec : vector<64xf32>, %acc : vector<f32>) -> f32 {
%layouted = iree_vector_ext.to_layout %vec to layout(#layout) : vector<64xf32>
%scalar = vector.extract %acc[] : f32 from vector<f32>
%out = vector.multi_reduction <add>, %layouted, %scalar [0] : vector<64xf32> to f32
return %out : f32
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @zero_d_vector_extract
// CHECK-SAME: %[[VEC:.+]]: vector<64xf32>, %[[ACC:.+]]: vector<f32>
// CHECK-DAG: %[[SIMT_ACC:.+]] = iree_vector_ext.to_simt %[[ACC]] : vector<f32> -> vector<f32>
// CHECK-DAG: %[[SCALAR:.+]] = vector.extract %[[SIMT_ACC]][] : f32 from vector<f32>
// CHECK-DAG: %[[SIMT:.+]] = iree_vector_ext.to_simt %[[VEC]] : vector<64xf32> -> vector<1x1x2xf32>
// CHECK: %[[LOCAL:.+]] = vector.multi_reduction <add>, %[[SIMT]], %{{.*}}
// CHECK: gpu.subgroup_reduce add %[[LOCAL]]
// Accumulator addition
// CHECK: %[[BROADCASTED:.+]] = vector.broadcast %[[SCALAR]] : f32 to vector<1xf32>
// CHECK: arith.addf %{{.*}}, %[[BROADCASTED]]
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,16 @@ class TestVectorLayoutOptions : public VectorLayoutOptions {
public:
TestVectorLayoutOptions(Operation *root)
: VectorLayoutOptions(root, /*fullConversion=*/false) {}

VectorLayoutInterface getDefaultLayout(VectorType type) const override {
// We only allow a default layout for 0-d vectors for now.
if (type.getRank() > 0) {
return VectorLayoutInterface();
}
ArrayRef<int64_t> empty = {};
return IREE::VectorExt::NestedLayoutAttr::get(
type.getContext(), empty, empty, empty, empty, empty, empty, empty);
}
};

DiagnosedSilenceableFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {

RewritePatternSet &getPatterns() { return patterns; }

VectorLayoutInterface getDefaultLayout(VectorType type) const override {
// We only allow a default layout for 0-d vectors for now.
if (type.getRank() > 0) {
return VectorLayoutInterface();
}
ArrayRef<int64_t> empty = {};
return IREE::VectorExt::NestedLayoutAttr::get(
type.getContext(), empty, empty, empty, empty, empty, empty, empty);
}

private:
RewritePatternSet patterns;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,10 @@ class TransformVectorLayoutOptions : public VectorLayoutOptions {
public:
TransformVectorLayoutOptions(Operation *root, bool fullConversion)
: VectorLayoutOptions(root, fullConversion) {}

VectorLayoutInterface getDefaultLayout(VectorType type) const override {
return VectorLayoutInterface();
}
};

DiagnosedSilenceableFailure
Expand Down

0 comments on commit 543fb31

Please sign in to comment.