Skip to content

Commit

Permalink
[StatefulTransform] Extract aie.flow generation (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls authored Sep 6, 2024
1 parent a9cfbf4 commit b6a3ba8
Show file tree
Hide file tree
Showing 42 changed files with 505 additions and 91 deletions.
13 changes: 13 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,4 +577,17 @@ LogicalResult DMABDOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AIE_FlowOp
//===----------------------------------------------------------------------===//

void FlowOp::build(OpBuilder &b, OperationState &result, Value source,
mlir::iree_compiler::AMDAIE::StrmSwPortType source_bundle,
uint8_t source_channel, Value dest,
mlir::iree_compiler::AMDAIE::StrmSwPortType dest_bundle,
uint8_t dest_channel) {
build(b, result, source, source_bundle, source_channel, dest, dest_bundle,
dest_channel, nullptr);
}

} // namespace xilinx::AIE
12 changes: 11 additions & 1 deletion compiler/plugins/target/AMD-AIE/aie/AIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,22 @@ def AIE_FlowOp: AIE_Op<"flow"> {
ConfinedAttr<I8Attr, [IntMinValue<0>]>:$source_channel,
Index:$dest,
StrmSwPortTypeAttr:$dest_bundle,
ConfinedAttr<I8Attr, [IntMinValue<0>]>:$dest_channel
ConfinedAttr<I8Attr, [IntMinValue<0>]>:$dest_channel,
OptionalAttr<FlatSymbolRefAttr>:$symbol
);
let summary = "A logical circuit-switched connection between cores";
let assemblyFormat = [{
`(` $source `,` $source_bundle `:` $source_channel `,` $dest `,` $dest_bundle `:` $dest_channel `)` attr-dict
}];
let builders = [
OpBuilder<(
ins "::mlir::Value":$source,
"::mlir::iree_compiler::AMDAIE::StrmSwPortType":$source_bundle,
"uint8_t":$source_channel,
"::mlir::Value":$dest,
"::mlir::iree_compiler::AMDAIE::StrmSwPortType":$dest_bundle,
"uint8_t":$dest_channel)>
];
}

def AIE_AMSelOp: AIE_Op<"amsel", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "Passes.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -114,24 +115,6 @@ std::optional<Value> getOptionalSharedTile(ObjectFifoLinkOp op) {

} // namespace

class DMAChannelAnalysis {
DenseMap<Value, uint8_t> producerChannelsPerTile;
DenseMap<Value, uint8_t> consumerChannelsPerTile;

public:
DMAChannelAnalysis() {}

/// Given an AIE tile, returns its next usable producer channel.
SwitchDMAConnection getProducerDMAChannel(Value tile) {
return {DMAChannelDir::MM2S, producerChannelsPerTile[tile]++};
}

/// Given an AIE tile, returns its next usable consumer channel.
SwitchDMAConnection getConsumerDMAChannel(Value tile) {
return {DMAChannelDir::S2MM, consumerChannelsPerTile[tile]++};
}
};

enum SharedMemoryDirection { LHS = -1, RHS = 1, NONE = 0 };

/// Retrieve ObjectFifoLinkOp of ObjectFifoCreateOp,
Expand Down Expand Up @@ -800,78 +783,92 @@ void createBuffersAndLocks(

/// Translate ObjectFifoCreateOp ops into routing primitives (Flows) and DMA
/// primitives (DMABD, DMAStart, Buffer, UseLock).
void createFlowsAndTileDMAs(
LogicalResult createFlowsAndTileDMAs(
OpBuilder builder, DeviceOp device, ObjectFifoCreateOp producer,
const std::vector<ObjectFifoCreateOp> &consumers,
DMAChannelAnalysis &dmaAnalysis,
std::vector<ObjectFifoCreateOp> &consumers,
const DenseMap<ObjectFifoCreateOp, std::vector<LockOp>> &locksPerFifo,
const DenseMap<ObjectFifoLinkOp, ObjectFifoCreateOp> &objFifoLinks,
const DenseMap<ObjectFifoCreateOp, std::vector<BufferOp>> &buffersPerFifo) {
const DenseMap<ObjectFifoCreateOp, std::vector<BufferOp>> &buffersPerFifo,
const DenseMap<StringRef, SmallVector<FlowOp>> &symbolToFlowOps) {
AMDAIEDeviceModel deviceModel =
getDeviceModel(static_cast<AMDAIEDevice>(device.getDevice()));
auto createDMA = [&deviceModel, &device, &builder, &locksPerFifo,
&objFifoLinks, &buffersPerFifo](
ObjectFifoCreateOp op, DMAChannelDir channelDir,
int channelIndex, BDDimLayoutArrayAttr dims) {
uint8_t channelIndex, BDDimLayoutArrayAttr dims) {
TileOp producerOp = cast<TileOp>(op.getProducerTile().getDefiningOp());
if (deviceModel.isShimTile(producerOp.getCol(), producerOp.getRow()))
if (deviceModel.isShimTile(producerOp.getCol(), producerOp.getRow())) {
return;
else if (deviceModel.isMemTile(producerOp.getCol(), producerOp.getRow()))
} else if (deviceModel.isMemTile(producerOp.getCol(),
producerOp.getRow())) {
createMemTileDMA(device, builder, op, channelDir, channelIndex, dims,
objFifoLinks, buffersPerFifo, locksPerFifo);
else
} else {
createAMDAIETileDMA(device, builder, op, channelDir, channelIndex, dims,
objFifoLinks, buffersPerFifo, locksPerFifo);
}
};
// create producer tile DMA

// Collect producer and consumer DMA channels
if (!symbolToFlowOps.contains(producer.getSymName())) {
return producer.emitOpError()
<< "symbol name not found in symbol to flow ops map";
}
SmallVector<FlowOp> flowOps = symbolToFlowOps.at(producer.getSymName());
SmallVector<uint8_t> producerChannelsVec = llvm::map_to_vector(
flowOps, [](FlowOp flowOp) { return flowOp.getSourceChannel(); });
llvm::SmallSetVector<uint8_t, 1> producerChannels(producerChannelsVec.begin(),
producerChannelsVec.end());
if (producerChannels.size() != 1)
return producer.emitOpError() << "expected a single producer channel";
DenseMap<Value, uint8_t> consumerChannelsMap;
for (FlowOp flowOp : flowOps)
consumerChannelsMap[flowOp.getDest()] = flowOp.getDestChannel();
if (consumerChannelsMap.size() != consumers.size()) {
return producer.emitOpError() << "expected same number of consumers as the "
"number of consumer objectfifos provided";
}

// create producer tile DMA
TileOp producerProducerTileOp =
cast<TileOp>(producer.getProducerTile().getDefiningOp());
SwitchDMAConnection producerChan =
dmaAnalysis.getProducerDMAChannel(producer.getProducerTile());
createDMA(producer, static_cast<DMAChannelDir>(producerChan.direction),
producerChan.channel, producer.getDimensionsToStreamAttr());
createDMA(producer, DMAChannelDir::MM2S, producerChannels[0],
producer.getDimensionsToStreamAttr());
// generate objectFifo allocation info
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(&device.getBody()->back());
if (deviceModel.isShimTile(producerProducerTileOp.getCol(),
producerProducerTileOp.getRow()))
producerProducerTileOp.getRow())) {
builder.create<ShimDMAAllocationOp>(
builder.getUnknownLoc(), producer.getName(),
static_cast<xilinx::AIE::DMAChannelDir>(producerChan.direction),
producerChan.channel, producerProducerTileOp.getCol());
builder.getUnknownLoc(), producer.getName(), DMAChannelDir::MM2S,
producerChannels[0], producerProducerTileOp.getCol());
}

for (ObjectFifoCreateOp consumer : consumers) {
if (!consumerChannelsMap.contains(consumer.getProducerTile())) {
return consumer.emitOpError()
<< "did not find producer tile in consumerChannelsMap";
}
uint8_t consumerChannel = consumerChannelsMap[consumer.getProducerTile()];

for (auto consumer : consumers) {
// create consumer tile DMA
SwitchDMAConnection consumerChan =
dmaAnalysis.getConsumerDMAChannel(consumer.getProducerTile());
BDDimLayoutArrayAttr consumerDims =
consumer.getDimensionsFromStreamPerConsumer()[0];
createDMA(consumer, static_cast<DMAChannelDir>(consumerChan.direction),
consumerChan.channel, consumerDims);
createDMA(consumer, DMAChannelDir::S2MM, consumerChannel, consumerDims);
// generate objectFifo allocation info
OpBuilder::InsertionGuard gg(builder);
builder.setInsertionPoint(&device.getBody()->back());

TileOp consumerProducerTileOp =
cast<TileOp>(consumer.getProducerTile().getDefiningOp());
if (deviceModel.isShimTile(consumerProducerTileOp.getCol(),
consumerProducerTileOp.getRow()))
consumerProducerTileOp.getRow())) {
builder.create<ShimDMAAllocationOp>(
builder.getUnknownLoc(), producer.getName(),
static_cast<xilinx::AIE::DMAChannelDir>(consumerChan.direction),
consumerChan.channel, consumerProducerTileOp.getCol());

// create flow
{
OpBuilder::InsertionGuard ggg(builder);
builder.setInsertionPointAfter(producer);
builder.create<FlowOp>(builder.getUnknownLoc(),
producer.getProducerTile(), WireBundle::DMA,
producerChan.channel, consumer.getProducerTile(),
WireBundle::DMA, consumerChan.channel);
builder.getUnknownLoc(), producer.getName(), DMAChannelDir::S2MM,
consumerChannel, consumerProducerTileOp.getCol());
}
}
return success();
}

namespace mlir::iree_compiler::AMDAIE {
Expand Down Expand Up @@ -905,7 +902,6 @@ struct AMDAIEObjectFifoStatefulTransformPass : mlir::OperationPass<DeviceOp> {

void runOnOperation() override {
DeviceOp device = getOperation();
DMAChannelAnalysis dmaAnalysis;
OpBuilder builder = OpBuilder::atBlockEnd(device.getBody());
// maps each objFifo to its corresponding buffer
DenseMap<ObjectFifoCreateOp, std::vector<BufferOp>> buffersPerFifo;
Expand All @@ -926,20 +922,32 @@ struct AMDAIEObjectFifoStatefulTransformPass : mlir::OperationPass<DeviceOp> {
llvm::to_vector(device.getOps<ObjectFifoCreateOp>());
for (ObjectFifoCreateOp createOp : createFifoOps) {
if (auto _shareDirection = NONE;
!requiresDMAs(createOp, _shareDirection, splitBecauseLink))
!requiresDMAs(createOp, _shareDirection, splitBecauseLink)) {
continue;
}
splitFifo(device, createOp, builder, splitFifos);
}

for (ObjectFifoCreateOp createOp : device.getOps<ObjectFifoCreateOp>())
for (ObjectFifoCreateOp createOp : device.getOps<ObjectFifoCreateOp>()) {
createBuffersAndLocks(builder, device, createOp, splitBecauseLink,
objFifoLinks, buffersPerFifo, locksPerFifo);
}

DenseMap<StringRef, SmallVector<FlowOp>> symbolToFlowOps;
device.walk([&](FlowOp op) {
std::optional<StringRef> symbolAttr = op.getSymbol();
if (symbolAttr) symbolToFlowOps[symbolAttr.value()].push_back(op);
});

// Only the objectFifos we split above require DMA communication; the others
// rely on shared memory and share the same buffers.
for (auto &[producer, consumers] : splitFifos)
createFlowsAndTileDMAs(builder, device, producer, consumers, dmaAnalysis,
locksPerFifo, objFifoLinks, buffersPerFifo);
for (auto &[producer, consumers] : splitFifos) {
if (failed(createFlowsAndTileDMAs(builder, device, producer, consumers,
locksPerFifo, objFifoLinks,
buffersPerFifo, symbolToFlowOps))) {
return signalPassFailure();
}
}

// Replace ops
for (auto coreOp : device.getOps<CoreOp>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ module @aie2_cyclostatic_dma {
%tile22 = aie.tile(2, 2) // producer tile
%tile83 = aie.tile(8, 3) // consumer tile
%buf83 = aie.buffer(%tile83) {sym_name = "buf83"} : memref<4xi32>
aie.flow(%tile22, DMA : 0, %tile83, DMA : 0) {symbol = @fifo}
// ObjectFifo that can hold 4 memref<i32>s, populated by tile22 and
// consumed by tile23
aie.objectfifo @fifo (%tile22, {%tile83}, 4 : i32) : !aie.objectfifo<memref<i32>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ module @aie2_cyclostatic_l2 {
%memtile = aie.tile(2, 1) // mem tile
%tile83 = aie.tile(8, 3) // consumer tile
%buf83 = aie.buffer(%tile83) {sym_name = "buf83"} : memref<1xi32>
aie.flow(%tile22, DMA : 0, %memtile, DMA : 0) {symbol = @fifo0}
aie.flow(%memtile, DMA : 0, %tile83, DMA : 0) {symbol = @fifo1}
// ObjectFifo that can hold 4 memref<1xi32>s, populated by tile22 and
// consumed by tile23
aie.objectfifo @fifo0 (%tile22, {%memtile}, 4 : i32) : !aie.objectfifo<memref<1xi32>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ module @alloc {
%tile20 = aie.tile(2, 0)
%tile22 = aie.tile(2, 2)
%tile23 = aie.tile(2, 3)
aie.flow(%tile20, DMA : 0, %tile22, DMA : 0) {symbol = @of_in_0}
aie.flow(%tile22, DMA : 0, %tile20, DMA : 0) {symbol = @of_out_0}
aie.flow(%tile20, DMA : 1, %tile23, DMA : 0) {symbol = @of_in_1}
aie.flow(%tile23, DMA : 0, %tile20, DMA : 1) {symbol = @of_out_1}
aie.objectfifo @of_in_0 (%tile20, {%tile22}, 2 : i32) : !aie.objectfifo<memref<64xi16>>
aie.objectfifo @of_out_0 (%tile22, {%tile20}, 2 : i32) : !aie.objectfifo<memref<64xi16>>
aie.objectfifo @of_in_1 (%tile20, {%tile23}, 2 : i32) : !aie.objectfifo<memref<64xi16>>
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/test/base_test_AIE1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ module @elementGenerationAIE1 {
%tile12 = aie.tile(1, 2)
%tile13 = aie.tile(1, 3)
%tile33 = aie.tile(3, 3)
aie.flow(%tile12, DMA : 0, %tile33, DMA : 0) {symbol = @of1}
aie.flow(%tile12, DMA : 1, %tile13, DMA : 0) {symbol = @of0}
// In the shared memory case, the number of elements does not change.
aie.objectfifo @of0 (%tile12, {%tile13}, 4 : i32) : !aie.objectfifo<memref<16xi32>>
// In the non-adjacent memory case, the number of elements depends on the max amount acquired by
Expand Down
6 changes: 4 additions & 2 deletions compiler/plugins/target/AMD-AIE/aie/test/base_test_AIE2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@
// CHECK: }

module @elementGenerationAIE2 {
aie.device(xcve2302) {
aie.device(xcve2302) {
%tile12 = aie.tile(1, 2)
%tile13 = aie.tile(1, 3)
%tile33 = aie.tile(3, 3)
aie.flow(%tile12, DMA : 0, %tile33, DMA : 0) {symbol = @of1}
aie.flow(%tile12, DMA : 1, %tile13, DMA : 0) {symbol = @of0}
// In the shared memory case, the number of elements does not change.
aie.objectfifo @of0 (%tile12, {%tile13}, 4 : i32) : !aie.objectfifo<memref<16xi32>>
// In the non-adjacent memory case, the number of elements depends on the max amount acquired by
// the processes running on each core (here nothing is specified so it cannot be derived).
aie.objectfifo @of1 (%tile12, {%tile33}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
}
}
}
4 changes: 4 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/test/broadcast_test.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ module @broadcast {
%tile14 = aie.tile(1, 4)
%tile32 = aie.tile(3, 2)
%tile33 = aie.tile(3, 3)
aie.flow(%tile13, DMA : 0, %tile33, DMA : 0) {symbol = @broadcast_of}
aie.flow(%tile13, DMA : 0, %tile32, DMA : 0) {symbol = @broadcast_of}
aie.flow(%tile13, DMA : 0, %tile14, DMA : 0) {symbol = @broadcast_of}
aie.flow(%tile13, DMA : 0, %tile12, DMA : 0) {symbol = @broadcast_of}
aie.objectfifo @broadcast_of (%tile13, {%tile12, %tile14, %tile32, %tile33}, [2, 2, 3, 4, 3]) : !aie.objectfifo<memref<16xi32>>
func.func @some_work(%lineOut : memref<16xi32>) -> () {
return
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/test/link_test_AIE1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ module @link_AIE1 {
%tile20 = aie.tile(2, 0)
%tile12 = aie.tile(1, 2)
%tile22 = aie.tile(2, 2)
aie.flow(%tile20, DMA : 0, %tile12, DMA : 0) {symbol = @of1}
aie.flow(%tile12, DMA : 0, %tile22, DMA : 0) {symbol = @of2}
aie.objectfifo @of1 (%tile20, {%tile12}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo @of2 (%tile12, {%tile22}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo.link [@of1] -> [@of2] ()
Expand Down
3 changes: 3 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/test/link_test_AIE2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ module @link_AIE2 {
%tile01 = aie.tile(2, 1)
%tile02 = aie.tile(2, 2)
%tile03 = aie.tile(2, 3)
aie.flow(%tile00, DMA : 0, %tile01, DMA : 0) {symbol = @mem_in}
aie.flow(%tile00, DMA : 0, %tile02, DMA : 0) {symbol = @mem_in}
aie.flow(%tile01, DMA : 0, %tile03, DMA : 0) {symbol = @mem_out}
aie.objectfifo @mem_in (%tile00, {%tile02, %tile01}, [2,2,7]) : !aie.objectfifo<memref<3000xi32>>
aie.objectfifo @mem_out (%tile01, {%tile03}, 7 : i32) : !aie.objectfifo<memref<3000xi32>>
aie.objectfifo.link [@mem_in] -> [@mem_out] ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ module @link_DDR_L1 {
%tile20 = aie.tile(2, 0)
%tile21 = aie.tile(2, 1)
%tile22 = aie.tile(2, 2)
aie.flow(%tile20, DMA : 0, %tile21, DMA : 0) {symbol = @to_memTile}
aie.flow(%tile21, DMA : 0, %tile22, DMA : 0) {symbol = @from_memTile}
aie.objectfifo @to_memTile (%tile20, {%tile21}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo @from_memTile (%tile21, {%tile22}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo.link [@to_memTile] -> [@from_memTile] ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ module @link_L1_DDR {
%tile20 = aie.tile(2, 0)
%tile21 = aie.tile(2, 1)
%tile22 = aie.tile(2, 2)
aie.flow(%tile22, DMA : 0, %tile21, DMA : 0) {symbol = @to_memTile}
aie.flow(%tile21, DMA : 0, %tile20, DMA : 0) {symbol = @from_memTile}
aie.objectfifo @to_memTile (%tile22, {%tile21}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo @from_memTile (%tile21, {%tile20}, 2 : i32) : !aie.objectfifo<memref<48xi32>>
aie.objectfifo.link [@to_memTile] -> [@from_memTile] ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ module @link_broadcast {
%tile21 = aie.tile(2, 1)
%tile22 = aie.tile(2, 2)
%tile33 = aie.tile(3, 3)
aie.flow(%tile20, DMA : 0, %tile21, DMA : 0) {symbol = @link1}
aie.flow(%tile21, DMA : 0, %tile33, DMA : 0) {symbol = @link2}
aie.flow(%tile21, DMA : 0, %tile22, DMA : 0) {symbol = @link2}
aie.flow(%tile22, DMA : 0, %tile33, DMA : 1) {symbol = @skip_connection}
aie.objectfifo @link1 (%tile20, {%tile21}, 2 : i32) : !aie.objectfifo<memref<48xi32>>
aie.objectfifo @link2 (%tile21, {%tile22, %tile33}, [2, 2, 3]) : !aie.objectfifo<memref<16xi32>>
aie.objectfifo @skip_connection (%tile22, {%tile33}, 2 : i32) : !aie.objectfifo<memref<16xi32>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ module @link_distribute {
%tile22 = aie.tile(2, 2)
%tile23 = aie.tile(2, 3)
%tile33 = aie.tile(3, 3)
aie.flow(%tile20, DMA : 0, %tile21, DMA : 0) {symbol = @link1}
aie.flow(%tile21, DMA : 0, %tile22, DMA : 0) {symbol = @link2}
aie.flow(%tile21, DMA : 1, %tile23, DMA : 0) {symbol = @link3}
aie.flow(%tile21, DMA : 2, %tile33, DMA : 0) {symbol = @link4}
aie.objectfifo @link1 (%tile20, {%tile21}, 2 : i32) : !aie.objectfifo<memref<48xi32>>
aie.objectfifo @link2 (%tile21, {%tile22}, 2 : i32) : !aie.objectfifo<memref<4x4xi32>>
aie.objectfifo @link3 (%tile21, {%tile23}, 2 : i32) : !aie.objectfifo<memref<20xi32>>
Expand Down
Loading

0 comments on commit b6a3ba8

Please sign in to comment.