Skip to content

Commit 5c570ba

Browse files
Garra1980chencha3Jianhui-Li
authored
[XeGPU] add lowering pass for load_matrix and store_matrix (#1244) (#1118)
Co-authored-by: Chao Chen <[email protected]> Co-authored-by: Jianhui Li <[email protected]>
1 parent 52df995 commit 5c570ba

File tree

12 files changed

+999
-9
lines changed

12 files changed

+999
-9
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
2+
index cb4b8c2468d7..4dcf95a0f87a 100644
3+
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
4+
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
5+
@@ -573,6 +573,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
6+
return getAttrs().getAs<ArrayAttr>("stride");
7+
}
8+
9+
+ ArrayAttr getBlockAttr() {
10+
+ return getAttrs().getAs<ArrayAttr>("block");
11+
+ }
12+
+
13+
}];
14+
15+
}
16+
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
17+
index f8b371db498e..93642c2166e1 100644
18+
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
19+
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
20+
@@ -232,7 +232,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
21+
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
22+
}
23+
24+
- ArrayAttr getStrides() {
25+
+ ArrayAttr getStridesAttr() {
26+
auto layout = getMemLayout();
27+
if (layout && layout.hasAttr("stride")) {
28+
return layout.getStrides();
29+
@@ -245,6 +245,106 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
30+
Builder builder(getContext());
31+
return builder.getI64ArrayAttr(defaultStrides);
32+
}
33+
+
34+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
35+
+ /// based on the rank and the value of the first stride dimension.
36+
+ bool isColMajor() {
37+
+ auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
38+
+ return getRank() == 2 && dim0 && dim0.getInt() == 1;
39+
+ }
40+
+
41+
+ // get the Blocking shape for a MemDescType, Which is represented
42+
+ // as an attribute in MemDescType. By default it is the shape
43+
+ // of the mdescTy
44+
+ SmallVector<int64_t> getBlockSize() {
45+
+ SmallVector<int64_t> size(getShape());
46+
+ MemLayoutAttr layout = getMemLayout();
47+
+ if (layout && layout.hasAttr("block")) {
48+
+ ArrayAttr attr = layout.getBlockAttr();
49+
+ size.clear();
50+
+ llvm::for_each(attr, [&](Attribute elem) {
51+
+ if (auto intElem = dyn_cast<IntegerAttr>(elem))
52+
+ size.push_back(intElem.getInt());
53+
+ });
54+
+ }
55+
+ return size;
56+
+ }
57+
+
58+
+ // Get strides as vector of integer.
59+
+ // If it contains block attribute, the strides are blocked strides.
60+
+ //
61+
+ // The blocking is applied against the original matrix shape
62+
+ // so that the linear offset is not impacted by the subview.
63+
+ //
64+
+ // It first computes the original matrix shape using the stride info,
65+
+ // then computes the number of blocks in each dimension of original shape,
66+
+ // then compute the outer block shape and stride,
67+
+ // then combines the inner and outer block shape and stride
68+
+ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
69+
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,8])
70+
+ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
71+
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
72+
+ SmallVector<int64_t> getStrides() {
73+
+
74+
+ SmallVector<int64_t> matrixShape(getShape().begin(),
75+
+ getShape().end());
76+
+
77+
+ ArrayAttr strideAttr = getStridesAttr();
78+
+ SmallVector<int64_t> strides;
79+
+ for (Attribute attr : strideAttr.getValue()) {
80+
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
81+
+ }
82+
+
83+
+ SmallVector<int64_t> innerBlkShape = getBlockSize();
84+
+ if (innerBlkShape.empty())
85+
+ return strides;
86+
+
87+
+ SmallVector<int, 4> perm = llvm::to_vector<4>(
88+
+ llvm::seq<int>(0, strides.size()));
89+
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
90+
+
91+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
92+
+
93+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
94+
+ innerBlkStride[perm[0]] = 1;
95+
+ for (size_t i = 1; i < perm.size(); ++i)
96+
+ innerBlkStride[perm[i]] =
97+
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
98+
+
99+
+ // compute the original matrix shape using the stride info
100+
+ // and compute the number of blocks in each dimension
101+
+ // The shape of highest dim can't be derived from stride info,
102+
+ // but doesn't impact the stride computation for blocked layout.
103+
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
104+
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
105+
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
106+
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
107+
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
108+
+ }
109+
+
110+
+ int64_t innerBlkSize = 1;
111+
+ for (auto s : innerBlkShape)
112+
+ innerBlkSize *= s;
113+
+
114+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
115+
+ outerBlkStride[perm[0]] = innerBlkSize;
116+
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
117+
+ outerBlkStride[perm[i + 1]] =
118+
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
119+
+ }
120+
+
121+
+ // combine the inner and outer strides
122+
+ SmallVector<int64_t> blockedStrides;
123+
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
124+
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
125+
+ return blockedStrides;
126+
+ }
127+
+ /// Generates instructions to compute the linearize offset
128+
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
129+
+ // the strides of memory descriptor is always considered regardless of blocked or not
130+
+ Value getLinearOffsets(OpBuilder &builder,
131+
+ Location loc, ArrayRef<OpFoldResult> offsets);
132+
+
133+
}];
134+
135+
let hasCustomAssemblyFormat = true;
136+
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
137+
index 8ea8cb1f4597..808270534459 100644
138+
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
139+
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
140+
@@ -703,6 +703,89 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
141+
}
142+
printer << ">";
143+
}
144+
+// a helper utility to perform binary operation on OpFoldResult.
145+
+// If both a and b are attributes, it will simply return the result.
146+
+// Otherwise, the corresponding arith op will be generated, and an
147+
+// contant op will be created if one of them is an attribute.
148+
+template <typename ArithOp>
149+
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
150+
+ OpBuilder &builder) {
151+
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
152+
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
153+
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
154+
+}
155+
+
156+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
157+
+#define div(a, b) \
158+
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
159+
+
160+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
161+
+#define rem(a, b) \
162+
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
163+
+
164+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
165+
+#define mul(a, b) \
166+
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
167+
+
168+
+// a helper utility to perform addition operation on two OpFoldResult.
169+
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
170+
+
171+
+// block the given offsets according to the block shape
172+
+// say the original offset is [y, x], and the block shape is [By, Bx],
173+
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
174+
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
175+
+ ArrayRef<OpFoldResult> offsets,
176+
+ ArrayRef<int64_t> blockShape) {
177+
+
178+
+ assert(offsets.size() == blockShape.size() &&
179+
+ "offsets and blockShape must have the same size");
180+
+ SmallVector<OpFoldResult> blockedOffsets;
181+
+ SmallVector<OpFoldResult> divs, rems;
182+
+
183+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
184+
+ divs.push_back(div(offset, block));
185+
+ rems.push_back(rem(offset, block));
186+
+ }
187+
+ blockedOffsets.append(divs.begin(), divs.end());
188+
+ blockedOffsets.append(rems.begin(), rems.end());
189+
+
190+
+ return blockedOffsets;
191+
+}
192+
+
193+
+// Calculate the linear offset using the blocked offsets and stride
194+
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
195+
+ ArrayRef<OpFoldResult> offsets) {
196+
+
197+
+ SmallVector<int64_t> blockShape = getBlockSize();
198+
+ SmallVector<int64_t> strides = getStrides();
199+
+ if (!blockShape.empty()) {
200+
+ assert(offsets.size() == blockShape.size() &&
201+
+ "offsets and blockShape must have the same size");
202+
+ // say the original offset is [y, x], and the block shape is [By, Bx],
203+
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
204+
+ SmallVector<OpFoldResult> blockedOffsets;
205+
+ SmallVector<OpFoldResult> divs, rems;
206+
+
207+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
208+
+ divs.push_back(div(offset, block));
209+
+ rems.push_back(rem(offset, block));
210+
+ }
211+
+ blockedOffsets.append(divs.begin(), divs.end());
212+
+ blockedOffsets.append(rems.begin(), rems.end());
213+
+
214+
+ offsets = blockedOffsets;
215+
+ }
216+
+
217+
+ // Start with initial value as matrix descriptor's base offset.
218+
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
219+
+ for (size_t i = 0; i < offsets.size(); ++i) {
220+
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
221+
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
222+
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
223+
+ }
224+
+
225+
+ return linearOffset;
226+
+}
227+
228+
} // namespace xegpu
229+
} // namespace mlir
230+
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
231+
index ecee53c56a54..ba38d74f3c7f 100644
232+
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
233+
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
234+
@@ -1069,7 +1069,7 @@ LogicalResult MemDescSubviewOp::verify() {
235+
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
236+
return emitOpError("result shape must not exceed source shape.");
237+
238+
- if (srcTy.getStrides() != resTy.getStrides())
239+
+ if (srcTy.getStridesAttr() != resTy.getStridesAttr())
240+
return emitOpError("result must inherit the source strides.");
241+
242+
return success();
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
2+
index 2e00b42f4a56..15529d4c9b54 100644
3+
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
4+
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
5+
@@ -393,8 +393,65 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
6+
return success();
7+
}
8+
9+
+class ViewOpPattern final : public OpConversionPattern<memref::ViewOp> {
10+
+public:
11+
+ using OpConversionPattern<memref::ViewOp>::OpConversionPattern;
12+
+
13+
+ LogicalResult
14+
+ matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor,
15+
+ ConversionPatternRewriter &rewriter) const override;
16+
+};
17+
+
18+
+
19+
//===----------------------------------------------------------------------===//
20+
-// AllocOp
21+
+// ViewOp
22+
+// %view = memref.view %alloc[%c0][] : memref<2048xi8, 3> to memref<512xf32, 3>
23+
+// spirv.GlobalVariable @__workgroup_mem__1 : !spirv.ptr<!spirv.array<2048 x i8>, Workgroup>
24+
+// %1 = spirv.Bitcast @__workgroup_mem__1 : !spirv.ptr<!spirv.array<2048 x i8>, Workgroup> to !spirv.ptr<!spirv.array<512 x f32>, Workgroup>
25+
+//
26+
+//===----------------------------------------------------------------------===//
27+
+
28+
+LogicalResult
29+
+ViewOpPattern::matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor,
30+
+ ConversionPatternRewriter &rewriter) const {
31+
+ MemRefType ToType = operation.getType();
32+
+
33+
+ // insert spirv.bitcast which cast the pointer type from spirvFromType to spirvToType
34+
+ Type spirvToType = getTypeConverter()->convertType(ToType);
35+
+ if (!spirvToType)
36+
+ return rewriter.notifyMatchFailure(operation, "type conversion failed");
37+
+
38+
+ // need to limit the case where the source is a memref with element type i8
39+
+ // the result memref must have static sizes.
40+
+ MemRefType FromType = cast<MemRefType>(operation.getSource().getType());
41+
+ if (!FromType.getElementType().isInteger(8) || !FromType.hasStaticShape())
42+
+ return rewriter.notifyMatchFailure(operation, "unhandled view type");
43+
+ if (!ToType.hasStaticShape())
44+
+ return rewriter.notifyMatchFailure(operation, "unhandled view type");
45+
+
46+
+ // get base pointer from adaptor.getSource()
47+
+ Value basePtr = adaptor.getSource();
48+
+ // get the offset
49+
+ Value offset = adaptor.getByteShift();
50+
+ if (offset) {
51+
+ Location loc = operation.getLoc();
52+
+ auto *spirvTypeConverter = getTypeConverter<SPIRVTypeConverter>();
53+
+ Type materializedIndexType = spirvTypeConverter->getIndexType();
54+
+ Value basePtrAsInt = rewriter.createOrFold<spirv::ConvertPtrToUOp>(loc, materializedIndexType, basePtr);
55+
+ Value newPtrAsInt = rewriter.createOrFold<spirv::IAddOp>(loc, materializedIndexType, basePtrAsInt, offset);
56+
+ Value newPtr = rewriter.createOrFold<spirv::ConvertUToPtrOp>(loc, basePtr.getType(), newPtrAsInt);
57+
+ basePtr = newPtr;
58+
+ }
59+
+
60+
+ Location loc = operation.getLoc();
61+
+ Value castOp = rewriter.createOrFold<spirv::BitcastOp>(
62+
+ loc, spirvToType, basePtr);
63+
+ rewriter.replaceOp(operation, castOp);
64+
+ return success();
65+
+}
66+
+
67+
+//===----------------------------------------------------------------------===//
68+
+// AtomicRMWOp
69+
//===----------------------------------------------------------------------===//
70+
71+
LogicalResult
72+
@@ -1071,7 +1128,7 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
73+
namespace mlir {
74+
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
75+
RewritePatternSet &patterns) {
76+
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
77+
+ patterns.add<AllocaOpPattern, AllocOpPattern, ViewOpPattern, AtomicRMWOpPattern,
78+
DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
79+
IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
80+
StoreOpPattern, ReinterpretCastPattern, CastPattern,
81+
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
82+
index e6321e99693a..7308f000cdbe 100644
83+
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
84+
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
85+
@@ -446,6 +446,30 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
86+
87+
// -----
88+
89+
+// Check memref.view
90+
+
91+
+module attributes {
92+
+ spirv.target_env = #spirv.target_env<
93+
+ #spirv.vce<v1.0,
94+
+ [
95+
+ Kernel, Addresses, GenericPointer, Int8, Int64, StorageBuffer8BitAccess, Shader], [SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
96+
+} {
97+
+
98+
+// CHECK-LABEL: func @memory_view
99+
+// CHECK-SAME: (%[[ARG0:.+]]: memref<2048xi8, #spirv.storage_class<Function>>)
100+
+func.func @memory_view(%arg0: memref<2048xi8, #spirv.storage_class<Function>>)
101+
+ -> memref<512xf32, #spirv.storage_class<Function>> {
102+
+// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2048xi8, #spirv.storage_class<Function>> to !spirv.ptr<!spirv.array<2048 x i8>, Function>
103+
+// CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ARG0_CAST]] : !spirv.ptr<!spirv.array<2048 x i8>, Function> to !spirv.ptr<!spirv.array<512 x f32>, Function>
104+
+ %c0 = arith.constant 0: index
105+
+ %view = memref.view %arg0[%c0][] : memref<2048xi8, #spirv.storage_class<Function>> to memref<512xf32, #spirv.storage_class<Function>>
106+
+ return %view : memref<512xf32, #spirv.storage_class<Function>>
107+
+}
108+
+
109+
+}
110+
+
111+
+// -----
112+
+
113+
module attributes {
114+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Kernel, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
115+
} {

include/imex/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ std::unique_ptr<mlir::Pass> createHoistTransposePass();
4141
std::unique_ptr<mlir::Pass> createVnniTransformationPass();
4242
std::unique_ptr<mlir::Pass> createEmulateNonNativeBF16Pass();
4343
std::unique_ptr<mlir::Pass> createTileLoopsPass();
44+
std::unique_ptr<mlir::Pass> createMaterializeMatrixOpPass();
4445

4546
#define GEN_PASS_DECL
4647
#include "imex/Transforms/Passes.h.inc"

include/imex/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,19 @@ def TileLoops : Pass<"tile-loops", "::mlir::func::FuncOp"> {
226226
];
227227
}
228228

229+
def MaterializeMatrixOp: Pass<"imex-xegpu-materialize-matrix-op"> {
230+
let summary = "materialize matrix ops for Xe2/Xe3";
231+
let description = [{
232+
Coverts mem_desc operations (load_matrix, store_matrix) into other xegpu memory operations
233+
(load/store chunk, 1d block load) over shared local memory. It computes physical address
234+
using the matrix's layout attributes (@strides, @block) and logical lane coordinates.
235+
}];
236+
let constructor = "imex::createMaterializeMatrixOpPass()";
237+
let dependentDialects = [
238+
"::mlir::xegpu::XeGPUDialect",
239+
"::mlir::vector::VectorDialect",
240+
"::mlir::memref::MemRefDialect"
241+
];
242+
}
243+
229244
#endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_

0 commit comments

Comments
 (0)