Skip to content

Commit

Permalink
Collapse multi-dimensional offsets to 1d for bf16 func arguments (Xil…
Browse files Browse the repository at this point in the history
…inx#468)

* Bump mlir-air

* Fixups on variable names and op erasing

* Fixup variable shadowing
  • Loading branch information
erwei-xilinx authored Mar 4, 2024
1 parent 982001f commit 5243a62
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
40 changes: 34 additions & 6 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ struct DmaToIpuPattern : public OpConversionPattern<DmaMemcpyNdOp> {
SmallVector<Value> offsets;
SmallVector<int64_t> staticOffsets;
if (auto const_int = getConstantIntValue(adaptor.getOffset3()))
staticOffsets.push_back(*const_int / div);
staticOffsets.push_back(*const_int);
else
offsets.push_back(divOp(adaptor.getOffset3()));
offsets.push_back(adaptor.getOffset3());
if (auto const_int = getConstantIntValue(adaptor.getOffset2()))
staticOffsets.push_back(*const_int / div);
staticOffsets.push_back(*const_int);
else
offsets.push_back(divOp(adaptor.getOffset2()));
offsets.push_back(adaptor.getOffset2());
if (auto const_int = getConstantIntValue(adaptor.getOffset1()))
staticOffsets.push_back(*const_int / div);
staticOffsets.push_back(*const_int);
else
offsets.push_back(divOp(adaptor.getOffset1()));
offsets.push_back(adaptor.getOffset1());
if (auto const_int = getConstantIntValue(adaptor.getOffset0()))
staticOffsets.push_back(*const_int / div);
else
Expand Down Expand Up @@ -344,6 +344,34 @@ static LogicalResult CastFunctionArgs(func::FuncOp funcOp,
rewriter.setInsertionPointToStart(&entry);
auto cast = rewriter.create<UnrealizedConversionCastOp>(
rewriter.getUnknownLoc(), memrefTy, entry.getArgument(i));
// With memref shape collapsed to 1d, the multi-dimensional offset also
// needs to be collapsed.
SmallVector<Operation *> users;
for (auto user : entry.getArgument(i + 1).getUsers()) {
if (auto cast_user = dyn_cast<UnrealizedConversionCastOp>(user)) {
assert(cast_user.getNumResults() == 1);
for (auto cast_r_user : cast_user.getResult(0).getUsers())
users.push_back(cast_r_user);
} else
users.push_back(user);
}
for (Operation *user : users) {
if (auto dmaUser = dyn_cast<AIEX::IpuDmaMemcpyNdOp>(user)) {
int oneDOffset = *getConstantIntValue(dmaUser.getMixedOffsets().back());
for (int j = dmaUser.getMixedOffsets().size() - 2; j >= 0; j--)
oneDOffset += *getConstantIntValue(dmaUser.getMixedOffsets()[j]) *
*getConstantIntValue(dmaUser.getMixedStrides()[j]);
rewriter.setInsertionPoint(dmaUser);
const std::vector<int64_t> newStaticOffsets = {0, 0, 0, oneDOffset};
rewriter.create<AIEX::IpuDmaMemcpyNdOp>(
rewriter.getUnknownLoc(), dmaUser.getX(), dmaUser.getY(),
dmaUser.getMemref(), SmallVector<Value>{}, dmaUser.getSizes(),
dmaUser.getStrides(), ArrayRef(newStaticOffsets),
dmaUser.getStaticSizes(), dmaUser.getStaticStrides(),
dmaUser.getMetadata(), dmaUser.getId());
rewriter.eraseOp(dmaUser);
}
}
entry.getArgument(i + 1).replaceAllUsesWith(cast.getResult(0));
entry.eraseArgument(i + 1);
}
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,25 @@ module {

// -----

// Multi-dimensional offset collapsing

// CHECK-LABEL: func.func @func13
// CHECK-SAME: %arg0: memref<512xi32>
// CHECK-NEXT: aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 264][1, 1, 16, 8][0, 0, 16]) {id = 0 : i64, metadata = @md0} : memref<512xi32>
module {
func.func @func13(%arg0 : memref<32x32xbf16>) {
%c1_i32 = arith.constant 1 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c16_i64 = arith.constant 16 : i64
%c32_i64 = arith.constant 32 : i64
airrt.dma_memcpy_nd(%c1_i32, %c0_i64, %c0_i64, %arg0[%c0_i64, %c0_i64, %c16_i64, %c16_i64], [%c1_i64, %c1_i64, %c16_i64, %c16_i64], [%c0_i64, %c0_i64, %c32_i64]) {metadata = @md0} : (i32, i64, i64, memref<32x32xbf16>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
return
}
}

// -----

// Loop carried event

// CHECK-LABEL: func.func @func14
Expand Down

0 comments on commit 5243a62

Please sign in to comment.