|
| 1 | +diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td |
| 2 | +index cb4b8c2468d7..4dcf95a0f87a 100644 |
| 3 | +--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td |
| 4 | ++++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td |
| 5 | +@@ -573,6 +573,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { |
| 6 | + return getAttrs().getAs<ArrayAttr>("stride"); |
| 7 | + } |
| 8 | + |
| 9 | ++ ArrayAttr getBlockAttr() { |
| 10 | ++ return getAttrs().getAs<ArrayAttr>("block"); |
| 11 | ++ } |
| 12 | ++ |
| 13 | + }]; |
| 14 | + |
| 15 | + } |
| 16 | +diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td |
| 17 | +index f8b371db498e..93642c2166e1 100644 |
| 18 | +--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td |
| 19 | ++++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td |
| 20 | +@@ -232,7 +232,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m |
| 21 | + return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); |
| 22 | + } |
| 23 | + |
| 24 | +- ArrayAttr getStrides() { |
| 25 | ++ ArrayAttr getStridesAttr() { |
| 26 | + auto layout = getMemLayout(); |
| 27 | + if (layout && layout.hasAttr("stride")) { |
| 28 | + return layout.getStrides(); |
| 29 | +@@ -245,6 +245,106 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m |
| 30 | + Builder builder(getContext()); |
| 31 | + return builder.getI64ArrayAttr(defaultStrides); |
| 32 | + } |
| 33 | ++ |
| 34 | ++ /// Heuristic to determine if the MemDesc uses column-major layout, |
| 35 | ++ /// based on the rank and the value of the first stride dimension. |
| 36 | ++ bool isColMajor() { |
| 37 | ++ auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]); |
| 38 | ++ return getRank() == 2 && dim0 && dim0.getInt() == 1; |
| 39 | ++ } |
| 40 | ++ |
| 41 | ++ // get the Blocking shape for a MemDescType, Which is represented |
| 42 | ++ // as an attribute in MemDescType. By default it is the shape |
| 43 | ++ // of the mdescTy |
| 44 | ++ SmallVector<int64_t> getBlockSize() { |
| 45 | ++ SmallVector<int64_t> size(getShape()); |
| 46 | ++ MemLayoutAttr layout = getMemLayout(); |
| 47 | ++ if (layout && layout.hasAttr("block")) { |
| 48 | ++ ArrayAttr attr = layout.getBlockAttr(); |
| 49 | ++ size.clear(); |
| 50 | ++ llvm::for_each(attr, [&](Attribute elem) { |
| 51 | ++ if (auto intElem = dyn_cast<IntegerAttr>(elem)) |
| 52 | ++ size.push_back(intElem.getInt()); |
| 53 | ++ }); |
| 54 | ++ } |
| 55 | ++ return size; |
| 56 | ++ } |
| 57 | ++ |
| 58 | ++ // Get strides as vector of integer. |
| 59 | ++ // If it contains block attribute, the strides are blocked strides. |
| 60 | ++ // |
| 61 | ++ // The blocking is applied against the original matrix shape |
| 62 | ++ // so that the linear offset is not impacted by the subview. |
| 63 | ++ // |
| 64 | ++ // It first computes the original matrix shape using the stride info, |
| 65 | ++ // then computes the number of blocks in each dimension of original shape, |
| 66 | ++ // then compute the outer block shape and stride, |
| 67 | ++ // then combines the inner and outer block shape and stride |
| 68 | ++ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> |
| 69 | ++ // its memory layout tuple is ([2,32,16,8],[128,256,1,8]) |
| 70 | ++ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] |
| 71 | ++ // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) |
| 72 | ++ SmallVector<int64_t> getStrides() { |
| 73 | ++ |
| 74 | ++ SmallVector<int64_t> matrixShape(getShape().begin(), |
| 75 | ++ getShape().end()); |
| 76 | ++ |
| 77 | ++ ArrayAttr strideAttr = getStridesAttr(); |
| 78 | ++ SmallVector<int64_t> strides; |
| 79 | ++ for (Attribute attr : strideAttr.getValue()) { |
| 80 | ++ strides.push_back(cast<IntegerAttr>(attr).getInt()); |
| 81 | ++ } |
| 82 | ++ |
| 83 | ++ SmallVector<int64_t> innerBlkShape = getBlockSize(); |
| 84 | ++ if (innerBlkShape.empty()) |
| 85 | ++ return strides; |
| 86 | ++ |
| 87 | ++ SmallVector<int, 4> perm = llvm::to_vector<4>( |
| 88 | ++ llvm::seq<int>(0, strides.size())); |
| 89 | ++ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); |
| 90 | ++ |
| 91 | ++ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); |
| 92 | ++ |
| 93 | ++ SmallVector<int64_t> innerBlkStride(innerBlkShape.size()); |
| 94 | ++ innerBlkStride[perm[0]] = 1; |
| 95 | ++ for (size_t i = 1; i < perm.size(); ++i) |
| 96 | ++ innerBlkStride[perm[i]] = |
| 97 | ++ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; |
| 98 | ++ |
| 99 | ++ // compute the original matrix shape using the stride info |
| 100 | ++ // and compute the number of blocks in each dimension |
| 101 | ++ // The shape of highest dim can't be derived from stride info, |
| 102 | ++ // but doesn't impact the stride computation for blocked layout. |
| 103 | ++ SmallVector<int64_t> matrixShapeOrig(matrixShape.size()); |
| 104 | ++ SmallVector<int64_t> BlkShapeOrig(matrixShape.size()); |
| 105 | ++ for (size_t i = 0; i < perm.size() - 1; ++i) { |
| 106 | ++ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; |
| 107 | ++ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; |
| 108 | ++ } |
| 109 | ++ |
| 110 | ++ int64_t innerBlkSize = 1; |
| 111 | ++ for (auto s : innerBlkShape) |
| 112 | ++ innerBlkSize *= s; |
| 113 | ++ |
| 114 | ++ SmallVector<int64_t> outerBlkStride(matrixShape.size()); |
| 115 | ++ outerBlkStride[perm[0]] = innerBlkSize; |
| 116 | ++ for (size_t i = 0; i < perm.size() - 1; ++i) { |
| 117 | ++ outerBlkStride[perm[i + 1]] = |
| 118 | ++ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; |
| 119 | ++ } |
| 120 | ++ |
| 121 | ++ // combine the inner and outer strides |
| 122 | ++ SmallVector<int64_t> blockedStrides; |
| 123 | ++ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); |
| 124 | ++ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); |
| 125 | ++ return blockedStrides; |
| 126 | ++ } |
| 127 | ++ /// Generates instructions to compute the linearize offset |
| 128 | ++ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout |
| 129 | ++ // the strides of memory descriptor is always considered regardless of blocked or not |
| 130 | ++ Value getLinearOffsets(OpBuilder &builder, |
| 131 | ++ Location loc, ArrayRef<OpFoldResult> offsets); |
| 132 | ++ |
| 133 | + }]; |
| 134 | + |
| 135 | + let hasCustomAssemblyFormat = true; |
| 136 | +diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp |
| 137 | +index 8ea8cb1f4597..808270534459 100644 |
| 138 | +--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp |
| 139 | ++++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp |
| 140 | +@@ -703,6 +703,89 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { |
| 141 | + } |
| 142 | + printer << ">"; |
| 143 | + } |
| 144 | ++// a helper utility to perform binary operation on OpFoldResult. |
| 145 | ++// If both a and b are attributes, it will simply return the result. |
| 146 | ++// Otherwise, the corresponding arith op will be generated, and an |
| 147 | ++// contant op will be created if one of them is an attribute. |
| 148 | ++template <typename ArithOp> |
| 149 | ++OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, |
| 150 | ++ OpBuilder &builder) { |
| 151 | ++ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); |
| 152 | ++ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); |
| 153 | ++ return builder.create<ArithOp>(loc, aVal, bVal).getResult(); |
| 154 | ++} |
| 155 | ++ |
| 156 | ++// a helper utility to perform division operation on OpFoldResult and int64_t. |
| 157 | ++#define div(a, b) \ |
| 158 | ++ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder) |
| 159 | ++ |
| 160 | ++// a helper utility to perform reminder operation on OpFoldResult and int64_t. |
| 161 | ++#define rem(a, b) \ |
| 162 | ++ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder) |
| 163 | ++ |
| 164 | ++// a helper utility to perform multiply operation on OpFoldResult and int64_t. |
| 165 | ++#define mul(a, b) \ |
| 166 | ++ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder) |
| 167 | ++ |
| 168 | ++// a helper utility to perform addition operation on two OpFoldResult. |
| 169 | ++#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder) |
| 170 | ++ |
| 171 | ++// block the given offsets according to the block shape |
| 172 | ++// say the original offset is [y, x], and the block shape is [By, Bx], |
| 173 | ++// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] |
| 174 | ++SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc, |
| 175 | ++ ArrayRef<OpFoldResult> offsets, |
| 176 | ++ ArrayRef<int64_t> blockShape) { |
| 177 | ++ |
| 178 | ++ assert(offsets.size() == blockShape.size() && |
| 179 | ++ "offsets and blockShape must have the same size"); |
| 180 | ++ SmallVector<OpFoldResult> blockedOffsets; |
| 181 | ++ SmallVector<OpFoldResult> divs, rems; |
| 182 | ++ |
| 183 | ++ for (auto [offset, block] : llvm::zip(offsets, blockShape)) { |
| 184 | ++ divs.push_back(div(offset, block)); |
| 185 | ++ rems.push_back(rem(offset, block)); |
| 186 | ++ } |
| 187 | ++ blockedOffsets.append(divs.begin(), divs.end()); |
| 188 | ++ blockedOffsets.append(rems.begin(), rems.end()); |
| 189 | ++ |
| 190 | ++ return blockedOffsets; |
| 191 | ++} |
| 192 | ++ |
| 193 | ++// Calculate the linear offset using the blocked offsets and stride |
| 194 | ++Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, |
| 195 | ++ ArrayRef<OpFoldResult> offsets) { |
| 196 | ++ |
| 197 | ++ SmallVector<int64_t> blockShape = getBlockSize(); |
| 198 | ++ SmallVector<int64_t> strides = getStrides(); |
| 199 | ++ if (!blockShape.empty()) { |
| 200 | ++ assert(offsets.size() == blockShape.size() && |
| 201 | ++ "offsets and blockShape must have the same size"); |
| 202 | ++ // say the original offset is [y, x], and the block shape is [By, Bx], |
| 203 | ++ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] |
| 204 | ++ SmallVector<OpFoldResult> blockedOffsets; |
| 205 | ++ SmallVector<OpFoldResult> divs, rems; |
| 206 | ++ |
| 207 | ++ for (auto [offset, block] : llvm::zip(offsets, blockShape)) { |
| 208 | ++ divs.push_back(div(offset, block)); |
| 209 | ++ rems.push_back(rem(offset, block)); |
| 210 | ++ } |
| 211 | ++ blockedOffsets.append(divs.begin(), divs.end()); |
| 212 | ++ blockedOffsets.append(rems.begin(), rems.end()); |
| 213 | ++ |
| 214 | ++ offsets = blockedOffsets; |
| 215 | ++ } |
| 216 | ++ |
| 217 | ++ // Start with initial value as matrix descriptor's base offset. |
| 218 | ++ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); |
| 219 | ++ for (size_t i = 0; i < offsets.size(); ++i) { |
| 220 | ++ OpFoldResult mulResult = mul(offsets[i], strides[i]); |
| 221 | ++ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); |
| 222 | ++ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); |
| 223 | ++ } |
| 224 | ++ |
| 225 | ++ return linearOffset; |
| 226 | ++} |
| 227 | + |
| 228 | + } // namespace xegpu |
| 229 | + } // namespace mlir |
| 230 | +diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp |
| 231 | +index ecee53c56a54..ba38d74f3c7f 100644 |
| 232 | +--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp |
| 233 | ++++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp |
| 234 | +@@ -1069,7 +1069,7 @@ LogicalResult MemDescSubviewOp::verify() { |
| 235 | + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) |
| 236 | + return emitOpError("result shape must not exceed source shape."); |
| 237 | + |
| 238 | +- if (srcTy.getStrides() != resTy.getStrides()) |
| 239 | ++ if (srcTy.getStridesAttr() != resTy.getStridesAttr()) |
| 240 | + return emitOpError("result must inherit the source strides."); |
| 241 | + |
| 242 | + return success(); |
0 commit comments