Skip to content

Commit

Permalink
Fixup a bug when removing redundant dimensions; enable clearing of wr… (
Browse files Browse the repository at this point in the history
Xilinx#479)

* Fixup a bug when removing redundant dimensions; enable clearing of wrap-and-stride lists if default access pattern

* Rewrite success/failure conditions

* Code format

* Clang format
  • Loading branch information
erwei-xilinx authored Mar 7, 2024
1 parent 1dff131 commit 9503917
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 47 deletions.
3 changes: 2 additions & 1 deletion mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ void foldForLoopNestAsExtendedSizesAndStrides(
LogicalResult canonicalizeWrapAndStrideList(OpBuilder builder,
SmallVector<Value> &offsets,
SmallVector<Value> &sizes,
SmallVector<Value> &strides);
SmallVector<Value> &strides,
int memref_volume);

} // namespace air
} // namespace xilinx
Expand Down
19 changes: 17 additions & 2 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,23 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder,
strides.push_back(i64_one);

// Canonicalize wraps and strides
(void)air::canonicalizeWrapAndStrideList(builder, offsets, wraps, strides);

(void)air::canonicalizeWrapAndStrideList(
builder, offsets, wraps, strides, air::getTensorVolume(memref.getType()));

// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty()) {
auto memref_shape = air::getTensorShape(memref.getType());
int current_stride = air::getTensorVolume(memref.getType());
for (unsigned i = 0; i < memref_shape.size(); i++) {
offsets.push_back(builder.create<arith::ConstantIndexOp>(loc, 0));
wraps.push_back(
builder.create<arith::ConstantIndexOp>(loc, memref_shape[i]));
current_stride /= memref_shape[i];
strides.push_back(
builder.create<arith::ConstantIndexOp>(loc, current_stride));
}
}
xilinx::air::foldForLoopNestAsExtendedSizesAndStrides(
builder, for_op.getOperation(), memcpy_ops[0].getOperation(), offsets,
wraps, strides, memcpy_ops[0]->getOperand(3));
Expand Down
52 changes: 31 additions & 21 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,22 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
SmallVector<Value> offsets = channel_ops[0].getOffsets();
SmallVector<Value> wraps = channel_ops[0].getSizes();
SmallVector<Value> strides = channel_ops[0].getStrides();
for (auto o : for_loops) {
// Check for perfect loop nest containing only air.channel ops
if (!hasNElements(o.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(o.getBody()->begin())) {
} else if (isa<scf::ForOp>(o.getBody()->begin())) {
} else
return failure();
if (!getStaticScfForTripCountAsInt(o))
return failure();
}

(void)canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(channel_ops[0].getMemref().getType()));

// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty()) {
Expand All @@ -1714,25 +1730,13 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
rewriter.create<arith::ConstantIndexOp>(loc, current_stride));
}
}
for (auto o : for_loops) {
// Check for perfect loop nest containing only air.channel ops
if (!hasNElements(o.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(o.getBody()->begin())) {
} else if (isa<scf::ForOp>(o.getBody()->begin())) {
} else
return failure();
if (!getStaticScfForTripCountAsInt(o))
return failure();
}

(void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);

foldForLoopNestAsExtendedSizesAndStrides(
rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets,
wraps, strides, channel_ops[0].getMemref());

(void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);
(void)canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(channel_ops[0].getMemref().getType()));

Operation *new_chan_op = nullptr;
SmallVector<Type, 1> tys;
Expand Down Expand Up @@ -1823,13 +1827,17 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
return failure();
}

(void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);
(void)canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(channel_ops[0].getMemref().getType()));

foldForLoopNestAsExtendedSizesAndStrides(
rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets,
wraps, strides, channel_ops[0].getMemref());

(void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);
(void)canonicalizeWrapAndStrideList(
rewriter, offsets, wraps, strides,
air::getTensorVolume(channel_ops[0].getMemref().getType()));

Operation *new_chan_op = nullptr;
SmallVector<Type, 1> tys;
Expand Down Expand Up @@ -1881,8 +1889,9 @@ struct AIRCanonicalizeChannelPutOpWrapAndStrideList
SmallVector<Value> sizes = op.getSizes();
SmallVector<Value> strides = op.getStrides();

if (failed(
canonicalizeWrapAndStrideList(rewriter, offsets, sizes, strides)))
if (failed(canonicalizeWrapAndStrideList(
rewriter, offsets, sizes, strides,
air::getTensorVolume(op.getMemref().getType()))))
return failure();

auto new_op = rewriter.create<air::ChannelPutOp>(
Expand Down Expand Up @@ -1917,8 +1926,9 @@ struct AIRCanonicalizeChannelGetOpWrapAndStrideList
SmallVector<Value> sizes = op.getSizes();
SmallVector<Value> strides = op.getStrides();

if (failed(
canonicalizeWrapAndStrideList(rewriter, offsets, sizes, strides)))
if (failed(canonicalizeWrapAndStrideList(
rewriter, offsets, sizes, strides,
air::getTensorVolume(op.getMemref().getType()))))
return failure();

auto new_op = rewriter.create<air::ChannelGetOp>(
Expand Down
45 changes: 29 additions & 16 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,10 @@ void air::getDefiningOpsToOperands(Operation *op,
LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
SmallVector<Value> &offsets,
SmallVector<Value> &sizes,
SmallVector<Value> &strides) {
SmallVector<Value> &strides,
int memref_volume) {

bool listsHaveChanged = false;
// Match offsets size with sizes and strides
int max_dim_size =
std::max(std::max(offsets.size(), sizes.size()), strides.size());
Expand All @@ -816,6 +818,7 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
offsets.insert(offsets.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 0));
}
listsHaveChanged = true;
}

SmallVector<int> unit_dims;
Expand All @@ -831,40 +834,50 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
listsHaveChanged = true;
}

SmallVector<int> redundant_dims;
if (!sizes.empty())
if (!sizes.empty()) {
for (int i = sizes.size() - 1; i >= 1; i--) {
if (getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) &&
if (getConstantIntValue(offsets[i]) &&
getConstantIntValue(offsets[i - 1]) &&
getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) &&
getConstantIntValue(strides[i]) &&
getConstantIntValue(strides[i - 1])) {
auto const_size = *getConstantIntValue(sizes[i]);
auto const_size_next = *getConstantIntValue(sizes[i - 1]);
auto const_stride = *getConstantIntValue(strides[i]);
auto const_stride_next = *getConstantIntValue(strides[i - 1]);
// Skip over the first dimension if stride is 1
if (const_stride == 1 && i == (int)sizes.size() - 1)
continue;
if (const_stride_next == const_size * const_stride) {
redundant_dims.push_back(i - 1);
sizes[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), const_size * const_size_next);
offsets.erase(offsets.begin() + i - 1);
sizes.erase(sizes.begin() + i - 1);
strides.erase(strides.begin() + i - 1);
listsHaveChanged = true;
}
}
}

for (auto i : redundant_dims) {
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
}

if (unit_dims.empty() && redundant_dims.empty()) {
return failure();
// If default data access pattern, then clear the offsets, sizes and strides.
if (offsets.size() == 1 && sizes.size() == 1 && strides.size() == 1) {
if (getConstantIntValue(offsets[0]) && getConstantIntValue(sizes[0]) &&
getConstantIntValue(strides[0])) {
if (*getConstantIntValue(strides[0]) == 1 &&
*getConstantIntValue(sizes[0]) == memref_volume) {
offsets.erase(offsets.begin());
sizes.erase(sizes.begin());
strides.erase(strides.begin());
listsHaveChanged = true;
}
}
}

return success();
if (listsHaveChanged)
return success();
else
return failure();
}

// Fold perfectly nested for loops as extra entries in wraps and strides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ module {
}

// CHECK-LABEL: test1
// CHECK: put @channel_1[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_1[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>)
// CHECK: get @channel_2[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_3[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_4[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
Expand Down Expand Up @@ -101,7 +101,7 @@ module {
}

// CHECK-LABEL: test2
// CHECK: put @channel_6[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_6[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>)
// CHECK: get @channel_7[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_8[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_9[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
Expand Down Expand Up @@ -136,10 +136,10 @@ module {
}

// CHECK-LABEL: test3
// CHECK: put @channel_11[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_11[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>)
// CHECK: put @channel_12[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_13[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_13[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>)
// CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>)

func.func @test3(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -177,7 +177,7 @@ module {

// CHECK-LABEL: test4
// CHECK: put async @channel_15[%c0, %c0] (%arg0[%c0] [%c32] [%c1]) : (memref<128xf32>)
// CHECK: put async @channel_16[%c0, %c0] (%arg1[%c0, %c0] [%c16, %c4] [%c4, %c1]) : (memref<128x128xf32>)
// CHECK: put async @channel_16[%c0, %c0] (%arg1[%c0] [%c64] [%c1]) : (memref<128x128xf32>)

func.func @test4(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand All @@ -196,7 +196,7 @@ module {
}

// CHECK-LABEL: test5
// CHECK: put async @channel_17[] (%arg0[%c0, %c0, %c0] [%c8, %c32, %c32] [%c0, %c32, %c1]) : (memref<32x32xf32>)
// CHECK: put async @channel_17[] (%arg0[%c0, %c0] [%c8, %c1024] [%c0, %c1]) : (memref<32x32xf32>)

func.func @test5(%arg0: memref<32x32xf32>) -> memref<32x32xf32> {
%c0 = arith.constant 0 : index
Expand All @@ -209,4 +209,19 @@ module {
}
return %alloc : memref<32x32xf32>
}

// CHECK-LABEL: test6
// CHECK: put async @channel_18[] (%arg0[] [] []) : (memref<1x1x4x2x8x4xi32>)

func.func @test6(%arg0: memref<1x1x4x2x8x4xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%1 = air.channel.put async @channel_18[] (%arg0[%c0, %c0, %c0, %c0] [%c4, %c2, %c8, %c4] [%c64, %c32, %c4, %c1]) : (memref<1x1x4x2x8x4xi32>)
return
}
}

0 comments on commit 9503917

Please sign in to comment.