Skip to content

Commit 0069c12

Browse files
authored
[Test] SG level Flash Attention Implementation using SIMT lowering. (#1125)
1 parent 4c816f8 commit 0069c12

File tree

4 files changed

+853
-2
lines changed

4 files changed

+853
-2
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
2+
index 9ead1d89069d..3822d24c8579 100644
3+
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
4+
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
5+
@@ -189,9 +189,10 @@ class CreateNdDescToXeVMPattern
6+
// If source is a memref, we need to extract the aligned pointer as index.
7+
// Pointer type is passed as i32 or i64 by type converter.
8+
if (sourceMemrefTy) {
9+
- if (!sourceMemrefTy.hasStaticShape()) {
10+
- return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
11+
- }
12+
+ // if (!sourceMemrefTy.hasStaticShape()) {
13+
+ // return rewriter.notifyMatchFailure(op, "Expected static memref
14+
+ // shape.");
15+
+ // }
16+
baseAddr =
17+
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
18+
} else {
19+
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
20+
index e95338f7d18b..2615d225dc1d 100644
21+
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
22+
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
23+
@@ -348,6 +349,9 @@ private:
24+
/// d1) and return vector<16x2x64>
25+
static VectorType getDistributedType(VectorType originalType, AffineMap map,
26+
int64_t warpSize) {
27+
+ // If the map has zero results, that means no distribution.
28+
+ if (map.getNumResults() == 0)
29+
+ return originalType;
30+
SmallVector<int64_t> targetShape(originalType.getShape());
31+
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
32+
unsigned position = map.getDimPosition(i);
33+
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
34+
index f1dbc5ddb202..3023c65d4bc3 100644
35+
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
36+
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
37+
@@ -1506,9 +1506,14 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
38+
if (!layout)
39+
return AffineMap::getMultiDimMapWithTargets(
40+
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
41+
+ // Expecting layout and vector rank to match.
42+
+ assert(layout.getRank() == vecRank &&
43+
+ "vector rank and layout rank must match");
44+
+ // A dimension is distributed if its layout value is > 1 and the dimension
45+
+ // size is evenly divisible by the layout value.
46+
SmallVector<unsigned int> distributedDims;
47+
for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
48+
- if (v > 1)
49+
+ if (v > 1 && vecType.getShape()[i] % v == 0)
50+
distributedDims.push_back(i);
51+
}
52+
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,

lib/ExecutionEngine/ImexRunnerUtils.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cstdlib>
2020
#include <cstring>
2121
#include <iostream>
22+
#include <limits>
2223
#include <random>
2324

2425
// NOLINTBEGIN(*-identifier-naming)
@@ -223,21 +224,33 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
223224
std::pair<double, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
224225
std::pair<double, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
225226
uint64_t idx = 0;
227+
double max_rel_error_i = std::numeric_limits<double>::infinity(),
228+
max_rel_error_j = std::numeric_limits<double>::infinity();
229+
double max_abs_error_i = std::numeric_limits<double>::infinity(),
230+
max_abs_error_j = std::numeric_limits<double>::infinity();
226231
for (; i != DM.end() && j != DN.end(); ++i, ++j, ++idx) {
227232
const double i_val = getFloat(*i);
228233
const double j_val = getFloat(*j);
229234
const double delta = fabs(i_val - j_val);
230235
const double rel_error = delta / fmax(fabs(i_val), fabs(j_val));
231-
if (delta > max_abs_err_idx.first)
236+
if (delta > max_abs_err_idx.first) {
232237
max_abs_err_idx = {delta, i};
233-
if (rel_error > max_rel_err_idx.first)
238+
max_abs_error_i = i_val;
239+
max_abs_error_j = j_val;
240+
}
241+
if (rel_error > max_rel_err_idx.first) {
234242
max_rel_err_idx = {rel_error, i};
243+
max_rel_error_i = i_val;
244+
max_rel_error_j = j_val;
245+
}
235246
}
236247
std::cout << "Max absolute error " << max_abs_err_idx.first
237248
<< " at idx=" << std::distance(DM.begin(), max_abs_err_idx.second)
249+
<< " (i=" << max_abs_error_i << ", j=" << max_abs_error_j << ")"
238250
<< '\n';
239251
std::cout << "Max relative error " << max_rel_err_idx.first
240252
<< " at idx=" << std::distance(DM.begin(), max_rel_err_idx.second)
253+
<< " (i=" << max_rel_error_i << ", j=" << max_rel_error_j << ")"
241254
<< '\n';
242255
}
243256

0 commit comments

Comments
 (0)