diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 374ad4920..5ca20995e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: workflow_call: - + workflow_dispatch: pull_request: merge_group: diff --git a/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp b/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp index 61496586a..77405227e 100644 --- a/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp +++ b/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp @@ -13,23 +13,47 @@ #include "d_ary_heap.h" #include "iree-amd-aie/aie_runtime/iree_aie_runtime.h" #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/DirectedGraph.h" -#include "llvm/ADT/GraphTraits.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; -using namespace xilinx; -using namespace xilinx::AIE; +using namespace mlir::iree_compiler::AMDAIE; + +using xilinx::AIE::AMSelOp; +using xilinx::AIE::ConnectOp; +using xilinx::AIE::DeviceOp; +using xilinx::AIE::DMAChannelDir; +using xilinx::AIE::EndOp; +using xilinx::AIE::FlowOp; +using xilinx::AIE::Interconnect; +using xilinx::AIE::MasterSetOp; +using xilinx::AIE::PacketDestOp; +using xilinx::AIE::PacketFlowOp; +using xilinx::AIE::PacketRuleOp; +using xilinx::AIE::PacketRulesOp; +using xilinx::AIE::PacketSourceOp; +using xilinx::AIE::PLIOOp; +using xilinx::AIE::ShimMuxOp; +using xilinx::AIE::SwitchboxOp; +using xilinx::AIE::TileOp; +using xilinx::AIE::WireBundle; +using xilinx::AIE::WireOp; + +using xilinx::AIE::Connect; +using xilinx::AIE::Port; +using xilinx::AIE::TileID; + +using AMDAIEDeviceModel = xilinx::AIE::AIETargetModel; #define DEBUG_TYPE "amdaie-create-pathfinder-flows" #define OVER_CAPACITY_COEFF 0.02 #define USED_CAPACITY_COEFF 0.02 #define DEMAND_COEFF 1.1 -namespace { +namespace mlir::iree_compiler::AMDAIE { + StrmSwPortType toStrmT(WireBundle w) { switch (w) { case WireBundle::Core: @@ -47,9 +71,9 @@ StrmSwPortType toStrmT(WireBundle w) { case WireBundle::East: return StrmSwPortType::EAST; case WireBundle::PLIO: - llvm::report_fatal_error("unhandled PLIO"); + return StrmSwPortType::PLIO; case WireBundle::NOC: - llvm::report_fatal_error("unhandled NOC"); + return StrmSwPortType::NOC; case WireBundle::Trace: return StrmSwPortType::TRACE; case WireBundle::Ctrl: @@ -58,156 +82,300 @@ StrmSwPortType toStrmT(WireBundle w) { llvm::report_fatal_error("unhandled WireBundle"); } } -} // namespace + +WireBundle toWireB(StrmSwPortType w) { + switch (w) { + case StrmSwPortType::CORE: + return WireBundle::Core; + case StrmSwPortType::DMA: + return WireBundle::DMA; + case StrmSwPortType::FIFO: + return WireBundle::FIFO; + case StrmSwPortType::SOUTH: + return WireBundle::South; + case StrmSwPortType::WEST: + return WireBundle::West; + case StrmSwPortType::NORTH: + return WireBundle::North; + case StrmSwPortType::EAST: + return WireBundle::East; + case StrmSwPortType::TRACE: + return WireBundle::Trace; + case StrmSwPortType::PLIO: + return WireBundle::PLIO; + case StrmSwPortType::NOC: + return WireBundle::NOC; + case StrmSwPortType::CTRL: + return WireBundle::Ctrl; + default: + llvm::report_fatal_error("unhandled WireBundle"); + } +} + +} // namespace mlir::iree_compiler::AMDAIE namespace mlir::iree_compiler::AMDAIE { -struct Port { - xilinx::AIE::WireBundle bundle; - int channel; +enum class Connectivity { INVALID = -1, AVAILABLE = 0, OCCUPIED = 1 }; + +using SwitchboxNode = struct SwitchboxNode { + SwitchboxNode(int col, int row, int id, int maxCol, int maxRow, + const AMDAIEDeviceModel &targetModel) + : col{col}, row{row}, id{id} { + std::vector bundles = { + WireBundle::Core, WireBundle::DMA, WireBundle::FIFO, + WireBundle::South, WireBundle::West, WireBundle::North, + WireBundle::East, WireBundle::PLIO, WireBundle::NOC, + WireBundle::Trace, WireBundle::Ctrl}; + + for (WireBundle bundle : bundles) { + int maxCapacity = targetModel.getNumSourceSwitchboxConnections( + col, row, toStrmT(bundle)); + if (targetModel.isShimNOCorPLTile(col, row) && maxCapacity == 0) { + // wordaround for shimMux, todo: integrate shimMux into routable grid + maxCapacity = targetModel.getNumSourceShimMuxConnections( + col, row, toStrmT(bundle)); + } - bool operator==(const Port &rhs) const { - return std::tie(bundle, channel) == std::tie(rhs.bundle, rhs.channel); - } + for (int channel = 0; channel < maxCapacity; channel++) { + Port inPort = {bundle, channel}; + inPortToId[inPort] = inPortId; + inPortId++; + } - bool operator!=(const Port &rhs) const { return !(*this == rhs); } + maxCapacity = + targetModel.getNumDestSwitchboxConnections(col, row, toStrmT(bundle)); + if (targetModel.isShimNOCorPLTile(col, row) && maxCapacity == 0) { + // wordaround for shimMux, todo: integrate shimMux into routable grid + maxCapacity = + targetModel.getNumDestShimMuxConnections(col, row, toStrmT(bundle)); + } + for (int channel = 0; channel < maxCapacity; channel++) { + Port outPort = {bundle, channel}; + outPortToId[outPort] = outPortId; + outPortId++; + } + } - bool operator<(const Port &rhs) const { - return std::tie(bundle, channel) < std::tie(rhs.bundle, rhs.channel); + connectionMatrix.resize(inPortId, std::vector( + outPortId, Connectivity::AVAILABLE)); + + // illegal connection + for (const auto &[inPort, inId] : inPortToId) { + for (const auto &[outPort, outId] : outPortToId) { + if (!targetModel.isLegalTileConnection( + col, row, toStrmT(inPort.bundle), inPort.channel, + toStrmT(outPort.bundle), outPort.channel)) + connectionMatrix[inId][outId] = Connectivity::INVALID; + + if (targetModel.isShimNOCorPLTile(col, row)) { + // wordaround for shimMux, todo: integrate shimMux into routable grid + auto isBundleInList = [](WireBundle bundle, + std::vector bundles) { + return std::find(bundles.begin(), bundles.end(), bundle) != + bundles.end(); + }; + std::vector bundles = {WireBundle::DMA, WireBundle::NOC, + WireBundle::PLIO}; + if (isBundleInList(inPort.bundle, bundles) || + isBundleInList(outPort.bundle, bundles)) + connectionMatrix[inId][outId] = Connectivity::AVAILABLE; + } + } + } } - friend std::ostream &operator<<(std::ostream &os, const Port &port) { - os << "("; - switch (port.bundle) { - case xilinx::AIE::WireBundle::Core: - os << "Core"; - break; - case xilinx::AIE::WireBundle::DMA: - os << "DMA"; - break; - case xilinx::AIE::WireBundle::North: - os << "N"; - break; - case xilinx::AIE::WireBundle::East: - os << "E"; - break; - case xilinx::AIE::WireBundle::South: - os << "S"; - break; - case xilinx::AIE::WireBundle::West: - os << "W"; - break; - default: - os << "X"; - break; + // given a outPort, find availble input channel + std::vector findAvailableChannelIn(WireBundle inBundle, Port outPort, + bool isPkt) { + std::vector availableChannels; + if (outPortToId.count(outPort) > 0) { + int outId = outPortToId[outPort]; + if (isPkt) { + for (const auto &[inPort, inId] : inPortToId) { + if (inPort.bundle == inBundle && + connectionMatrix[inId][outId] != Connectivity::INVALID) { + bool available = true; + if (inPortPktCount.count(inPort) == 0) { + for (const auto &[outPort, outId] : outPortToId) { + if (connectionMatrix[inId][outId] == Connectivity::OCCUPIED) { + // occupied by others as circuit-switched + available = false; + break; + } + } + } else { + if (inPortPktCount[inPort] >= maxPktStream) { + // occupied by others as packet-switched but exceed max packet + // stream capacity + available = false; + } + } + if (available) availableChannels.push_back(inPort.channel); + } + } + } else { + for (const auto &[inPort, inId] : inPortToId) { + if (inPort.bundle == inBundle && + connectionMatrix[inId][outId] == Connectivity::AVAILABLE) { + bool available = true; + for (const auto &[outPort, outId] : outPortToId) { + if (connectionMatrix[inId][outId] == Connectivity::OCCUPIED) { + available = false; + break; + } + } + if (available) availableChannels.push_back(inPort.channel); + } + } + } } - os << ": " << std::to_string(port.channel) << ")"; - return os; + return availableChannels; } - GENERATE_TO_STRING(Port) + bool allocate(Port inPort, Port outPort, bool isPkt) { + // invalid port + if (outPortToId.count(outPort) == 0 || inPortToId.count(inPort) == 0) + return false; - friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const Port &port) { - os << to_string(port); - return os; - } -}; -} // namespace mlir::iree_compiler::AMDAIE + int inId = inPortToId[inPort]; + int outId = outPortToId[outPort]; -namespace std { -template <> -struct less { - bool operator()(const mlir::iree_compiler::AMDAIE::Port &a, - const mlir::iree_compiler::AMDAIE::Port &b) const { - return a.bundle == b.bundle ? a.channel < b.channel : a.bundle < b.bundle; - } -}; + // invalid connection + if (connectionMatrix[inId][outId] == Connectivity::INVALID) return false; -template <> -struct hash { - size_t operator()(const mlir::iree_compiler::AMDAIE::Port &p) const noexcept { - size_t h1 = hash{}(p.bundle); - size_t h2 = hash{}(p.channel); - return h1 ^ h2 << 1; + if (isPkt) { + // a packet-switched stream to be allocated + if (inPortPktCount.count(inPort) == 0) { + for (const auto &[outPort, outId] : outPortToId) { + if (connectionMatrix[inId][outId] == Connectivity::OCCUPIED) { + // occupied by others as circuit-switched, allocation fail! + return false; + } + } + // empty channel, allocation succeed! + inPortPktCount[inPort] = 1; + connectionMatrix[inId][outId] = Connectivity::OCCUPIED; + return true; + } else { + if (inPortPktCount[inPort] >= maxPktStream) { + // occupied by others as packet-switched but exceed max packet stream + // capacity, allocation fail! + return false; + } else { + // valid packet-switched, allocation succeed! + inPortPktCount[inPort]++; + return true; + } + } + } else { + // a circuit-switched stream to be allocated + if (connectionMatrix[inId][outId] == Connectivity::AVAILABLE) { + // empty channel, allocation succeed! + connectionMatrix[inId][outId] = Connectivity::OCCUPIED; + return true; + } else { + // occupied by others, allocation fail! + return false; + } + } } -}; -} // namespace std -namespace mlir::iree_compiler::AMDAIE { + void clearAllocation() { + for (int inId = 0; inId < inPortId; inId++) { + for (int outId = 0; outId < outPortId; outId++) { + if (connectionMatrix[inId][outId] != Connectivity::INVALID) { + connectionMatrix[inId][outId] = Connectivity::AVAILABLE; + } + } + } + inPortPktCount.clear(); + } -#define GENERATE_TO_STRING(TYPE_WITH_INSERTION_OP) \ - friend std::string to_string(const TYPE_WITH_INSERTION_OP &s) { \ - std::ostringstream ss; \ - ss << s; \ - return ss.str(); \ + friend std::ostream &operator<<(std::ostream &os, const SwitchboxNode &s) { + os << "Switchbox(" << s.col << ", " << s.row << ")"; + return os; } -typedef struct Connect { - Port src; - Port dst; + GENERATE_TO_STRING(SwitchboxNode); - bool operator==(const Connect &rhs) const { - return std::tie(src, dst) == std::tie(rhs.src, rhs.dst); + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const SwitchboxNode &s) { + os << to_string(s); + return os; } -} Connect; - -typedef struct DMAChannel { - xilinx::AIE::DMAChannelDir direction; - int channel; - bool operator==(const DMAChannel &rhs) const { - return std::tie(direction, channel) == std::tie(rhs.direction, rhs.channel); + bool operator<(const SwitchboxNode &rhs) const { + return std::tie(col, row) < std::tie(rhs.col, rhs.row); } -} DMAChannel; -struct Switchbox : TileLoc { - // Necessary for initializer construction? - Switchbox(TileLoc t) : TileLoc(t) {} - Switchbox(int col, int row) : TileLoc{col, row} {} - friend std::ostream &operator<<(std::ostream &os, const Switchbox &s) { - os << "Switchbox(" << s.col << ", " << s.row << ")"; - return os; + bool operator==(const SwitchboxNode &rhs) const { + return std::tie(col, row) == std::tie(rhs.col, rhs.row); } - GENERATE_TO_STRING(Switchbox); + int col, row, id; + int inPortId = 0, outPortId = 0; + std::map inPortToId, outPortToId; - bool operator==(const Switchbox &rhs) const { - return static_cast(*this) == rhs; - } + // tenary representation of switchbox connectivity + // -1: invalid in arch, 0: empty and available, 1: occupued + std::vector> connectionMatrix; + + // input ports with incoming packet-switched streams + std::map inPortPktCount; + + // up to 32 packet-switched stram through a port + const int maxPktStream = 32; }; -struct Channel { - Channel(Switchbox &src, Switchbox &target, xilinx::AIE::WireBundle bundle, - int maxCapacity) - : src(src), target(target), bundle(bundle), maxCapacity(maxCapacity) {} +using ChannelEdge = struct ChannelEdge { + ChannelEdge(SwitchboxNode *src, SwitchboxNode *target) + : src(src), target(target) { + // get bundle from src to target coordinates + if (src->col == target->col) { + if (src->row > target->row) + bundle = WireBundle::South; + else + bundle = WireBundle::North; + } else { + if (src->col > target->col) + bundle = WireBundle::West; + else + bundle = WireBundle::East; + } + + // maximum number of routing resources + maxCapacity = 0; + for (auto &[outPort, _] : src->outPortToId) { + if (outPort.bundle == bundle) { + maxCapacity++; + } + } + } - friend std::ostream &operator<<(std::ostream &os, const Channel &c) { + friend std::ostream &operator<<(std::ostream &os, const ChannelEdge &c) { os << "Channel(src=" << c.src << ", dst=" << c.target << ")"; return os; } - GENERATE_TO_STRING(Channel) + GENERATE_TO_STRING(ChannelEdge) friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const Channel &c) { + const ChannelEdge &c) { os << to_string(c); return os; } - Switchbox &src; - Switchbox ⌖ - xilinx::AIE::WireBundle bundle; - int maxCapacity = 0; // maximum number of routing resources - double demand = 0.0; // indicates how many flows want to use this Channel - int usedCapacity = 0; // how many flows are actually using this Channel - std::set fixedCapacity; // channels not available to the algorithm - int overCapacityCount = 0; // history of Channel being over capacity + SwitchboxNode *src; + SwitchboxNode *target; + + int maxCapacity; + WireBundle bundle; }; -// A SwitchSetting defines the required settings for a Switchbox for a flow +// A SwitchSetting defines the required settings for a SwitchboxNode for a flow // SwitchSetting.src is the incoming signal // SwitchSetting.dsts is the fanout -struct SwitchSetting { +using SwitchSetting = struct SwitchSetting { SwitchSetting() = default; SwitchSetting(Port src) : src(src) {} SwitchSetting(Port src, std::set dsts) @@ -219,7 +387,8 @@ struct SwitchSetting { // namespace surrounding the class). friend std::ostream &operator<<(std::ostream &os, const SwitchSetting &setting) { - os << setting.src << " -> " << "{" + os << setting.src << " -> " + << "{" << join(llvm::map_range(setting.dsts, [](const Port &port) { std::ostringstream ss; @@ -242,10 +411,12 @@ struct SwitchSetting { bool operator<(const SwitchSetting &rhs) const { return src < rhs.src; } }; +using SwitchSettings = std::map; + // A Flow defines source and destination vertices // Only one source, but any number of destinations (fanout) -struct PathEndPoint { - Switchbox sb; +using PathEndPoint = struct PathEndPoint { + SwitchboxNode sb; Port port; friend std::ostream &operator<<(std::ostream &os, const PathEndPoint &s) { @@ -271,132 +442,93 @@ struct PathEndPoint { } }; -} // namespace mlir::iree_compiler::AMDAIE +// A Flow defines source and destination vertices +// Only one source, but any number of destinations (fanout) +using PathEndPointNode = struct PathEndPointNode : PathEndPoint { + PathEndPointNode(SwitchboxNode *sb, Port port) + : PathEndPoint{*sb, port}, sb(sb) {} + SwitchboxNode *sb; +}; -namespace std { -template <> -struct hash { - size_t operator()( - const mlir::iree_compiler::AMDAIE::TileLoc &s) const noexcept { - size_t h1 = hash{}(s.col); - size_t h2 = hash{}(s.row); - return h1 ^ (h2 << 1); - } +using FlowNode = struct FlowNode { + bool isPacketFlow; + PathEndPointNode src; + std::vector dsts; }; -// For some mysterious reason, the only way to get the priorityQueue(cmp) -// comparison in dijkstraShortestPaths to work correctly is to define -// this template specialization for the pointers. Overloading operator -// will not work. Furthermore, if you try to move this into AIEPathFinder.cpp -// you'll get a compile error about -// "specialization of ‘std::less’ after -// instantiation" because one of the graph traits below is doing the comparison -// internally (try moving this below the llvm namespace...) -template <> -struct less { - bool operator()(const mlir::iree_compiler::AMDAIE::Switchbox *a, - const mlir::iree_compiler::AMDAIE::Switchbox *b) const { - return *a < *b; + +} // namespace mlir::iree_compiler::AMDAIE + +namespace llvm { + +inline raw_ostream &operator<<( + raw_ostream &os, const mlir::iree_compiler::AMDAIE::SwitchSettings &ss) { + std::stringstream s; + s << "\tSwitchSettings: "; + for (const auto &[sb, setting] : ss) { + s << sb << ": " << setting << " | "; } -}; + s << "\n"; + os << s.str(); + return os; +} + +} // namespace llvm template <> -struct hash { - size_t operator()( - const mlir::iree_compiler::AMDAIE::Switchbox &s) const noexcept { - return hash{}(s); +struct std::hash { + std::size_t operator()( + const mlir::iree_compiler::AMDAIE::SwitchboxNode &s) const noexcept { + return std::hash{}({s.col, s.row}); } }; template <> -struct hash { - size_t operator()( +struct std::hash { + std::size_t operator()( const mlir::iree_compiler::AMDAIE::PathEndPoint &pe) const noexcept { - size_t h1 = hash{}(pe.port); - size_t h2 = hash{}(pe.sb); + std::size_t h1 = std::hash{}(pe.port); + std::size_t h2 = + std::hash{}(pe.sb); return h1 ^ (h2 << 1); } }; -} // namespace std - namespace mlir::iree_compiler::AMDAIE { -struct SwitchboxNode; -struct ChannelEdge; -using SwitchboxNodeBase = llvm::DGNode; -using ChannelEdgeBase = llvm::DGEdge; -using SwitchboxGraphBase = llvm::DirectedGraph; - -struct SwitchboxNode : SwitchboxNodeBase, Switchbox { - using Switchbox::Switchbox; - SwitchboxNode(int col, int row, int id) : Switchbox{col, row}, id{id} {} - int id; -}; - -// warning: 'mlir::iree_compiler::AMDAIE::ChannelEdge::src' will be initialized -// after SwitchboxNode &src; [-Wreorder] -struct ChannelEdge : ChannelEdgeBase, Channel { - using Channel::Channel; - - explicit ChannelEdge(SwitchboxNode &target) = delete; - ChannelEdge(SwitchboxNode &src, SwitchboxNode &target, - xilinx::AIE::WireBundle bundle, int maxCapacity) - : ChannelEdgeBase(target), - Channel(src, target, bundle, maxCapacity), - src(src) {} - - // This class isn't designed to copied or moved. - ChannelEdge(const ChannelEdge &E) = delete; - ChannelEdge &operator=(ChannelEdge &&E) = delete; - - SwitchboxNode &src; -}; - -class SwitchboxGraph : public SwitchboxGraphBase { - public: - SwitchboxGraph() = default; - ~SwitchboxGraph() = default; -}; - -using SwitchSettings = std::map; - -// A Flow defines source and destination vertices -// Only one source, but any number of destinations (fanout) -struct PathEndPointNode : PathEndPoint { - PathEndPointNode(SwitchboxNode *sb, Port port) - : PathEndPoint{*sb, port}, sb(sb) {} - SwitchboxNode *sb; -}; - -struct FlowNode { - PathEndPointNode src; - std::vector dsts; -}; - class Pathfinder { public: Pathfinder() = default; - void initialize(int maxCol, int maxRow, AMDAIEDeviceModel &deviceModel); - void addFlow(TileLoc srcCoords, Port srcPort, TileLoc dstCoords, - Port dstPort); - bool addFixedConnection(xilinx::AIE::ConnectOp connectOp); + void initialize(int maxCol, int maxRow, const AMDAIEDeviceModel &targetModel); + void addFlow(TileID srcCoords, Port srcPort, TileID dstCoords, Port dstPort, + bool isPacketFlow); + bool addFixedConnection(SwitchboxOp switchboxOp); std::optional> findPaths( int maxIterations); - Switchbox *getSwitchbox(TileLoc coords) { - auto *sb = std::find_if(graph.begin(), graph.end(), [&](SwitchboxNode *sb) { - return sb->col == coords.col && sb->row == coords.row; - }); - assert(sb != graph.end() && "couldn't find sb"); - return *sb; - } + std::map dijkstraShortestPaths( + SwitchboxNode *src); + + SwitchboxNode getSwitchboxNode(TileID coords) { return grid.at(coords); } private: - SwitchboxGraph graph; + // Flows to be routed std::vector flows; - std::map grid; + + // Grid of switchboxes available + std::map grid; + // Use a list instead of a vector because nodes have an edge list of raw // pointers to edges (so growing a vector would invalidate the pointers). std::list edges; + + // Use Dijkstra's shortest path to find routes, and use "demand" as the + // weights. + std::map demand; + + // History of Channel being over capacity + std::map overCapacity; + + // how many flows are actually using this Channel + std::map usedCapacity; }; // DynamicTileAnalysis integrates the Pathfinder class into the MLIR @@ -410,10 +542,10 @@ class DynamicTileAnalysis { std::map flowSolutions; std::map processedFlows; - llvm::DenseMap coordToTile; - llvm::DenseMap coordToSwitchbox; - llvm::DenseMap coordToShimMux; - llvm::DenseMap coordToPLIO; + llvm::DenseMap coordToTile; + llvm::DenseMap coordToSwitchbox; + llvm::DenseMap coordToShimMux; + llvm::DenseMap coordToPLIO; const int maxIterations = 1000; // how long until declared unroutable @@ -421,137 +553,20 @@ class DynamicTileAnalysis { DynamicTileAnalysis(std::shared_ptr p) : pathfinder(std::move(p)) {} - mlir::LogicalResult runAnalysis(xilinx::AIE::DeviceOp &device); + mlir::LogicalResult runAnalysis(DeviceOp &device); int getMaxCol() const { return maxCol; } int getMaxRow() const { return maxRow; } - xilinx::AIE::TileOp getTile(mlir::OpBuilder &builder, int col, int row); + TileOp getTile(mlir::OpBuilder &builder, int col, int row); - xilinx::AIE::SwitchboxOp getSwitchbox(mlir::OpBuilder &builder, int col, - int row); + SwitchboxOp getSwitchbox(mlir::OpBuilder &builder, int col, int row); - xilinx::AIE::ShimMuxOp getShimMux(mlir::OpBuilder &builder, int col); + ShimMuxOp getShimMux(mlir::OpBuilder &builder, int col); }; } // namespace mlir::iree_compiler::AMDAIE - -namespace llvm { -template <> -struct DenseMapInfo { - using FirstInfo = DenseMapInfo; - using SecondInfo = DenseMapInfo; - - static mlir::iree_compiler::AMDAIE::DMAChannel getEmptyKey() { - return {FirstInfo::getEmptyKey(), SecondInfo::getEmptyKey()}; - } - - static mlir::iree_compiler::AMDAIE::DMAChannel getTombstoneKey() { - return {FirstInfo::getTombstoneKey(), SecondInfo::getTombstoneKey()}; - } - - static unsigned getHashValue( - const mlir::iree_compiler::AMDAIE::DMAChannel &d) { - return detail::combineHashValue(FirstInfo::getHashValue(d.direction), - SecondInfo::getHashValue(d.channel)); - } - - static bool isEqual(const mlir::iree_compiler::AMDAIE::DMAChannel &lhs, - const mlir::iree_compiler::AMDAIE::DMAChannel &rhs) { - return lhs == rhs; - } -}; - -template <> -struct DenseMapInfo { - using FirstInfo = DenseMapInfo; - using SecondInfo = DenseMapInfo; - - static mlir::iree_compiler::AMDAIE::Port getEmptyKey() { - return {FirstInfo::getEmptyKey(), SecondInfo::getEmptyKey()}; - } - - static mlir::iree_compiler::AMDAIE::Port getTombstoneKey() { - return {FirstInfo::getTombstoneKey(), SecondInfo::getTombstoneKey()}; - } - - static unsigned getHashValue(const mlir::iree_compiler::AMDAIE::Port &d) { - return detail::combineHashValue(FirstInfo::getHashValue(d.bundle), - SecondInfo::getHashValue(d.channel)); - } - - static bool isEqual(const mlir::iree_compiler::AMDAIE::Port &lhs, - const mlir::iree_compiler::AMDAIE::Port &rhs) { - return lhs == rhs; - } -}; - -template <> -struct GraphTraits { - using NodeRef = mlir::iree_compiler::AMDAIE::SwitchboxNode *; - - static mlir::iree_compiler::AMDAIE::SwitchboxNode *SwitchboxGraphGetSwitchbox( - DGEdge *P) { - return &P->getTargetNode(); - } - - // Provide a mapped iterator so that the GraphTrait-based implementations - // can find the target nodes without having to explicitly go through the - // edges. - using ChildIteratorType = - mapped_iterator; - using ChildEdgeIteratorType = - mlir::iree_compiler::AMDAIE::SwitchboxNode::iterator; - - static NodeRef getEntryNode(NodeRef N) { return N; } - static ChildIteratorType child_begin(NodeRef N) { - return {N->begin(), &SwitchboxGraphGetSwitchbox}; - } - static ChildIteratorType child_end(NodeRef N) { - return {N->end(), &SwitchboxGraphGetSwitchbox}; - } - - static ChildEdgeIteratorType child_edge_begin(NodeRef N) { - return N->begin(); - } - static ChildEdgeIteratorType child_edge_end(NodeRef N) { return N->end(); } -}; - -template <> -struct GraphTraits - : GraphTraits { - using nodes_iterator = mlir::iree_compiler::AMDAIE::SwitchboxGraph::iterator; - static NodeRef getEntryNode(mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { - return *DG->begin(); - } - static nodes_iterator nodes_begin( - mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { - return DG->begin(); - } - static nodes_iterator nodes_end( - mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { - return DG->end(); - } -}; - -inline raw_ostream &operator<<( - raw_ostream &os, const mlir::iree_compiler::AMDAIE::SwitchSettings &ss) { - std::stringstream s; - s << "\tSwitchSettings: "; - for (const auto &[sb, setting] : ss) { - s << sb << ": " << setting << " | "; - } - s << "\n"; - os << s.str(); - return os; -} - -} // namespace llvm - namespace mlir::iree_compiler::AMDAIE { - LogicalResult DynamicTileAnalysis::runAnalysis(DeviceOp &device) { LLVM_DEBUG(llvm::dbgs() << "\t---Begin DynamicTileAnalysis Constructor---\n"); // find the maxCol and maxRow @@ -562,17 +577,19 @@ LogicalResult DynamicTileAnalysis::runAnalysis(DeviceOp &device) { maxRow = std::max(maxRow, tileOp.rowIndex()); } - AMDAIEDeviceModel deviceModel = + AMDAIEDeviceModel targetModel = getDeviceModel(static_cast(device.getDevice())); - pathfinder->initialize(maxCol, maxRow, deviceModel); + pathfinder->initialize(maxCol, maxRow, targetModel); + + // pathfinder->initialize(maxCol, maxRow, device.getTargetModel()); // for each flow in the device, add it to pathfinder // each source can map to multiple different destinations (fanout) for (FlowOp flowOp : device.getOps()) { TileOp srcTile = cast(flowOp.getSource().getDefiningOp()); TileOp dstTile = cast(flowOp.getDest().getDefiningOp()); - TileLoc srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; - TileLoc dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; + TileID srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + TileID dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; Port srcPort = {flowOp.getSourceBundle(), flowOp.getSourceChannel()}; Port dstPort = {flowOp.getDestBundle(), flowOp.getDestChannel()}; LLVM_DEBUG(llvm::dbgs() @@ -581,16 +598,42 @@ LogicalResult DynamicTileAnalysis::runAnalysis(DeviceOp &device) { << " -> (" << dstCoords.col << ", " << dstCoords.row << ")" << stringifyWireBundle(dstPort.bundle) << dstPort.channel << "\n"); - pathfinder->addFlow(srcCoords, srcPort, dstCoords, dstPort); + pathfinder->addFlow(srcCoords, srcPort, dstCoords, dstPort, false); + } + + for (PacketFlowOp pktFlowOp : device.getOps()) { + Region &r = pktFlowOp.getPorts(); + Block &b = r.front(); + Port srcPort, dstPort; + TileOp srcTile, dstTile; + TileID srcCoords, dstCoords; + for (Operation &Op : b.getOperations()) { + if (auto pktSource = dyn_cast(Op)) { + srcTile = dyn_cast(pktSource.getTile().getDefiningOp()); + srcPort = pktSource.port(); + srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + } else if (auto pktDest = dyn_cast(Op)) { + dstTile = dyn_cast(pktDest.getTile().getDefiningOp()); + dstPort = pktDest.port(); + dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; + LLVM_DEBUG(llvm::dbgs() + << "\tAdding Packet Flow: (" << srcCoords.col << ", " + << srcCoords.row << ")" + << stringifyWireBundle(srcPort.bundle) << srcPort.channel + << " -> (" << dstCoords.col << ", " << dstCoords.row << ")" + << stringifyWireBundle(dstPort.bundle) << dstPort.channel + << "\n"); + // todo: support many-to-one & many-to-many? + pathfinder->addFlow(srcCoords, srcPort, dstCoords, dstPort, true); + } + } } // add existing connections so Pathfinder knows which resources are // available search all existing SwitchBoxOps for exising connections for (SwitchboxOp switchboxOp : device.getOps()) { - for (ConnectOp connectOp : switchboxOp.getOps()) { - if (!pathfinder->addFixedConnection(connectOp)) - return switchboxOp.emitOpError() << "Couldn't connect " << connectOp; - } + if (!pathfinder->addFixedConnection(switchboxOp)) + return switchboxOp.emitOpError() << "Unable to add fixed connections"; } // all flows are now populated, call the congestion-aware pathfinder @@ -628,7 +671,7 @@ LogicalResult DynamicTileAnalysis::runAnalysis(DeviceOp &device) { for (auto shimmuxOp : device.getOps()) { int col = shimmuxOp.colIndex(); int row = shimmuxOp.rowIndex(); - assert(coordToShimMux.count(TileLoc{col, row}) == 0); + assert(coordToShimMux.count({col, row}) == 0); coordToShimMux[{col, row}] = shimmuxOp; } @@ -670,7 +713,7 @@ ShimMuxOp DynamicTileAnalysis::getShimMux(OpBuilder &builder, int col) { if (coordToShimMux.count({col, row})) { return coordToShimMux[{col, row}]; } - assert(getTile(builder, col, row).isShimNOCTile()); + // assert(getTile(builder, col, row).isShimNOCTile()); auto switchboxOp = builder.create(builder.getUnknownLoc(), getTile(builder, col, row)); SwitchboxOp::ensureTerminator(switchboxOp.getConnections(), builder, @@ -682,49 +725,41 @@ ShimMuxOp DynamicTileAnalysis::getShimMux(OpBuilder &builder, int col) { } void Pathfinder::initialize(int maxCol, int maxRow, - AMDAIEDeviceModel &deviceModel) { + const AMDAIEDeviceModel &targetModel) { // make grid of switchboxes int id = 0; for (int row = 0; row <= maxRow; row++) { for (int col = 0; col <= maxCol; col++) { - auto [it, _] = grid.insert({{col, row}, SwitchboxNode{col, row, id++}}); - (void)graph.addNode(it->second); + grid.insert({{col, row}, + SwitchboxNode{col, row, id++, maxCol, maxRow, targetModel}}); SwitchboxNode &thisNode = grid.at({col, row}); if (row > 0) { // if not in row 0 add channel to North/South SwitchboxNode &southernNeighbor = grid.at({col, row - 1}); // get the number of outgoing connections on the south side - outgoing // because these correspond to rhs of a connect op - if (uint32_t maxCapacity = deviceModel.getNumDestSwitchboxConnections( + if (targetModel.getNumDestSwitchboxConnections( col, row, toStrmT(WireBundle::South))) { - edges.emplace_back(thisNode, southernNeighbor, WireBundle::South, - maxCapacity); - (void)graph.connect(thisNode, southernNeighbor, edges.back()); + edges.emplace_back(&thisNode, &southernNeighbor); } // get the number of incoming connections on the south side - incoming // because they correspond to connections on the southside that are then // routed using internal connect ops through the switchbox (i.e., lhs of // connect ops) - if (uint32_t maxCapacity = deviceModel.getNumSourceSwitchboxConnections( + if (targetModel.getNumSourceSwitchboxConnections( col, row, toStrmT(WireBundle::South))) { - edges.emplace_back(southernNeighbor, thisNode, WireBundle::North, - maxCapacity); - (void)graph.connect(southernNeighbor, thisNode, edges.back()); + edges.emplace_back(&southernNeighbor, &thisNode); } } if (col > 0) { // if not in col 0 add channel to East/West SwitchboxNode &westernNeighbor = grid.at({col - 1, row}); - if (uint32_t maxCapacity = deviceModel.getNumDestSwitchboxConnections( + if (targetModel.getNumDestSwitchboxConnections( col, row, toStrmT(WireBundle::West))) { - edges.emplace_back(thisNode, westernNeighbor, WireBundle::West, - maxCapacity); - (void)graph.connect(thisNode, westernNeighbor, edges.back()); + edges.emplace_back(&thisNode, &westernNeighbor); } - if (uint32_t maxCapacity = deviceModel.getNumSourceSwitchboxConnections( + if (targetModel.getNumSourceSwitchboxConnections( col, row, toStrmT(WireBundle::West))) { - edges.emplace_back(westernNeighbor, thisNode, WireBundle::East, - maxCapacity); - (void)graph.connect(westernNeighbor, thisNode, edges.back()); + edges.emplace_back(&westernNeighbor, &thisNode); } } } @@ -733,81 +768,67 @@ void Pathfinder::initialize(int maxCol, int maxRow, // Add a flow from src to dst can have an arbitrary number of dst locations due // to fanout. -void Pathfinder::addFlow(TileLoc srcCoords, Port srcPort, TileLoc dstCoords, - Port dstPort) { +void Pathfinder::addFlow(TileID srcCoords, Port srcPort, TileID dstCoords, + Port dstPort, bool isPacketFlow) { // check if a flow with this source already exists - for (auto &[src, dsts] : flows) { - SwitchboxNode *existingSrc = src.sb; - assert(existingSrc && "nullptr flow source"); - if (Port existingPort = src.port; existingSrc->col == srcCoords.col && - existingSrc->row == srcCoords.row && + for (auto &[isPkt, src, dsts] : flows) { + SwitchboxNode *existingSrcPtr = src.sb; + assert(existingSrcPtr && "nullptr flow source"); + if (Port existingPort = src.port; existingSrcPtr->col == srcCoords.col && + existingSrcPtr->row == srcCoords.row && existingPort == srcPort) { // find the vertex corresponding to the destination - auto *matchingSb = std::find_if( - graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { - return sb->col == dstCoords.col && sb->row == dstCoords.row; - }); - assert(matchingSb != graph.end() && "didn't find flow dest"); - dsts.emplace_back(*matchingSb, dstPort); + SwitchboxNode *matchingDstSbPtr = &grid.at(dstCoords); + dsts.emplace_back(matchingDstSbPtr, dstPort); return; } } // If no existing flow was found with this source, create a new flow. - auto *matchingSrcSb = - std::find_if(graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { - return sb->col == srcCoords.col && sb->row == srcCoords.row; - }); - assert(matchingSrcSb != graph.end() && "didn't find flow source"); - auto *matchingDstSb = - std::find_if(graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { - return sb->col == dstCoords.col && sb->row == dstCoords.row; - }); - assert(matchingDstSb != graph.end() && "didn't add flow destinations"); - flows.push_back({PathEndPointNode{*matchingSrcSb, srcPort}, - std::vector{{*matchingDstSb, dstPort}}}); + SwitchboxNode *matchingSrcSbPtr = &grid.at(srcCoords); + SwitchboxNode *matchingDstSbPtr = &grid.at(dstCoords); + flows.push_back({isPacketFlow, PathEndPointNode{matchingSrcSbPtr, srcPort}, + std::vector{{matchingDstSbPtr, dstPort}}}); } // Keep track of connections already used in the AIE; Pathfinder algorithm will // avoid using these. -bool Pathfinder::addFixedConnection(ConnectOp connectOp) { - auto sb = connectOp->getParentOfType(); - // TODO: keep track of capacity? - if (sb.getTileOp().isShimNOCTile()) return true; - - TileLoc sbTile(sb.getTileID().col, sb.getTileID().row); - WireBundle sourceBundle = connectOp.getSourceBundle(); - WireBundle destBundle = connectOp.getDestBundle(); - - // find the correct Channel and indicate the fixed direction - // outgoing connection - auto matchingCh = - std::find_if(edges.begin(), edges.end(), [&](ChannelEdge &ch) { - return static_cast(ch.src) == sbTile && - ch.bundle == destBundle; - }); - if (matchingCh != edges.end()) - return matchingCh->fixedCapacity.insert(connectOp.getDestChannel()) - .second || - true; - - // incoming connection - matchingCh = std::find_if(edges.begin(), edges.end(), [&](ChannelEdge &ch) { - return static_cast(ch.target) == sbTile && - ch.bundle == getConnectingBundle(sourceBundle); - }); - if (matchingCh != edges.end()) - return matchingCh->fixedCapacity.insert(connectOp.getSourceChannel()) - .second || - true; +bool Pathfinder::addFixedConnection(SwitchboxOp switchboxOp) { + int col = switchboxOp.colIndex(); + int row = switchboxOp.rowIndex(); + SwitchboxNode &sb = grid.at({col, row}); + std::set invalidInId, invalidOutId; + + for (ConnectOp connectOp : switchboxOp.getOps()) { + Port srcPort = connectOp.sourcePort(); + Port destPort = connectOp.destPort(); + if (sb.inPortToId.count(srcPort) == 0 || + sb.outPortToId.count(destPort) == 0) + return false; + int inId = sb.inPortToId.at(srcPort); + int outId = sb.outPortToId.at(destPort); + if (sb.connectionMatrix[inId][outId] != Connectivity::AVAILABLE) + return false; + invalidInId.insert(inId); + invalidOutId.insert(outId); + } - return false; + for (const auto &[inPort, inId] : sb.inPortToId) { + for (const auto &[outPort, outId] : sb.outPortToId) { + if (invalidInId.find(inId) != invalidInId.end() || + invalidOutId.find(outId) != invalidOutId.end()) { + // mark as invalid + sb.connectionMatrix[inId][outId] = Connectivity::INVALID; + } + } + } + return true; } static constexpr double INF = std::numeric_limits::max(); -std::map dijkstraShortestPaths( - const SwitchboxGraph &graph, SwitchboxNode *src) { +std::map Pathfinder::dijkstraShortestPaths( + SwitchboxNode *src) { // Use std::map instead of DenseMap because DenseMap doesn't let you overwrite // tombstones. auto distance = std::map(); @@ -821,19 +842,24 @@ std::map dijkstraShortestPaths( MutableQueue; MutableQueue Q(distance, indexInHeap); - for (SwitchboxNode *sb : graph) distance.emplace(sb, INF); + for (auto &[_, sb] : grid) distance.emplace(&sb, INF); distance[src] = 0.0; - std::map> edges; + std::map> channels; enum Color { WHITE, GRAY, BLACK }; std::map colors; - for (SwitchboxNode *sb : graph) { - colors[sb] = WHITE; - edges[sb] = {sb->getEdges().begin(), sb->getEdges().end()}; - std::sort(edges[sb].begin(), edges[sb].end(), + for (auto &[_, sb] : grid) { + SwitchboxNode *sbPtr = &sb; + colors[sbPtr] = WHITE; + for (auto &e : edges) { + if (e.src == sbPtr) { + channels[sbPtr].push_back(&e); + } + } + std::sort(channels[sbPtr].begin(), channels[sbPtr].end(), [](const ChannelEdge *c1, ChannelEdge *c2) { - return c1->getTargetNode().id < c2->getTargetNode().id; + return c1->target->id < c2->target->id; }); } @@ -841,18 +867,18 @@ std::map dijkstraShortestPaths( while (!Q.empty()) { src = Q.top(); Q.pop(); - for (ChannelEdge *e : edges[src]) { - SwitchboxNode *dest = &e->getTargetNode(); - bool relax = distance[src] + e->demand < distance[dest]; + for (ChannelEdge *e : channels[src]) { + SwitchboxNode *dest = e->target; + bool relax = distance[src] + demand[e] < distance[dest]; if (colors[dest] == WHITE) { if (relax) { - distance[dest] = distance[src] + e->demand; + distance[dest] = distance[src] + demand[e]; preds[dest] = src; colors[dest] = GRAY; } Q.push(dest); } else if (colors[dest] == GRAY && relax) { - distance[dest] = distance[src] + e->demand; + distance[dest] = distance[src] + demand[e]; preds[dest] = src; } } @@ -874,42 +900,21 @@ std::optional> Pathfinder::findPaths( std::map routingSolution; // initialize all Channel histories to 0 - for (auto &ch : edges) ch.overCapacityCount = 0; - - // Check that every channel does not exceed max capacity. - auto isLegal = [&] { - bool legal = true; // assume legal until found otherwise - for (auto &e : edges) { - if (e.usedCapacity > e.maxCapacity) { - LLVM_DEBUG(llvm::dbgs() - << "Too much capacity on Edge (" << e.getTargetNode().col - << ", " << e.getTargetNode().row << ") . " - << stringifyWireBundle(e.bundle) << "\t: used_capacity = " - << e.usedCapacity << "\t: Demand = " << e.demand << "\n"); - e.overCapacityCount++; - LLVM_DEBUG(llvm::dbgs() - << "over_capacity_count = " << e.overCapacityCount << "\n"); - legal = false; - break; - } - } - - return legal; - }; + for (auto &ch : edges) { + overCapacity[&ch] = 0; + usedCapacity[&ch] = 0; + } + // assume legal until found otherwise + bool isLegal = true; do { LLVM_DEBUG(llvm::dbgs() << "Begin findPaths iteration #" << iterationCount << "\n"); // update demand on all channels for (auto &ch : edges) { - if (ch.fixedCapacity.size() >= - static_cast::size_type>(ch.maxCapacity)) { - ch.demand = INF; - } else { - double history = 1.0 + OVER_CAPACITY_COEFF * ch.overCapacityCount; - double congestion = 1.0 + USED_CAPACITY_COEFF * ch.usedCapacity; - ch.demand = history * congestion; - } + double history = 1.0 + OVER_CAPACITY_COEFF * overCapacity[&ch]; + double congestion = 1.0 + USED_CAPACITY_COEFF * usedCapacity[&ch]; + demand[&ch] = history * congestion; } // if reach maxIterations, throw an error since no routing can be found if (++iterationCount > maxIterations) { @@ -920,13 +925,15 @@ std::optional> Pathfinder::findPaths( return std::nullopt; } - // "rip up" all routes, i.e. set used capacity in each Channel to 0 + // "rip up" all routes routingSolution.clear(); - for (auto &ch : edges) ch.usedCapacity = 0; + for (auto &[tileID, node] : grid) node.clearAllocation(); + for (auto &ch : edges) usedCapacity[&ch] = 0; + isLegal = true; // for each flow, find the shortest path from source to destination // update used_capacity for the path between them - for (const auto &[src, dsts] : flows) { + for (const auto &[isPkt, src, dsts] : flows) { // Use dijkstra to find path given current demand from the start // switchbox; find the shortest paths to each other switchbox. Output is // in the predecessor map, which must then be processed to get individual @@ -934,7 +941,16 @@ std::optional> Pathfinder::findPaths( assert(src.sb && "nonexistent flow source"); std::set processed; std::map preds = - dijkstraShortestPaths(graph, src.sb); + dijkstraShortestPaths(src.sb); + + auto findIncomingEdge = [&](SwitchboxNode *sb) -> ChannelEdge * { + for (auto &e : edges) { + if (e.src == preds[sb] && e.target == sb) { + return &e; + } + } + return nullptr; + }; // trace the path of the flow backwards via predecessors // increment used_capacity for the associated channels @@ -942,62 +958,140 @@ std::optional> Pathfinder::findPaths( // set the input bundle for the source endpoint switchSettings[*src.sb].src = src.port; processed.insert(src.sb); + // track destination ports used by src.sb + std::vector srcDestPorts; for (const PathEndPointNode &endPoint : dsts) { SwitchboxNode *curr = endPoint.sb; assert(curr && "endpoint has no source switchbox"); // set the output bundle for this destination endpoint switchSettings[*curr].dsts.insert(endPoint.port); - + Port lastDestPort = endPoint.port; // trace backwards until a vertex already processed is reached while (!processed.count(curr)) { - // find the edge from the pred to curr by searching incident edges - SmallVector channels; - graph.findIncomingEdgesToNode(*curr, channels); - auto *matchingCh = std::find_if( - channels.begin(), channels.end(), - [&](ChannelEdge *ch) { return ch->src == *preds[curr]; }); - assert(matchingCh != channels.end() && "couldn't find ch"); - // incoming edge - ChannelEdge *ch = *matchingCh; - - // don't use fixed channels - while (ch->fixedCapacity.count(ch->usedCapacity)) ch->usedCapacity++; + // find the incoming edge from the pred to curr + ChannelEdge *ch = findIncomingEdge(curr); + assert(ch != nullptr && "couldn't find ch"); + int channel; + // find all available channels in + std::vector availableChannels = curr->findAvailableChannelIn( + getConnectingBundle(ch->bundle), lastDestPort, isPkt); + if (availableChannels.size() > 0) { + // if possible, choose the channel that predecessor can also use + // todo: consider all predecessors? + int bFound = false; + auto &pred = preds[curr]; + if (!processed.count(pred) && pred != src.sb) { + ChannelEdge *predCh = findIncomingEdge(pred); + assert(predCh != nullptr && "couldn't find ch"); + for (int availableCh : availableChannels) { + channel = availableCh; + std::vector availablePredChannels = + pred->findAvailableChannelIn( + getConnectingBundle(predCh->bundle), + {ch->bundle, channel}, isPkt); + if (availablePredChannels.size() > 0) { + bFound = true; + break; + } + } + } + if (!bFound) channel = availableChannels[0]; + bool succeed = + curr->allocate({getConnectingBundle(ch->bundle), channel}, + lastDestPort, isPkt); + if (!succeed) assert(false && "invalid allocation"); + LLVM_DEBUG(llvm::dbgs() + << *curr << ", connecting: " + << stringifyWireBundle(getConnectingBundle(ch->bundle)) + << channel << " -> " + << stringifyWireBundle(lastDestPort.bundle) + << lastDestPort.channel << "\n"); + } else { + // if no channel available, use a virtual channel id and mark + // routing as being invalid + channel = usedCapacity[ch]; + if (isLegal) { + overCapacity[ch]++; + LLVM_DEBUG(llvm::dbgs() + << *curr << ", congestion: " + << stringifyWireBundle(getConnectingBundle(ch->bundle)) + << ", used_capacity = " << usedCapacity[ch] + << ", over_capacity_count = " << overCapacity[ch] + << "\n"); + } + isLegal = false; + } + usedCapacity[ch]++; // add the entrance port for this Switchbox - switchSettings[*curr].src = {getConnectingBundle(ch->bundle), - ch->usedCapacity}; + Port currSourcePort = {getConnectingBundle(ch->bundle), channel}; + switchSettings[*curr].src = {currSourcePort}; + // add the current Switchbox to the map of the predecessor - switchSettings[*preds[curr]].dsts.insert( - {ch->bundle, ch->usedCapacity}); + Port PredDestPort = {ch->bundle, channel}; + switchSettings[*preds[curr]].dsts.insert(PredDestPort); + lastDestPort = PredDestPort; - ch->usedCapacity++; // if at capacity, bump demand to discourage using this Channel - if (ch->usedCapacity >= ch->maxCapacity) { - LLVM_DEBUG(llvm::dbgs() << "ch over capacity: " << ch << "\n"); + if (usedCapacity[ch] >= ch->maxCapacity) { // this means the order matters! - ch->demand *= DEMAND_COEFF; + demand[ch] *= DEMAND_COEFF; + LLVM_DEBUG(llvm::dbgs() + << *curr << ", bump demand: " + << stringifyWireBundle(getConnectingBundle(ch->bundle)) + << ", demand = " << demand[ch] << "\n"); } processed.insert(curr); curr = preds[curr]; + + // allocation may fail, as we start from the dest of flow while + // src.port is not chosen by router + if (curr == src.sb && + std::find(srcDestPorts.begin(), srcDestPorts.end(), + lastDestPort) == srcDestPorts.end()) { + bool succeed = src.sb->allocate(src.port, lastDestPort, isPkt); + if (!succeed) { + isLegal = false; + overCapacity[ch]++; + LLVM_DEBUG(llvm::dbgs() + << *curr << ", unable to connect: " + << stringifyWireBundle(src.port.bundle) + << src.port.channel << " -> " + << stringifyWireBundle(lastDestPort.bundle) + << lastDestPort.channel << "\n"); + } + srcDestPorts.push_back(lastDestPort); + } } } // add this flow to the proposed solution routingSolution[src] = switchSettings; } - } while (!isLegal()); // continue iterations until a legal routing is found + + } while (!isLegal); // continue iterations until a legal routing is found return routingSolution; } -// allocates channels between switchboxes ( but does not assign them) -// instantiates shim-muxes AND allocates channels ( no need to rip these up in ) + +} // namespace mlir::iree_compiler::AMDAIE + +static std::vector flowOps; + +namespace mlir::iree_compiler::AMDAIE { + struct ConvertFlowsToInterconnect : OpConversionPattern { using OpConversionPattern::OpConversionPattern; DeviceOp &device; DynamicTileAnalysis &analyzer; + bool keepFlowOp; ConvertFlowsToInterconnect(MLIRContext *context, DeviceOp &d, - DynamicTileAnalysis &a, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), device(d), analyzer(a) {} + DynamicTileAnalysis &a, bool keepFlowOp, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), + device(d), + analyzer(a), + keepFlowOp(keepFlowOp) {} LogicalResult match(FlowOp op) const override { return success(); } @@ -1027,14 +1121,22 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { Operation *Op = flowOp.getOperation(); auto srcTile = cast(flowOp.getSource().getDefiningOp()); - TileLoc srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + TileID srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; auto srcBundle = flowOp.getSourceBundle(); auto srcChannel = flowOp.getSourceChannel(); Port srcPort = {srcBundle, srcChannel}; + if (keepFlowOp) { + auto *clonedOp = Op->clone(); + flowOps.push_back(clonedOp); + } + + AMDAIEDeviceModel targetModel = + getDeviceModel(static_cast(device.getDevice())); + #ifndef NDEBUG auto dstTile = cast(flowOp.getDest().getDefiningOp()); - TileLoc dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; + TileID dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; auto dstBundle = flowOp.getDestBundle(); auto dstChannel = flowOp.getDestChannel(); LLVM_DEBUG(llvm::dbgs() @@ -1047,7 +1149,8 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { // if the flow (aka "net") for this FlowOp hasn't been processed yet, // add all switchbox connections to implement the flow - Switchbox srcSB = {srcCoords.col, srcCoords.row}; + SwitchboxNode srcSB = + analyzer.pathfinder->getSwitchboxNode({srcCoords.col, srcCoords.row}); if (PathEndPoint srcPoint = {srcSB, srcPort}; !analyzer.processedFlows[srcPoint]) { SwitchSettings settings = analyzer.flowSolutions[srcPoint]; @@ -1056,8 +1159,7 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { SwitchboxOp swOp = analyzer.getSwitchbox(rewriter, curr.col, curr.row); int shimCh = srcChannel; // TODO: must reserve N3, N7, S2, S3 for DMA connections - if (curr == srcSB && - analyzer.getTile(rewriter, srcSB.col, srcSB.row).isShimNOCTile()) { + if (curr == srcSB && targetModel.isShimNOCTile(srcSB.col, srcSB.row)) { // shim DMAs at start of flows if (srcBundle == WireBundle::DMA) { shimCh = srcChannel == 0 @@ -1069,7 +1171,7 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { srcBundle, srcChannel, WireBundle::North, shimCh); } else if (srcBundle == WireBundle::NOC) { // must be NOC0/NOC1 -> N2/N3 or - // NOC2/NOC3 -> N6/N7 + // NOC2/NOC3 -> N6/N7 shimCh = srcChannel >= 2 ? srcChannel + 4 : srcChannel + 2; ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, srcSB.col); addConnection(rewriter, @@ -1086,19 +1188,18 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { } } } + for (const auto &[bundle, channel] : setting.dsts) { // handle special shim connectivity - if (curr == srcSB && analyzer.getTile(rewriter, srcSB.col, srcSB.row) - .isShimNOCorPLTile()) { + if (curr == srcSB && + targetModel.isShimNOCorPLTile(srcSB.col, srcSB.row)) { addConnection(rewriter, cast(swOp.getOperation()), flowOp, WireBundle::South, shimCh, bundle, channel); - } else if (analyzer.getTile(rewriter, curr.col, curr.row) - .isShimNOCorPLTile() && + } else if (targetModel.isShimNOCorPLTile(curr.col, curr.row) && (bundle == WireBundle::DMA || bundle == WireBundle::PLIO || bundle == WireBundle::NOC)) { shimCh = channel; - if (analyzer.getTile(rewriter, curr.col, curr.row) - .isShimNOCTile()) { + if (targetModel.isShimNOCTile(curr.col, curr.row)) { // shim DMAs at end of flows if (bundle == WireBundle::DMA) { shimCh = channel == 0 @@ -1133,7 +1234,8 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { } } - LLVM_DEBUG(llvm::dbgs() << curr << ": " << setting << " | " << "\n"); + LLVM_DEBUG(llvm::dbgs() << curr << ": " << setting << " | " + << "\n"); } LLVM_DEBUG(llvm::dbgs() @@ -1146,65 +1248,124 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { } }; -/// Overall Flow: -/// rewrite switchboxes to assign unassigned connections, ensure this can be -/// done concurrently ( by different threads) -/// 1. Goal is to rewrite all flows in the device into switchboxes + shim-mux -/// 2. multiple passes of the rewrite pattern rewriting streamswitch -/// configurations to routes -/// 3. rewrite flows to stream-switches using 'weights' from analysis pass. -/// 4. check a region is legal -/// 5. rewrite stream-switches (within a bounding box) back to flows -struct AMDAIEPathfinderPass : mlir::OperationPass { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMDAIEPathfinderPass) - - AMDAIEPathfinderPass() : mlir::OperationPass(resolveTypeID()) {} - - llvm::StringRef getArgument() const override { +} // namespace mlir::iree_compiler::AMDAIE + +namespace mlir::iree_compiler::AMDAIE { + +template +class AIERoutePathfinderFlowsBase : public ::mlir::OperationPass { + public: + using Base = AIERoutePathfinderFlowsBase; + + AIERoutePathfinderFlowsBase() + : ::mlir::OperationPass(::mlir::TypeID::get()) {} + AIERoutePathfinderFlowsBase(const AIERoutePathfinderFlowsBase &other) + : ::mlir::OperationPass(other) {} + AIERoutePathfinderFlowsBase &operator=(const AIERoutePathfinderFlowsBase &) = + delete; + AIERoutePathfinderFlowsBase(AIERoutePathfinderFlowsBase &&) = delete; + AIERoutePathfinderFlowsBase &operator=(AIERoutePathfinderFlowsBase &&) = + delete; + ~AIERoutePathfinderFlowsBase() = default; + + /// Returns the command-line argument attached to this pass. + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("amdaie-create-pathfinder-flows"); + } + ::llvm::StringRef getArgument() const override { return "amdaie-create-pathfinder-flows"; } - llvm::StringRef getName() const override { return "AMDAIEPathfinderPass"; } + ::llvm::StringRef getDescription() const override { + return "Route aie.flow and aie.packetflow operations through switchboxes"; + } - std::unique_ptr clonePass() const override { - return std::make_unique( - *static_cast(this)); + /// Returns the derived pass name. + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("AIERoutePathfinderFlows"); + } + ::llvm::StringRef getName() const override { + return "AIERoutePathfinderFlows"; } + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + + /// A clone method to create a copy of this pass. + std::unique_ptr<::mlir::Pass> clonePass() const override { + return std::make_unique(*static_cast(this)); + } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + AIERoutePathfinderFlowsBase) + + AIERoutePathfinderFlowsBase(const AIERoutePathfinderFlowsOptions &options) + : AIERoutePathfinderFlowsBase() { + clRouteCircuit = options.clRouteCircuit; + clRoutePacket = options.clRoutePacket; + clKeepFlowOp = options.clKeepFlowOp; + } + + protected: + ::mlir::Pass::Option clRouteCircuit{ + *this, "route-circuit", + ::llvm::cl::desc("Flag to enable aie.flow lowering."), + ::llvm::cl::init(true)}; + ::mlir::Pass::Option clRoutePacket{ + *this, "route-packet", + ::llvm::cl::desc("Flag to enable aie.packetflow lowering."), + ::llvm::cl::init(true)}; + ::mlir::Pass::Option clKeepFlowOp{ + *this, "keep-flow-op", + ::llvm::cl::desc("Flag to not erase aie.flow/packetflow after its " + "lowering,used for routing visualization."), + ::llvm::cl::init(false)}; + + private: +}; + +struct AIEPathfinderPass : AIERoutePathfinderFlowsBase { DynamicTileAnalysis analyzer; - AMDAIEPathfinderPass(DynamicTileAnalysis analyzer) - : mlir::OperationPass(resolveTypeID()), - analyzer(std::move(analyzer)) {} + mlir::DenseMap tiles; + + AIEPathfinderPass() = default; + AIEPathfinderPass(DynamicTileAnalysis analyzer) + : analyzer(std::move(analyzer)) {} void runOnOperation() override; + void runOnFlow(DeviceOp d, mlir::OpBuilder &builder); + void runOnPacketFlow(DeviceOp d, mlir::OpBuilder &builder); - bool attemptFixupMemTileRouting(const mlir::OpBuilder &builder, - SwitchboxOp northSwOp, SwitchboxOp southSwOp, - ConnectOp &problemConnect); + typedef std::pair PhysPort; - bool reconnectConnectOps(const mlir::OpBuilder &builder, SwitchboxOp sw, - ConnectOp problemConnect, bool isIncomingToSW, - WireBundle problemBundle, int problemChan, - int emptyChan); + typedef struct { + SwitchboxOp sw; + Port sourcePort; + Port destPort; + } SwConnection; - ConnectOp replaceConnectOpWithNewDest(mlir::OpBuilder builder, - ConnectOp connect, WireBundle newBundle, - int newChannel); - ConnectOp replaceConnectOpWithNewSource(mlir::OpBuilder builder, - ConnectOp connect, - WireBundle newBundle, int newChannel); + bool findPathToDest(SwitchSettings settings, TileID currTile, + WireBundle currDestBundle, int currDestChannel, + TileID finalTile, WireBundle finalDestBundle, + int finalDestChannel); SwitchboxOp getSwitchbox(DeviceOp &d, int col, int row); -}; -void AMDAIEPathfinderPass::runOnOperation() { - // create analysis pass with routing graph for entire device - LLVM_DEBUG(llvm::dbgs() << "---Begin AMDAIEPathfinderPass---\n"); - - DeviceOp d = getOperation(); - if (failed(analyzer.runAnalysis(d))) return signalPassFailure(); - OpBuilder builder = OpBuilder::atBlockEnd(d.getBody()); + mlir::Operation *getOrCreateTile(mlir::OpBuilder &builder, int col, int row); + SwitchboxOp getOrCreateSwitchbox(mlir::OpBuilder &builder, TileOp tile); +}; +void AIEPathfinderPass::runOnFlow(DeviceOp d, OpBuilder &builder) { // Apply rewrite rule to switchboxes to add assignments to every 'connect' // operation inside ConversionTarget target(getContext()); @@ -1214,11 +1375,19 @@ void AMDAIEPathfinderPass::runOnOperation() { target.addLegalOp(); target.addLegalOp(); + AMDAIEDeviceModel targetModel = + getDeviceModel(static_cast(d.getDevice())); + RewritePatternSet patterns(&getContext()); - patterns.insert(d.getContext(), d, analyzer); + patterns.insert(d.getContext(), d, analyzer, + clKeepFlowOp); if (failed(applyPartialConversion(d, target, std::move(patterns)))) return signalPassFailure(); + // Keep for visualization + if (clKeepFlowOp) + for (auto op : flowOps) builder.insert(op); + // Populate wires between switchboxes and tiles. for (int col = 0; col <= analyzer.getMaxCol(); col++) { for (int row = 0; row <= analyzer.getMaxRow(); row++) { @@ -1255,7 +1424,7 @@ void AMDAIEPathfinderPass::runOnOperation() { WireBundle::North, sw, WireBundle::South); } } else if (row == 0) { - if (tile.isShimNOCTile()) { + if (targetModel.isShimNOCTile(tile.getCol(), tile.getRow())) { if (analyzer.coordToShimMux.count({col, 0})) { auto shimsw = analyzer.coordToShimMux[{col, 0}]; builder.create( @@ -1275,7 +1444,7 @@ void AMDAIEPathfinderPass::runOnOperation() { builder.create(builder.getUnknownLoc(), tile, WireBundle::DMA, shimsw, WireBundle::DMA); } - } else if (tile.isShimPLTile()) { + } else if (targetModel.isShimPLTile(tile.getCol(), tile.getRow())) { // PLIO is attached directly to switch if (analyzer.coordToPLIO.count(col)) { auto plio = analyzer.coordToPLIO[col]; @@ -1286,153 +1455,627 @@ void AMDAIEPathfinderPass::runOnOperation() { } } } +} - // If the routing violates architecture-specific routing constraints, then - // attempt to partially reroute. - AMDAIEDeviceModel deviceModel = - getDeviceModel(static_cast(d.getDevice())); - std::vector problemConnects; - d.walk([&](ConnectOp connect) { - if (auto sw = connect->getParentOfType()) { - // Constraint: memtile stream switch constraints - if (auto tile = sw.getTileOp(); - tile.isMemTile() && - !deviceModel.isLegalMemtileConnection( - tile.getCol(), tile.getRow(), toStrmT(connect.getSourceBundle()), - connect.getSourceChannel(), toStrmT(connect.getDestBundle()), - connect.getDestChannel())) { - problemConnects.push_back(connect); - } - } - }); +Operation *AIEPathfinderPass::getOrCreateTile(OpBuilder &builder, int col, + int row) { + TileID index = {col, row}; + Operation *tileOp = tiles[index]; + if (!tileOp) { + auto tile = builder.create(builder.getUnknownLoc(), col, row); + tileOp = tile.getOperation(); + tiles[index] = tileOp; + } + return tileOp; +} - for (auto connect : problemConnects) { - auto swBox = connect->getParentOfType(); - builder.setInsertionPoint(connect); - auto northSw = getSwitchbox(d, swBox.colIndex(), swBox.rowIndex() + 1); - if (auto southSw = getSwitchbox(d, swBox.colIndex(), swBox.rowIndex() - 1); - !attemptFixupMemTileRouting(builder, northSw, southSw, connect)) - return signalPassFailure(); +SwitchboxOp AIEPathfinderPass::getOrCreateSwitchbox(OpBuilder &builder, + TileOp tile) { + for (auto i : tile.getResult().getUsers()) { + if (llvm::isa(*i)) { + return llvm::cast(*i); + } } + return builder.create(builder.getUnknownLoc(), tile); } -bool AMDAIEPathfinderPass::attemptFixupMemTileRouting( - const OpBuilder &builder, SwitchboxOp northSwOp, SwitchboxOp southSwOp, - ConnectOp &problemConnect) { - int problemNorthChannel; - if (problemConnect.getSourceBundle() == WireBundle::North) { - problemNorthChannel = problemConnect.getSourceChannel(); - } else if (problemConnect.getDestBundle() == WireBundle::North) { - problemNorthChannel = problemConnect.getDestChannel(); - } else - return false; // Problem is not about n-s routing - int problemSouthChannel; - if (problemConnect.getSourceBundle() == WireBundle::South) { - problemSouthChannel = problemConnect.getSourceChannel(); - } else if (problemConnect.getDestBundle() == WireBundle::South) { - problemSouthChannel = problemConnect.getDestChannel(); - } else - return false; // Problem is not about n-s routing - - // Attempt to reroute northern neighbouring sw - if (reconnectConnectOps(builder, northSwOp, problemConnect, true, - WireBundle::South, problemNorthChannel, - problemSouthChannel)) - return true; - if (reconnectConnectOps(builder, northSwOp, problemConnect, false, - WireBundle::South, problemNorthChannel, - problemSouthChannel)) - return true; - // Otherwise, attempt to reroute southern neighbouring sw - if (reconnectConnectOps(builder, southSwOp, problemConnect, true, - WireBundle::North, problemSouthChannel, - problemNorthChannel)) - return true; - if (reconnectConnectOps(builder, southSwOp, problemConnect, false, - WireBundle::North, problemSouthChannel, - problemNorthChannel)) +template +struct AIEOpRemoval : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename MyOp::Adaptor; + + explicit AIEOpRemoval(MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit) {} + + LogicalResult matchAndRewrite( + MyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Operation *Op = op.getOperation(); + + rewriter.eraseOp(Op); + return success(); + } +}; + +bool AIEPathfinderPass::findPathToDest(SwitchSettings settings, TileID currTile, + WireBundle currDestBundle, + int currDestChannel, TileID finalTile, + WireBundle finalDestBundle, + int finalDestChannel) { + if ((currTile == finalTile) && (currDestBundle == finalDestBundle) && + (currDestChannel == finalDestChannel)) { return true; + } + + WireBundle neighbourSourceBundle; + TileID neighbourTile; + if (currDestBundle == WireBundle::East) { + neighbourSourceBundle = WireBundle::West; + neighbourTile = {currTile.col + 1, currTile.row}; + } else if (currDestBundle == WireBundle::West) { + neighbourSourceBundle = WireBundle::East; + neighbourTile = {currTile.col - 1, currTile.row}; + } else if (currDestBundle == WireBundle::North) { + neighbourSourceBundle = WireBundle::South; + neighbourTile = {currTile.col, currTile.row + 1}; + } else if (currDestBundle == WireBundle::South) { + neighbourSourceBundle = WireBundle::North; + neighbourTile = {currTile.col, currTile.row - 1}; + } else { + return false; + } + + int neighbourSourceChannel = currDestChannel; + for (const auto &[sbNode, setting] : settings) { + TileID tile = {sbNode.col, sbNode.row}; + if ((tile == neighbourTile) && + (setting.src.bundle == neighbourSourceBundle) && + (setting.src.channel == neighbourSourceChannel)) { + for (const auto &[bundle, channel] : setting.dsts) { + if (findPathToDest(settings, neighbourTile, bundle, channel, finalTile, + finalDestBundle, finalDestChannel)) { + return true; + } + } + } + } + return false; } -bool AMDAIEPathfinderPass::reconnectConnectOps(const OpBuilder &builder, - SwitchboxOp sw, - ConnectOp problemConnect, - bool isIncomingToSW, - WireBundle problemBundle, - int problemChan, int emptyChan) { - bool hasEmptyChannelSlot = true; - bool foundCandidateForFixup = false; - ConnectOp candidate; - if (isIncomingToSW) { - for (ConnectOp connect : sw.getOps()) { - if (connect.getDestBundle() == problemBundle && - connect.getDestChannel() == problemChan) { - candidate = connect; - foundCandidateForFixup = true; +void AIEPathfinderPass::runOnPacketFlow(DeviceOp device, OpBuilder &builder) { + ConversionTarget target(getContext()); + + // Map from a port and flowID to + DenseMap, SmallVector> packetFlows; + SmallVector, 4> slavePorts; + DenseMap, int> slaveAMSels; + // Map from a port to + DenseMap keepPktHeaderAttr; + + for (auto tileOp : device.getOps()) { + int col = tileOp.colIndex(); + int row = tileOp.rowIndex(); + tiles[{col, row}] = tileOp; + } + + // The logical model of all the switchboxes. + DenseMap, 8>> switchboxes; + for (PacketFlowOp pktFlowOp : device.getOps()) { + Region &r = pktFlowOp.getPorts(); + Block &b = r.front(); + int flowID = pktFlowOp.IDInt(); + Port srcPort, destPort; + TileOp srcTile, destTile; + TileID srcCoords, destCoords; + + for (Operation &Op : b.getOperations()) { + if (auto pktSource = dyn_cast(Op)) { + srcTile = dyn_cast(pktSource.getTile().getDefiningOp()); + srcPort = pktSource.port(); + srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + } else if (auto pktDest = dyn_cast(Op)) { + destTile = dyn_cast(pktDest.getTile().getDefiningOp()); + destPort = pktDest.port(); + destCoords = {destTile.colIndex(), destTile.rowIndex()}; + // Assign "keep_pkt_header flag" + if (pktFlowOp->hasAttr("keep_pkt_header")) + keepPktHeaderAttr[{destTile, destPort}] = + StringAttr::get(Op.getContext(), "true"); + SwitchboxNode srcSB = analyzer.pathfinder->getSwitchboxNode( + {srcCoords.col, srcCoords.row}); + if (PathEndPoint srcPoint = {srcSB, srcPort}; + !analyzer.processedFlows[srcPoint]) { + SwitchSettings settings = analyzer.flowSolutions[srcPoint]; + // add connections for all the Switchboxes in SwitchSettings + for (const auto &[curr, setting] : settings) { + for (const auto &[bundle, channel] : setting.dsts) { + TileID currTile = {curr.col, curr.row}; + // reject false broadcast + if (!findPathToDest(settings, currTile, bundle, channel, + destCoords, destPort.bundle, + destPort.channel)) + continue; + Connect connect = {{setting.src.bundle, setting.src.channel}, + {bundle, channel}}; + if (std::find(switchboxes[currTile].begin(), + switchboxes[currTile].end(), + std::pair{connect, flowID}) == + switchboxes[currTile].end()) + switchboxes[currTile].push_back({connect, flowID}); + } + } + } } - if (connect.getDestBundle() == problemBundle && - connect.getDestChannel() == emptyChan) { - hasEmptyChannelSlot = false; + } + } + + LLVM_DEBUG(llvm::dbgs() << "Check switchboxes\n"); + + for (const auto &[tileId, connects] : switchboxes) { + int col = tileId.col; + int row = tileId.row; + Operation *tileOp = getOrCreateTile(builder, col, row); + LLVM_DEBUG(llvm::dbgs() << "***switchbox*** " << col << " " << row << '\n'); + for (const auto &[conn, flowID] : connects) { + Port sourcePort = conn.src; + Port destPort = conn.dst; + auto sourceFlow = + std::make_pair(std::make_pair(tileOp, sourcePort), flowID); + packetFlows[sourceFlow].push_back({tileOp, destPort}); + slavePorts.push_back(sourceFlow); + LLVM_DEBUG(llvm::dbgs() << "flowID " << flowID << ':' + << stringifyWireBundle(sourcePort.bundle) << " " + << sourcePort.channel << " -> " + << stringifyWireBundle(destPort.bundle) << " " + << destPort.channel << "\n"); + } + } + + // amsel() + // masterset() + // packetrules() + // rule() + + // Compute arbiter assignments. Each arbiter has four msels. + // Therefore, the number of "logical" arbiters is 6 x 4 = 24 + // A master port can only be associated with one arbiter + + // A map from Tile and master selectValue to the ports targetted by that + // master select. + DenseMap, SmallVector> masterAMSels; + + // Count of currently used logical arbiters for each tile. + DenseMap amselValues; + int numMsels = 4; + int numArbiters = 6; + + std::vector, SmallVector>> + sortedPacketFlows(packetFlows.begin(), packetFlows.end()); + + // To get determinsitic behaviour + std::sort(sortedPacketFlows.begin(), sortedPacketFlows.end(), + [](const auto &lhs, const auto &rhs) { + auto lhsFlowID = lhs.first.second; + auto rhsFlowID = rhs.first.second; + return lhsFlowID < rhsFlowID; + }); + + // Check all multi-cast flows (same source, same ID). They should be + // assigned the same arbiter and msel so that the flow can reach all the + // destination ports at the same time For destination ports that appear in + // different (multicast) flows, it should have a different + // value pair for each flow + for (const auto &packetFlow : sortedPacketFlows) { + // The Source Tile of the flow + Operation *tileOp = packetFlow.first.first.first; + if (amselValues.count(tileOp) == 0) amselValues[tileOp] = 0; + + // arb0: 6*0, 6*1, 6*2, 6*3 + // arb1: 6*0+1, 6*1+1, 6*2+1, 6*3+1 + // arb2: 6*0+2, 6*1+2, 6*2+2, 6*3+2 + // arb3: 6*0+3, 6*1+3, 6*2+3, 6*3+3 + // arb4: 6*0+4, 6*1+4, 6*2+4, 6*3+4 + // arb5: 6*0+5, 6*1+5, 6*2+5, 6*3+5 + + int amselValue = amselValues[tileOp]; + assert(amselValue < numArbiters && "Could not allocate new arbiter!"); + + // Find existing arbiter assignment + // If there is an assignment of an arbiter to a master port before, we + // assign all the master ports here with the same arbiter but different + // msel + bool foundMatchedDest = false; + for (const auto &map : masterAMSels) { + if (map.first.first != tileOp) continue; + amselValue = map.first.second; + + // check if same destinations + SmallVector ports(masterAMSels[{tileOp, amselValue}]); + if (ports.size() != packetFlow.second.size()) continue; + + bool matched = true; + for (auto dest : packetFlow.second) { + if (Port port = dest.second; + std::find(ports.begin(), ports.end(), port) == ports.end()) { + matched = false; + break; + } + } + + if (matched) { + foundMatchedDest = true; + break; } } - } else { - for (ConnectOp connect : sw.getOps()) { - if (connect.getSourceBundle() == problemBundle && - connect.getSourceChannel() == problemChan) { - candidate = connect; - foundCandidateForFixup = true; + + if (!foundMatchedDest) { + bool foundAMSelValue = false; + for (int a = 0; a < numArbiters; a++) { + for (int i = 0; i < numMsels; i++) { + amselValue = a + i * numArbiters; + if (masterAMSels.count({tileOp, amselValue}) == 0) { + foundAMSelValue = true; + break; + } + } + + if (foundAMSelValue) break; } - if (connect.getSourceBundle() == problemBundle && - connect.getSourceChannel() == emptyChan) { - hasEmptyChannelSlot = false; + + for (auto dest : packetFlow.second) { + Port port = dest.second; + masterAMSels[{tileOp, amselValue}].push_back(port); } } + + slaveAMSels[packetFlow.first] = amselValue; + amselValues[tileOp] = amselValue % numArbiters; } - if (foundCandidateForFixup && hasEmptyChannelSlot) { - WireBundle problemBundleOpposite = problemBundle == WireBundle::North - ? WireBundle::South - : WireBundle::North; - // Found empty channel slot, perform reroute - if (isIncomingToSW) { - replaceConnectOpWithNewDest(builder, candidate, problemBundle, emptyChan); - replaceConnectOpWithNewSource(builder, problemConnect, - problemBundleOpposite, emptyChan); - } else { - replaceConnectOpWithNewSource(builder, candidate, problemBundle, - emptyChan); - replaceConnectOpWithNewDest(builder, problemConnect, - problemBundleOpposite, emptyChan); + + // Compute the master set IDs + // A map from a switchbox output port to the number of that port. + DenseMap> mastersets; + for (const auto &[physPort, ports] : masterAMSels) { + Operation *tileOp = physPort.first; + assert(tileOp); + int amselValue = physPort.second; + for (auto port : ports) { + PhysPort physPort = {tileOp, port}; + mastersets[physPort].push_back(amselValue); } - return true; } - return false; -} -// Replace connect op -ConnectOp AMDAIEPathfinderPass::replaceConnectOpWithNewDest( - OpBuilder builder, ConnectOp connect, WireBundle newBundle, - int newChannel) { - builder.setInsertionPoint(connect); - auto newOp = builder.create( - builder.getUnknownLoc(), connect.getSourceBundle(), - connect.getSourceChannel(), newBundle, newChannel); - connect.erase(); - return newOp; + LLVM_DEBUG(llvm::dbgs() << "CHECK mastersets\n"); +#ifndef NDEBUG + for (const auto &[physPort, values] : mastersets) { + Operation *tileOp = physPort.first; + WireBundle bundle = physPort.second.bundle; + int channel = physPort.second.channel; + assert(tileOp); + auto tile = dyn_cast(tileOp); + LLVM_DEBUG(llvm::dbgs() + << "master " << tile << " " << stringifyWireBundle(bundle) + << " : " << channel << '\n'); + for (auto value : values) + LLVM_DEBUG(llvm::dbgs() << "amsel: " << value << '\n'); + } +#endif + + // Compute mask values + // Merging as many stream flows as possible + // The flows must originate from the same source port and have different IDs + // Two flows can be merged if they share the same destinations + SmallVector, 4>, 4> slaveGroups; + SmallVector, 4> workList(slavePorts); + while (!workList.empty()) { + auto slave1 = workList.pop_back_val(); + Port slavePort1 = slave1.first.second; + + bool foundgroup = false; + for (auto &group : slaveGroups) { + auto slave2 = group.front(); + if (Port slavePort2 = slave2.first.second; slavePort1 != slavePort2) + continue; + + bool matched = true; + auto dests1 = packetFlows[slave1]; + auto dests2 = packetFlows[slave2]; + if (dests1.size() != dests2.size()) continue; + + for (auto dest1 : dests1) { + if (std::find(dests2.begin(), dests2.end(), dest1) == dests2.end()) { + matched = false; + break; + } + } + + if (matched) { + group.push_back(slave1); + foundgroup = true; + break; + } + } + + if (!foundgroup) { + SmallVector, 4> group({slave1}); + slaveGroups.push_back(group); + } + } + + DenseMap, int> slaveMasks; + for (const auto &group : slaveGroups) { + // Iterate over all the ID values in a group + // If bit n-th (n <= 5) of an ID value differs from bit n-th of another ID + // value, the bit position should be "don't care", and we will set the + // mask bit of that position to 0 + int mask[5] = {-1, -1, -1, -1, -1}; + for (auto port : group) { + int ID = port.second; + for (int i = 0; i < 5; i++) { + if (mask[i] == -1) + mask[i] = ID >> i & 0x1; + else if (mask[i] != (ID >> i & 0x1)) + mask[i] = 2; // found bit difference --> mark as "don't care" + } + } + + int maskValue = 0; + for (int i = 4; i >= 0; i--) { + if (mask[i] == 2) // don't care + mask[i] = 0; + else + mask[i] = 1; + maskValue = (maskValue << 1) + mask[i]; + } + for (auto port : group) slaveMasks[port] = maskValue; + } + +#ifndef NDEBUG + LLVM_DEBUG(llvm::dbgs() << "CHECK Slave Masks\n"); + for (auto map : slaveMasks) { + auto port = map.first.first; + auto tile = dyn_cast(port.first); + WireBundle bundle = port.second.bundle; + int channel = port.second.channel; + int ID = map.first.second; + int mask = map.second; + + LLVM_DEBUG(llvm::dbgs() + << "Port " << tile << " " << stringifyWireBundle(bundle) << " " + << channel << '\n'); + LLVM_DEBUG(llvm::dbgs() << "Mask " + << "0x" << llvm::Twine::utohexstr(mask) << '\n'); + LLVM_DEBUG(llvm::dbgs() << "ID " + << "0x" << llvm::Twine::utohexstr(ID) << '\n'); + for (int i = 0; i < 31; i++) { + if ((i & mask) == (ID & mask)) + LLVM_DEBUG(llvm::dbgs() << "matches flow ID " + << "0x" << llvm::Twine::utohexstr(i) << '\n'); + } + } +#endif + + // Realize the routes in MLIR + for (auto map : tiles) { + Operation *tileOp = map.second; + auto tile = dyn_cast(tileOp); + + // Create a switchbox for the routes and insert inside it. + builder.setInsertionPointAfter(tileOp); + SwitchboxOp swbox = getOrCreateSwitchbox(builder, tile); + SwitchboxOp::ensureTerminator(swbox.getConnections(), builder, + builder.getUnknownLoc()); + Block &b = swbox.getConnections().front(); + builder.setInsertionPoint(b.getTerminator()); + + std::vector amselOpNeededVector(32); + for (const auto &map : mastersets) { + if (tileOp != map.first.first) continue; + + for (auto value : map.second) { + amselOpNeededVector[value] = true; + } + } + // Create all the amsel Ops + DenseMap amselOps; + for (int i = 0; i < 32; i++) { + if (amselOpNeededVector[i]) { + int arbiterID = i % numArbiters; + int msel = i / numArbiters; + auto amsel = + builder.create(builder.getUnknownLoc(), arbiterID, msel); + amselOps[i] = amsel; + } + } + // Create all the master set Ops + // First collect the master sets for this tile. + SmallVector tileMasters; + for (const auto &map : mastersets) { + if (tileOp != map.first.first) continue; + tileMasters.push_back(map.first.second); + } + // Sort them so we get a reasonable order + std::sort(tileMasters.begin(), tileMasters.end()); + for (auto tileMaster : tileMasters) { + WireBundle bundle = tileMaster.bundle; + int channel = tileMaster.channel; + SmallVector msels = mastersets[{tileOp, tileMaster}]; + SmallVector amsels; + for (auto msel : msels) { + assert(amselOps.count(msel) == 1); + amsels.push_back(amselOps[msel]); + } + + auto msOp = builder.create(builder.getUnknownLoc(), + builder.getIndexType(), bundle, + channel, amsels); + if (auto pktFlowAttrs = keepPktHeaderAttr[{tileOp, tileMaster}]) + msOp->setAttr("keep_pkt_header", pktFlowAttrs); + } + + // Generate the packet rules + DenseMap slaveRules; + for (auto group : slaveGroups) { + builder.setInsertionPoint(b.getTerminator()); + + auto port = group.front().first; + if (tileOp != port.first) continue; + + WireBundle bundle = port.second.bundle; + int channel = port.second.channel; + auto slave = port.second; + + int mask = slaveMasks[group.front()]; + int ID = group.front().second & mask; + + // Verify that we actually map all the ID's correctly. +#ifndef NDEBUG + for (auto slave : group) assert((slave.second & mask) == ID); +#endif + Value amsel = amselOps[slaveAMSels[group.front()]]; + + PacketRulesOp packetrules; + if (slaveRules.count(slave) == 0) { + packetrules = builder.create(builder.getUnknownLoc(), + bundle, channel); + PacketRulesOp::ensureTerminator(packetrules.getRules(), builder, + builder.getUnknownLoc()); + slaveRules[slave] = packetrules; + } else + packetrules = slaveRules[slave]; + + Block &rules = packetrules.getRules().front(); + builder.setInsertionPoint(rules.getTerminator()); + builder.create(builder.getUnknownLoc(), mask, ID, amsel); + } + } + + AMDAIEDeviceModel targetModel = + getDeviceModel(static_cast(device.getDevice())); + + // Add support for shimDMA + // From shimDMA to BLI: 1) shimDMA 0 --> North 3 + // 2) shimDMA 1 --> North 7 + // From BLI to shimDMA: 1) North 2 --> shimDMA 0 + // 2) North 3 --> shimDMA 1 + + for (auto switchbox : make_early_inc_range(device.getOps())) { + auto retVal = switchbox->getOperand(0); + auto tileOp = retVal.getDefiningOp(); + + // Check if it is a shim Tile + if (!targetModel.isShimNOCTile(tileOp.getCol(), tileOp.getRow())) continue; + + // Check if the switchbox is empty + if (&switchbox.getBody()->front() == switchbox.getBody()->getTerminator()) + continue; + + Region &r = switchbox.getConnections(); + Block &b = r.front(); + + // Find if the corresponding shimmux exsists or not + int shimExist = 0; + ShimMuxOp shimOp; + for (auto shimmux : device.getOps()) { + if (shimmux.getTile() == tileOp) { + shimExist = 1; + shimOp = shimmux; + break; + } + } + + for (Operation &Op : b.getOperations()) { + if (auto pktrules = dyn_cast(Op)) { + // check if there is MM2S DMA in the switchbox of the 0th row + if (pktrules.getSourceBundle() == WireBundle::DMA) { + // If there is, then it should be put into the corresponding shimmux + // If shimmux not defined then create shimmux + if (!shimExist) { + builder.setInsertionPointAfter(tileOp); + shimOp = builder.create(builder.getUnknownLoc(), tileOp); + Region &r1 = shimOp.getConnections(); + Block *b1 = builder.createBlock(&r1); + builder.setInsertionPointToEnd(b1); + builder.create(builder.getUnknownLoc()); + shimExist = 1; + } + + Region &r0 = shimOp.getConnections(); + Block &b0 = r0.front(); + builder.setInsertionPointToStart(&b0); + + pktrules.setSourceBundle(WireBundle::South); + if (pktrules.getSourceChannel() == 0) { + pktrules.setSourceChannel(3); + builder.create(builder.getUnknownLoc(), WireBundle::DMA, + 0, WireBundle::North, 3); + } + if (pktrules.getSourceChannel() == 1) { + pktrules.setSourceChannel(7); + builder.create(builder.getUnknownLoc(), WireBundle::DMA, + 1, WireBundle::North, 7); + } + } + } + + if (auto mtset = dyn_cast(Op)) { + // check if there is S2MM DMA in the switchbox of the 0th row + if (mtset.getDestBundle() == WireBundle::DMA) { + // If there is, then it should be put into the corresponding shimmux + // If shimmux not defined then create shimmux + if (!shimExist) { + builder.setInsertionPointAfter(tileOp); + shimOp = builder.create(builder.getUnknownLoc(), tileOp); + Region &r1 = shimOp.getConnections(); + Block *b1 = builder.createBlock(&r1); + builder.setInsertionPointToEnd(b1); + builder.create(builder.getUnknownLoc()); + shimExist = 1; + } + + Region &r0 = shimOp.getConnections(); + Block &b0 = r0.front(); + builder.setInsertionPointToStart(&b0); + + mtset.setDestBundle(WireBundle::South); + if (mtset.getDestChannel() == 0) { + mtset.setDestChannel(2); + builder.create(builder.getUnknownLoc(), + WireBundle::North, 2, WireBundle::DMA, 0); + } + if (mtset.getDestChannel() == 1) { + mtset.setDestChannel(3); + builder.create(builder.getUnknownLoc(), + WireBundle::North, 3, WireBundle::DMA, 1); + } + } + } + } + } + + RewritePatternSet patterns(&getContext()); + + if (!clKeepFlowOp) + patterns.add>(device.getContext()); + + if (failed(applyPartialConversion(device, target, std::move(patterns)))) + signalPassFailure(); } -ConnectOp AMDAIEPathfinderPass::replaceConnectOpWithNewSource( - OpBuilder builder, ConnectOp connect, WireBundle newBundle, - int newChannel) { - builder.setInsertionPoint(connect); - auto newOp = builder.create(builder.getUnknownLoc(), newBundle, - newChannel, connect.getDestBundle(), - connect.getDestChannel()); - connect.erase(); - return newOp; +void AIEPathfinderPass::runOnOperation() { + // create analysis pass with routing graph for entire device + LLVM_DEBUG(llvm::dbgs() << "---Begin AIEPathfinderPass---\n"); + + DeviceOp d = getOperation(); + if (failed(analyzer.runAnalysis(d))) return signalPassFailure(); + OpBuilder builder = OpBuilder::atBlockEnd(d.getBody()); + + if (clRouteCircuit) runOnFlow(d, builder); + if (clRoutePacket) runOnPacketFlow(d, builder); } -SwitchboxOp AMDAIEPathfinderPass::getSwitchbox(DeviceOp &d, int col, int row) { +SwitchboxOp AIEPathfinderPass::getSwitchbox(DeviceOp &d, int col, int row) { SwitchboxOp output = nullptr; d.walk([&](SwitchboxOp swBox) { if (swBox.colIndex() == col && swBox.rowIndex() == row) { @@ -1442,8 +2085,11 @@ SwitchboxOp AMDAIEPathfinderPass::getSwitchbox(DeviceOp &d, int col, int row) { return output; } +} // namespace mlir::iree_compiler::AMDAIE + +namespace mlir::iree_compiler::AMDAIE { std::unique_ptr> createAMDAIEPathfinderPass() { - return std::make_unique(); + return std::make_unique(); } void registerAMDAIERoutePathfinderFlows() { @@ -1451,5 +2097,4 @@ void registerAMDAIERoutePathfinderFlows() { return createAMDAIEPathfinderPass(); }); } - } // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/Passes.h b/compiler/plugins/target/AMD-AIE/aie/Passes.h index 3204a53bb..107ab7cbf 100644 --- a/compiler/plugins/target/AMD-AIE/aie/Passes.h +++ b/compiler/plugins/target/AMD-AIE/aie/Passes.h @@ -11,6 +11,13 @@ #include "mlir/Pass/Pass.h" namespace mlir::iree_compiler::AMDAIE { + +struct AIERoutePathfinderFlowsOptions { + bool clRouteCircuit = true; + bool clRoutePacket = true; + bool clKeepFlowOp = false; +}; + std::unique_ptr> createAMDAIEAssignBufferAddressesBasicPass(); std::unique_ptr> diff --git a/compiler/plugins/target/AMD-AIE/aie/test/unit_simple.mlir b/compiler/plugins/target/AMD-AIE/aie/test/unit_simple.mlir index dd47270aa..912c33e0a 100644 --- a/compiler/plugins/target/AMD-AIE/aie/test/unit_simple.mlir +++ b/compiler/plugins/target/AMD-AIE/aie/test/unit_simple.mlir @@ -1,5 +1,5 @@ -// RUN: iree-opt --amdaie-create-pathfinder-flows %s | FileCheck %s +// RUN: iree-opt --amdaie-create-pathfinder-flows="route-circuit=true route-packet=false" %s | FileCheck %s // CHECK-LABEL: aie.device(xcvc1902) { // CHECK: %[[TILE_0_1:.*]] = aie.tile(0, 1) diff --git a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc index f83b12189..69029ecf5 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc +++ b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc @@ -59,6 +59,47 @@ bool isSouth(uint8_t srcCol, uint8_t srcRow, uint8_t dstCol, uint8_t dstRow) { } } // namespace +namespace MLIRAIELegacy { +using mlir::iree_compiler::AMDAIE::StrmSwPortType; +namespace VC1902TargetModel { +int columns(); +int rows(); +bool isShimNOCTile(int col, int row); +bool isShimPLTile(int col, int row); +bool isShimNOCorPLTile(int col, int row); +uint32_t getNumDestSwitchboxConnections(int col, int row, + StrmSwPortType bundle); +uint32_t getNumSourceSwitchboxConnections(int col, int row, + StrmSwPortType bundle); +uint32_t getNumDestShimMuxConnections(int col, int row, StrmSwPortType bundle); +uint32_t getNumSourceShimMuxConnections(int col, int row, + StrmSwPortType bundle); +bool isCoreTile(int col, int row); +bool isMemTile(int col, int row); +bool isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, int dstChan); +} // namespace VC1902TargetModel + +namespace VE2802TargetModel { +int columns(); +int rows(); +bool isShimNOCTile(int col, int row); +bool isShimPLTile(int col, int row); +bool isShimNOCorPLTile(int col, int row); +uint32_t getNumDestSwitchboxConnections(int col, int row, + StrmSwPortType bundle); +uint32_t getNumSourceSwitchboxConnections(int col, int row, + StrmSwPortType bundle); +uint32_t getNumDestShimMuxConnections(int col, int row, StrmSwPortType bundle); +uint32_t getNumSourceShimMuxConnections(int col, int row, + StrmSwPortType bundle); +bool isCoreTile(int col, int row); +bool isMemTile(int col, int row); +bool isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, int dstChan); +} // namespace VE2802TargetModel +} // namespace MLIRAIELegacy + namespace mlir::iree_compiler::AMDAIE { std::string to_string(const StrmSwPortType &value) { @@ -73,12 +114,35 @@ std::string to_string(const StrmSwPortType &value) { STRINGIFY_ENUM_CASE(StrmSwPortType::EAST) STRINGIFY_ENUM_CASE(StrmSwPortType::TRACE) STRINGIFY_ENUM_CASE(StrmSwPortType::UCTRLR) + STRINGIFY_ENUM_CASE(StrmSwPortType::NOC) + STRINGIFY_ENUM_CASE(StrmSwPortType::PLIO) STRINGIFY_ENUM_CASE(StrmSwPortType::SS_PORT_TYPE_MAX) } llvm::report_fatal_error("Unhandled StrmSwPortType case"); } +std::string to_string(const ::StrmSwPortType &value) { + using StrmSwPortType = ::StrmSwPortType; + switch (value) { + STRINGIFY_ENUM_CASE(StrmSwPortType::CORE) + STRINGIFY_ENUM_CASE(StrmSwPortType::DMA) + STRINGIFY_ENUM_CASE(StrmSwPortType::CTRL) + STRINGIFY_ENUM_CASE(StrmSwPortType::FIFO) + STRINGIFY_ENUM_CASE(StrmSwPortType::SOUTH) + STRINGIFY_ENUM_CASE(StrmSwPortType::WEST) + STRINGIFY_ENUM_CASE(StrmSwPortType::NORTH) + STRINGIFY_ENUM_CASE(StrmSwPortType::EAST) + STRINGIFY_ENUM_CASE(StrmSwPortType::TRACE) + STRINGIFY_ENUM_CASE(StrmSwPortType::UCTRLR) + // STRINGIFY_ENUM_CASE(::StrmSwPortType::NOC) + // STRINGIFY_ENUM_CASE(::StrmSwPortType::PLIO) + STRINGIFY_ENUM_CASE(::StrmSwPortType::SS_PORT_TYPE_MAX) + } + + llvm::report_fatal_error("Unhandled StrmSwPortType case"); +} + std::string to_string(const AieRC &value) { switch (value) { STRINGIFY_ENUM_CASE(AieRC::XAIE_OK) @@ -146,7 +210,7 @@ AMDAIEDeviceModel::AMDAIEDeviceModel( uint8_t aieGen, uint64_t baseAddr, uint8_t colShift, uint8_t rowShift, uint8_t devNColumns, uint8_t devNRows, uint8_t memTileRowStart, uint8_t nMemTileRows, uint8_t nShimTileRows, int partitionNumCols, - int partitionStartCol, bool aieSim, bool xaieDebug) + int partitionStartCol, bool aieSim, bool xaieDebug, AMDAIEDevice device) : configPtr{.AieGen = aieGen, .BaseAddr = baseAddr, .ColShift = colShift, @@ -163,7 +227,8 @@ AMDAIEDeviceModel::AMDAIEDeviceModel( .AieTileNumRows = static_cast(devNRows - nMemTileRows - nShimTileRows), .PartProp = {}}, - devInst{} { + devInst{}, + device(device) { uint64_t partBaseAddr; if (aieGen == XAIE_DEV_GEN_AIE) partBaseAddr = XAIE1_PARTITION_BASE_ADDR; @@ -201,9 +266,19 @@ AMDAIEDeviceModel::AMDAIEDeviceModel( TRY_XAIE_API_FATAL_ERROR(XAie_TurnEccOff, &devInst); } -int AMDAIEDeviceModel::rows() const { return configPtr.NumRows; } +int AMDAIEDeviceModel::rows() const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::rows(); + else + return configPtr.NumRows; +} -int AMDAIEDeviceModel::columns() const { return configPtr.NumCols; } +int AMDAIEDeviceModel::columns() const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::columns(); + else + return configPtr.NumCols; +} // TODO(max): these are buried somewhere in aie-rt... uint32_t AMDAIEDeviceModel::getMemSouthBaseAddress() const { @@ -239,10 +314,27 @@ bool AMDAIEDeviceModel::isMemTile(uint8_t col, uint8_t row) const { } bool AMDAIEDeviceModel::isShimNOCTile(uint8_t col, uint8_t row) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::isShimNOCTile(col, row); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::isShimNOCTile(col, row); + return getTileType(col, row) == AMDAIETileType::SHIMNOC; } +bool AMDAIEDeviceModel::isShimNOCorPLTile(uint8_t col, uint8_t row) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::isShimNOCorPLTile(col, row); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::isShimNOCorPLTile(col, row); + return isShimNOCTile(col, row) || isShimPLTile(col, row); +} + bool AMDAIEDeviceModel::isShimPLTile(uint8_t col, uint8_t row) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::isShimPLTile(col, row); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::isShimPLTile(col, row); return getTileType(col, row) == AMDAIETileType::SHIMPL; } @@ -360,14 +452,15 @@ bool AMDAIEDeviceModel::isLegalMemtileConnection(uint8_t col, uint8_t row, StrmSwPortType dstBundle, uint8_t dstChan) const { // TODO(max): this isn't correct but for agreement with mlir-aie... - if (srcBundle == dstBundle and srcBundle != DMA) return true; + if (srcBundle == dstBundle and srcBundle != StrmSwPortType::DMA) return true; assert(isMemTile(col, row) && "expected memtile"); AMDAIETileType tileType = getTileType(col, row); assert(tileType == AMDAIETileType::MEMTILE && "expected memtile"); const XAie_StrmMod *strmMod = devInst.DevProp.DevMod[static_cast(tileType)].StrmSw; - AieRC RC = strmMod->PortVerify(/*slave*/ srcBundle, srcChan, - /*master*/ dstBundle, dstChan); + AieRC RC = strmMod->PortVerify( + /*slave*/ static_cast<::StrmSwPortType>(srcBundle), srcChan, + /*master*/ static_cast<::StrmSwPortType>(dstBundle), dstChan); if (RC != XAIE_OK) { LLVM_DEBUG(llvm::dbgs() << "PortVerify failed with " << RC << "\n"); LLVM_DEBUG(SHOW_ARGS(llvm::dbgs(), col, row, srcBundle, (int)srcChan, @@ -381,29 +474,99 @@ bool AMDAIEDeviceModel::isLegalMemtileConnection(uint8_t col, uint8_t row, // source <-> slave and dest <-> master uint32_t AMDAIEDeviceModel::getNumSourceSwitchboxConnections( uint8_t col, uint8_t row, StrmSwPortType bundle) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::getNumSourceSwitchboxConnections( + col, row, bundle); + AMDAIETileType tileType = getTileType(col, row); // not sure if this makes sense but agrees with mlir-aie - if ((bundle == NORTH && row == rows() - 1) || (bundle == WEST && col == 0) || - (bundle == EAST && col == columns() - 1) || + if ((bundle == StrmSwPortType::NORTH && row == rows() - 1) || + (bundle == StrmSwPortType::WEST && col == 0) || + (bundle == StrmSwPortType::EAST && col == columns() - 1) || tileType == AMDAIETileType::MAX) return 0; const XAie_StrmMod *strmMod = devInst.DevProp.DevMod[static_cast(tileType)].StrmSw; - return strmMod->SlvConfig[bundle].NumPorts; + return strmMod->SlvConfig[static_cast<::StrmSwPortType>(bundle)].NumPorts; } uint32_t AMDAIEDeviceModel::getNumDestSwitchboxConnections( uint8_t col, uint8_t row, StrmSwPortType bundle) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::getNumDestSwitchboxConnections( + col, row, bundle); + AMDAIETileType tileType = getTileType(col, row); // not sure if this makes sense but agrees with mlir-aie - if ((bundle == NORTH && row == rows() - 1) || (bundle == WEST && col == 0) || - (bundle == EAST && col == columns() - 1) || + if ((bundle == StrmSwPortType::NORTH && row == rows() - 1) || + (bundle == StrmSwPortType::WEST && col == 0) || + (bundle == StrmSwPortType::EAST && col == columns() - 1) || tileType == AMDAIETileType::MAX) return 0; const XAie_StrmMod *strmMod = devInst.DevProp.DevMod[static_cast(tileType)].StrmSw; - return strmMod->MstrConfig[bundle].NumPorts; + return strmMod->MstrConfig[static_cast<::StrmSwPortType>(bundle)].NumPorts; +} + +uint32_t AMDAIEDeviceModel::getNumShimMuxConnections( + uint8_t col, uint8_t row, StrmSwPortType bundle) const { + if (isShimNOCorPLTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::NOC: + return 4; + case StrmSwPortType::PLIO: + return 8; + case StrmSwPortType::SOUTH: + return 6; // Connection to the south port of the stream switch + default: + return 0; + } + return 0; +} + +uint32_t AMDAIEDeviceModel::getNumSourceShimMuxConnections( + uint8_t col, uint8_t row, StrmSwPortType bundle) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::getNumSourceShimMuxConnections( + col, row, bundle); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::getNumSourceShimMuxConnections( + col, row, bundle); + assert(device == AMDAIEDevice::npu1_4col && "expected npu1_4col"); + return getNumShimMuxConnections(col, row, bundle); +} + +uint32_t AMDAIEDeviceModel::getNumDestShimMuxConnections( + uint8_t col, uint8_t row, StrmSwPortType bundle) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::getNumDestShimMuxConnections( + col, row, bundle); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::getNumDestShimMuxConnections( + col, row, bundle); + assert(device == AMDAIEDevice::npu1_4col && "expected npu1_4col"); + return getNumShimMuxConnections(col, row, bundle); +} + +bool AMDAIEDeviceModel::isLegalTileConnection(int col, int row, + StrmSwPortType srcBundle, + int srcChan, + StrmSwPortType dstBundle, + int dstChan) const { + if (device == AMDAIEDevice::xcvc1902) + return MLIRAIELegacy::VC1902TargetModel::isLegalTileConnection( + col, row, srcBundle, srcChan, dstBundle, dstChan); + if (device == AMDAIEDevice::xcve2802) + return MLIRAIELegacy::VE2802TargetModel::isLegalTileConnection( + col, row, srcBundle, srcChan, dstBundle, dstChan); + if (device == AMDAIEDevice::npu1_4col) + return _isLegalTileConnection(col, row, srcBundle, srcChan, dstBundle, + dstChan); + llvm::report_fatal_error( + llvm::Twine("isLegalTileConnection unsupported for device: ") + + stringifyAMDAIEDevice(device)); } uint32_t AMDAIEDeviceModel::getColumnShift() const { @@ -414,14 +577,15 @@ uint32_t AMDAIEDeviceModel::getRowShift() const { return configPtr.RowShift; } struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device) { switch (device) { case AMDAIEDevice::xcvc1902: - return AMDAIEDeviceModel( - XAIE_DEV_GEN_AIE, XAIE1_BASE_ADDR, XAIE1_COL_SHIFT, XAIE1_ROW_SHIFT, - XAIE1_NUM_COLS, XAIE1_NUM_ROWS, XAIE1_MEM_TILE_ROW_START, - XAIE1_MEM_TILE_NUM_ROWS, - // mlir-aie disagrees with aie-rt here - /*nShimTileRows*/ 0, - /*partitionNumCols*/ 50, - /*partitionStartCol*/ 0, /*aieSim*/ false, /*xaieDebug*/ false); + return AMDAIEDeviceModel(XAIE_DEV_GEN_AIE, XAIE1_BASE_ADDR, + XAIE1_COL_SHIFT, XAIE1_ROW_SHIFT, XAIE1_NUM_COLS, + XAIE1_NUM_ROWS, XAIE1_MEM_TILE_ROW_START, + XAIE1_MEM_TILE_NUM_ROWS, + // mlir-aie disagrees with aie-rt here + /*nShimTileRows*/ 0, + /*partitionNumCols*/ 50, + /*partitionStartCol*/ 0, /*aieSim*/ false, + /*xaieDebug*/ false, device); case AMDAIEDevice::xcve2302: return AMDAIEDeviceModel(XAIE_DEV_GEN_AIEML, XAIEML_BASE_ADDR, XAIEML_COL_SHIFT, XAIEML_ROW_SHIFT, @@ -431,7 +595,7 @@ struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device) { /*partitionNumCols*/ 17, /*partitionStartCol*/ 0, /*aieSim*/ false, - /*xaieDebug*/ false); + /*xaieDebug*/ false, device); case AMDAIEDevice::xcve2802: return AMDAIEDeviceModel(XAIE_DEV_GEN_AIEML, XAIEML_BASE_ADDR, XAIEML_COL_SHIFT, XAIEML_ROW_SHIFT, @@ -440,7 +604,7 @@ struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device) { XAIEML_MEM_TILE_NUM_ROWS, XAIEML_SHIM_NUM_ROWS, /*partitionNumCols*/ 38, /*partitionStartCol*/ 0, - /*aieSim*/ false, /*xaieDebug*/ false); + /*aieSim*/ false, /*xaieDebug*/ false, device); case AMDAIEDevice::npu1: case AMDAIEDevice::npu1_1col: case AMDAIEDevice::npu1_2col: @@ -476,9 +640,589 @@ struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device) { XAIE2IPU_ROW_SHIFT, XAIE2IPU_NUM_COLS, XAIE2IPU_NUM_ROWS, XAIE2IPU_MEM_TILE_ROW_START, XAIE2IPU_MEM_TILE_NUM_ROWS, XAIE2IPU_SHIM_NUM_ROWS, partitionNumCols, partitionStartCol, - /*aieSim*/ false, /*xaieDebug*/ false); + /*aieSim*/ false, /*xaieDebug*/ false, device); } llvm::report_fatal_error("Unhandled AMDAIEDevice case"); } } // namespace mlir::iree_compiler::AMDAIE + +namespace MLIRAIELegacy { +namespace VC1902TargetModel { +llvm::SmallDenseSet nocColumns = {2, 3, 6, 7, 10, 11, 18, 19, + 26, 27, 34, 35, 42, 43, 46, 47}; + +int columns() { return 50; } + +int rows() { return 9; /* One Shim row and 8 CORE rows. */ } + +bool isShimNOCTile(int col, int row) { + return row == 0 && nocColumns.contains(col); +} + +bool isShimPLTile(int col, int row) { + return row == 0 && !nocColumns.contains(col); +} + +bool isShimNOCorPLTile(int col, int row) { + return isShimNOCTile(col, row) || isShimPLTile(col, row); +} + +uint32_t getNumDestSwitchboxConnections(int col, int row, + StrmSwPortType bundle) { + if (isShimNOCTile(col, row) || isShimPLTile(col, row)) switch (bundle) { + case StrmSwPortType::FIFO: + return 2; + case StrmSwPortType::NORTH: + return 6; + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 6; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::CTRL: + return isShimNOCTile(col, row) ? 1 : 0; + default: + return 0; + } + + switch (bundle) { + case StrmSwPortType::CORE: + case StrmSwPortType::DMA: + case StrmSwPortType::FIFO: + return 2; + case StrmSwPortType::NORTH: { + if (row == rows() - 1) return 0; + return 6; + } + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 4; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } +} + +uint32_t getNumSourceSwitchboxConnections(int col, int row, + StrmSwPortType bundle) { + if (isShimNOCTile(col, row) || isShimPLTile(col, row)) switch (bundle) { + case StrmSwPortType::FIFO: + return 2; + case StrmSwPortType::NORTH: + return 4; + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 8; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::TRACE: + return 1; + case StrmSwPortType::CTRL: + return isShimNOCTile(col, row) ? 1 : 0; + default: + return 0; + } + + switch (bundle) { + case StrmSwPortType::CORE: + case StrmSwPortType::DMA: + case StrmSwPortType::FIFO: + return 2; + case StrmSwPortType::NORTH: { + if (row == rows() - 1) return 0; + return 4; + } + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 6; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::TRACE: + return 2; + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } +} +uint32_t getNumDestShimMuxConnections(int col, int row, StrmSwPortType bundle) { + if (isShimNOCorPLTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::NOC: + return 4; + case StrmSwPortType::PLIO: + return 6; + case StrmSwPortType::SOUTH: + return 8; // Connection to the south port of the stream switch + default: + return 0; + } + return 0; +} +uint32_t getNumSourceShimMuxConnections(int col, int row, + StrmSwPortType bundle) { + if (isShimNOCorPLTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::NOC: + return 4; + case StrmSwPortType::PLIO: + return 8; + case StrmSwPortType::SOUTH: + return 6; // Connection to the south port of the stream switch + default: + return 0; + } + return 0; +} + +bool isCoreTile(int col, int row) { return row > 0; } +bool isMemTile(int col, int row) { return false; } + +bool isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, int dstChan) { + // Check Channel Id within the range + if (srcChan >= int(getNumSourceSwitchboxConnections(col, row, srcBundle))) + return false; + if (dstChan >= int(getNumDestSwitchboxConnections(col, row, dstBundle))) + return false; + + // Memtile + if (isMemTile(col, row)) { + return false; + } + // Shimtile + else if (isShimNOCorPLTile(col, row)) { + if (srcBundle == StrmSwPortType::TRACE) + return dstBundle == StrmSwPortType::SOUTH; + else + return true; + } + // Coretile + else if (isCoreTile(col, row)) { + if (srcBundle == StrmSwPortType::TRACE) + return dstBundle == StrmSwPortType::SOUTH; + else + return true; + } + return false; +} +} // namespace VC1902TargetModel + +namespace VE2802TargetModel { +llvm::SmallDenseSet nocColumns = {2, 3, 6, 7, 14, 15, + 22, 23, 30, 31, 34, 35}; + +bool isShimNOCTile(int col, int row) { + return row == 0 && nocColumns.contains(col); +} + +bool isShimPLTile(int col, int row) { + return row == 0 && !nocColumns.contains(col); +} + +bool isShimNOCorPLTile(int col, int row) { + return isShimNOCTile(col, row) || isShimPLTile(col, row); +} + +int columns() { return 38; } + +int rows() { return 11; /* One Shim row, 2 memtile rows, and 8 Core rows. */ } + +bool isCoreTile(int col, int row) { return row > 2; } + +bool isMemTile(int col, int row) { return row == 1 || row == 2; } + +uint32_t getNumDestShimMuxConnections(int col, int row, StrmSwPortType bundle) { + if (isShimNOCorPLTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::NOC: + return 4; + case StrmSwPortType::PLIO: + return 6; + case StrmSwPortType::SOUTH: + return 8; // Connection to the south port of the stream switch + default: + return 0; + } + + return 0; +} + +uint32_t getNumSourceShimMuxConnections(int col, int row, + StrmSwPortType bundle) { + if (isShimNOCorPLTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::NOC: + return 4; + case StrmSwPortType::PLIO: + return 8; + case StrmSwPortType::SOUTH: + return 6; // Connection to the south port of the stream switch + default: + return 0; + } + + return 0; +} + +uint32_t getNumDestSwitchboxConnections(int col, int row, + StrmSwPortType bundle) { + if (isMemTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + case StrmSwPortType::NORTH: + return 6; + case StrmSwPortType::SOUTH: + return 4; + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } + + if (isShimNOCTile(col, row) || isShimPLTile(col, row)) switch (bundle) { + case StrmSwPortType::FIFO: + return 1; + case StrmSwPortType::NORTH: + return 6; + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 6; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::CTRL: + return isShimNOCTile(col, row) ? 1 : 0; + default: + return 0; + } + + switch (bundle) { + case StrmSwPortType::CORE: + return 1; + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::FIFO: + return 1; + case StrmSwPortType::NORTH: { + if (row == rows() - 1) return 0; + return 6; + } + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 4; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } +} + +uint32_t getNumSourceSwitchboxConnections(int col, int row, + StrmSwPortType bundle) { + if (isMemTile(col, row)) switch (bundle) { + case StrmSwPortType::DMA: + return 6; + case StrmSwPortType::NORTH: + return 4; + case StrmSwPortType::SOUTH: + return 6; + case StrmSwPortType::TRACE: + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } + + if (isShimNOCTile(col, row) || isShimPLTile(col, row)) switch (bundle) { + case StrmSwPortType::FIFO: + return 1; + case StrmSwPortType::NORTH: + return 4; + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 8; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::TRACE: + return 1; + case StrmSwPortType::CTRL: + return isShimNOCTile(col, row) ? 1 : 0; + default: + return 0; + } + + // compute/core tile + switch (bundle) { + case StrmSwPortType::CORE: + return 1; + case StrmSwPortType::DMA: + return 2; + case StrmSwPortType::FIFO: + return 1; + case StrmSwPortType::NORTH: { + if (row == rows() - 1) return 0; + return 4; + } + case StrmSwPortType::WEST: { + if (col == 0) return 0; + return 4; + } + case StrmSwPortType::SOUTH: + return 6; + case StrmSwPortType::EAST: { + if (col == columns() - 1) return 0; + return 4; + } + case StrmSwPortType::TRACE: + // Port 0: core trace. Port 1: memory trace. + return 2; + case StrmSwPortType::CTRL: + return 1; + default: + return 0; + } +} + +bool isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, int dstChan) { + // Check Channel Id within the range + if (srcChan >= int(getNumSourceSwitchboxConnections(col, row, srcBundle))) + return false; + if (dstChan >= int(getNumDestSwitchboxConnections(col, row, dstBundle))) + return false; + + // Lambda function to check if a bundle is in a list + auto isBundleInList = [](StrmSwPortType bundle, + std::initializer_list bundles) { + return std::find(bundles.begin(), bundles.end(), bundle) != bundles.end(); + }; + + // Memtile + if (isMemTile(col, row)) { + if (srcBundle == StrmSwPortType::DMA) { + if (dstBundle == StrmSwPortType::DMA) return srcChan == dstChan; + if (isBundleInList(dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::SOUTH, + StrmSwPortType::NORTH})) + return true; + } + if (srcBundle == StrmSwPortType::CTRL) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 5; + if (isBundleInList(dstBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) + return true; + } + if (isBundleInList(srcBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) { + if (isBundleInList(dstBundle, + {StrmSwPortType::DMA, StrmSwPortType::CTRL})) + return true; + if (isBundleInList(dstBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) + return srcChan == dstChan; + } + if (srcBundle == StrmSwPortType::TRACE) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 5; + if (dstBundle == StrmSwPortType::SOUTH) return true; + } + } + // Shimtile + else if (isShimNOCorPLTile(col, row)) { + if (srcBundle == StrmSwPortType::CTRL) + return dstBundle != StrmSwPortType::CTRL; + if (isBundleInList(srcBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return isBundleInList( + dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::FIFO, StrmSwPortType::SOUTH, + StrmSwPortType::WEST, StrmSwPortType::NORTH, StrmSwPortType::EAST}); + if (isBundleInList(srcBundle, {StrmSwPortType::WEST, StrmSwPortType::NORTH, + StrmSwPortType::EAST})) + return (srcBundle == dstBundle) + ? (srcChan == dstChan) + : isBundleInList( + dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST}); + if (srcBundle == StrmSwPortType::TRACE) { + if (isBundleInList(dstBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return true; + if (isBundleInList(dstBundle, + {StrmSwPortType::WEST, StrmSwPortType::EAST})) + return dstChan == 0; + } + } + // Coretile + else if (isCoreTile(col, row)) { + if (isBundleInList(srcBundle, + {StrmSwPortType::DMA, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST})) + if (isBundleInList(dstBundle, + {StrmSwPortType::CORE, StrmSwPortType::DMA, + StrmSwPortType::CTRL, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST})) + return (srcBundle == dstBundle) ? (srcChan == dstChan) : true; + if (srcBundle == StrmSwPortType::CORE) + return dstBundle != StrmSwPortType::CORE; + if (srcBundle == StrmSwPortType::CTRL) + return dstBundle != StrmSwPortType::CTRL && + dstBundle != StrmSwPortType::DMA; + if (srcBundle == StrmSwPortType::TRACE) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 0; + if (isBundleInList(dstBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return true; + } + } + return false; +} + +} // namespace VE2802TargetModel +} // namespace MLIRAIELegacy + +bool mlir::iree_compiler::AMDAIE::AMDAIEDeviceModel::_isLegalTileConnection( + int col, int row, StrmSwPortType srcBundle, int srcChan, + StrmSwPortType dstBundle, int dstChan) const { + // Check Channel Id within the range + if (srcChan >= int(getNumSourceSwitchboxConnections(col, row, srcBundle))) + return false; + if (dstChan >= int(getNumDestSwitchboxConnections(col, row, dstBundle))) + return false; + + // Lambda function to check if a bundle is in a list + auto isBundleInList = [](StrmSwPortType bundle, + std::initializer_list bundles) { + return std::find(bundles.begin(), bundles.end(), bundle) != bundles.end(); + }; + + // Memtile + if (isMemTile(col, row)) { + if (srcBundle == StrmSwPortType::DMA) { + if (dstBundle == StrmSwPortType::DMA) return srcChan == dstChan; + if (isBundleInList(dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::SOUTH, + StrmSwPortType::NORTH})) + return true; + } + if (srcBundle == StrmSwPortType::CTRL) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 5; + if (isBundleInList(dstBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) + return true; + } + if (isBundleInList(srcBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) { + if (isBundleInList(dstBundle, + {StrmSwPortType::DMA, StrmSwPortType::CTRL})) + return true; + if (isBundleInList(dstBundle, + {StrmSwPortType::SOUTH, StrmSwPortType::NORTH})) + return srcChan == dstChan; + } + if (srcBundle == StrmSwPortType::TRACE) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 5; + if (dstBundle == StrmSwPortType::SOUTH) return true; + } + } + // Shimtile + else if (isShimNOCorPLTile(col, row)) { + if (srcBundle == StrmSwPortType::CTRL) + return dstBundle != StrmSwPortType::CTRL; + if (isBundleInList(srcBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return isBundleInList( + dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::FIFO, StrmSwPortType::SOUTH, + StrmSwPortType::WEST, StrmSwPortType::NORTH, StrmSwPortType::EAST}); + if (isBundleInList(srcBundle, {StrmSwPortType::WEST, StrmSwPortType::NORTH, + StrmSwPortType::EAST})) + return (srcBundle == dstBundle) + ? (srcChan == dstChan) + : isBundleInList( + dstBundle, + {StrmSwPortType::CTRL, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST}); + if (srcBundle == StrmSwPortType::TRACE) { + if (isBundleInList(dstBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return true; + if (isBundleInList(dstBundle, + {StrmSwPortType::WEST, StrmSwPortType::EAST})) + return dstChan == 0; + } + } + // Coretile + else if (isCoreTile(col, row)) { + if (isBundleInList(srcBundle, + {StrmSwPortType::DMA, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST})) + if (isBundleInList(dstBundle, + {StrmSwPortType::CORE, StrmSwPortType::DMA, + StrmSwPortType::CTRL, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST})) + return (srcBundle == dstBundle) ? (srcChan == dstChan) : true; + if (srcBundle == StrmSwPortType::CORE) + return dstBundle != StrmSwPortType::CORE; + if (srcBundle == StrmSwPortType::CTRL) + return dstBundle != StrmSwPortType::CTRL && + dstBundle != StrmSwPortType::DMA; + if (srcBundle == StrmSwPortType::TRACE) { + if (dstBundle == StrmSwPortType::DMA) return dstChan == 0; + if (isBundleInList(dstBundle, + {StrmSwPortType::FIFO, StrmSwPortType::SOUTH})) + return true; + } + } + return false; +} \ No newline at end of file diff --git a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h index 14814331f..1b5b2174a 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h +++ b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h @@ -45,7 +45,7 @@ struct TileLoc { int col, row; TileLoc(int col, int row) : col(col), row(row) {} - TileLoc() = delete; + TileLoc() = default; TileLoc(XAie_LocType loc) : col(loc.Col), row(loc.Row) {} operator XAie_LocType() const { return XAie_TileLoc(col, row); } @@ -96,6 +96,22 @@ enum class AMDAIEDmaProp : uint8_t { MAX = sizeof(struct XAie_DmaMod) }; +enum class StrmSwPortType : uint8_t { + CORE, + DMA, + CTRL, + FIFO, + SOUTH, + WEST, + NORTH, + EAST, + TRACE, + UCTRLR, + NOC, + PLIO, + SS_PORT_TYPE_MAX +}; + /* * This struct is meant to be a thin wrapper around aie-rt, which provides * the canonical representation/metadata for AIE devices; attributes like number @@ -122,8 +138,8 @@ struct AMDAIEDeviceModel { uint8_t devNColumns, uint8_t devNRows, uint8_t memTileRowStart, uint8_t nMemTileRows, uint8_t nShimTileRows, int partitionNumCols, - int partitionStartCol, bool aieSim, - bool xaieDebug); + int partitionStartCol, bool aieSim, bool xaieDebug, + AMDAIEDevice device); int rows() const; int columns() const; @@ -133,6 +149,7 @@ struct AMDAIEDeviceModel { bool isMemTile(uint8_t col, uint8_t row) const; bool isShimNOCTile(uint8_t col, uint8_t row) const; bool isShimPLTile(uint8_t col, uint8_t row) const; + bool isShimNOCorPLTile(uint8_t col, uint8_t row) const; /// Retrieve a DMA properpty for the specified tile type. template @@ -176,13 +193,28 @@ struct AMDAIEDeviceModel { StrmSwPortType bundle) const; uint32_t getNumDestSwitchboxConnections(uint8_t col, uint8_t row, StrmSwPortType bundle) const; + uint32_t getNumSourceShimMuxConnections(uint8_t col, uint8_t row, + StrmSwPortType bundle) const; + uint32_t getNumDestShimMuxConnections(uint8_t col, uint8_t row, + StrmSwPortType bundle) const; bool isLegalMemtileConnection(uint8_t col, uint8_t row, StrmSwPortType srcBundle, uint8_t srcChan, StrmSwPortType dstBundle, uint8_t dstChan) const; + bool isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, + int dstChan) const; uint32_t getColumnShift() const; uint32_t getRowShift() const; + + private: + AMDAIEDevice device; + uint32_t getNumShimMuxConnections(uint8_t col, uint8_t row, + StrmSwPortType bundle) const; + bool _isLegalTileConnection(int col, int row, StrmSwPortType srcBundle, + int srcChan, StrmSwPortType dstBundle, + int dstChan) const; }; struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device); @@ -195,6 +227,7 @@ struct AMDAIEDeviceModel getDeviceModel(AMDAIEDevice device); _(AMDAIEDmaProp) \ _(AieRC) \ _(StrmSwPortType) \ + _(::StrmSwPortType) \ _(TileLoc) \ _(XAie_LocType) \ _(XAie_Lock) \ @@ -213,6 +246,7 @@ TO_STRINGS(TO_STRING) _(OSTREAM_OP_, mlir::iree_compiler::AMDAIE::TileLoc) \ _(OSTREAM_OP_, AieRC) \ _(OSTREAM_OP_, StrmSwPortType) \ + _(OSTREAM_OP_, ::StrmSwPortType) \ _(OSTREAM_OP_, XAie_LocType) \ _(OSTREAM_OP_, XAie_Lock) \ _(OSTREAM_OP_, XAie_Packet) diff --git a/runtime/src/iree-amd-aie/aie_runtime/test/DeviceModelTest.cpp b/runtime/src/iree-amd-aie/aie_runtime/test/DeviceModelTest.cpp index c071e8821..cea5b1c3e 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/test/DeviceModelTest.cpp +++ b/runtime/src/iree-amd-aie/aie_runtime/test/DeviceModelTest.cpp @@ -30,6 +30,7 @@ extern "C" { using namespace mlir::iree_compiler::AMDAIE; namespace { +using mlir::iree_compiler::AMDAIE::StrmSwPortType; const std::map _STRM_SW_PORT_TYPE_TO_WIRE_BUNDLE = { {StrmSwPortType::CORE, xilinx::AIE::WireBundle::Core}, @@ -81,7 +82,7 @@ class AMDAIENPUDeviceModelParameterizedAllPairsTimesAllPairsNPU4ColTestFixture class AMDAIENPUDeviceModelParameterizedSixTupleNPU4ColTestFixture : public AMDAIENPUDeviceModelParameterizedTupleTestNPU4ColFixture< - int, int, int, int, int, int> {}; + int, int, StrmSwPortType, StrmSwPortType, int, int> {}; TEST(SameNumRowsCols_NPU1, Test0) { AMDAIEDeviceModel deviceModel = @@ -253,18 +254,18 @@ const std::map, std::tuple, std::less<>> NumSourceSwitchboxConnectionsFails{ // c, r, port, deviceModelNumSrc, targetModelNumSrc - {{0, 0, TRACE}, {2, 1}}, - // trace - {{1, 0, TRACE}, {2, 1}}, - {{2, 0, TRACE}, {2, 1}}, - {{3, 0, TRACE}, {2, 1}}, - {{4, 0, TRACE}, {2, 1}}, + {{0, 0, StrmSwPortType::TRACE}, {2, 1}}, + // traceStrmSwPortType:: + {{1, 0, StrmSwPortType::TRACE}, {2, 1}}, + {{2, 0, StrmSwPortType::TRACE}, {2, 1}}, + {{3, 0, StrmSwPortType::TRACE}, {2, 1}}, + {{4, 0, StrmSwPortType::TRACE}, {2, 1}}, // east - {{3, 0, EAST}, {4, 0}}, - {{3, 2, EAST}, {4, 0}}, - {{3, 3, EAST}, {4, 0}}, - {{3, 4, EAST}, {4, 0}}, - {{3, 5, EAST}, {4, 0}}}; + {{3, 0, StrmSwPortType::EAST}, {4, 0}}, + {{3, 2, StrmSwPortType::EAST}, {4, 0}}, + {{3, 3, StrmSwPortType::EAST}, {4, 0}}, + {{3, 4, StrmSwPortType::EAST}, {4, 0}}, + {{3, 5, StrmSwPortType::EAST}, {4, 0}}}; TEST_P( AMDAIENPUDeviceModelParameterizedAllPairsTimesAllSwitchesNPU4ColTestFixture, @@ -290,11 +291,11 @@ const std::map, std::tuple, std::less<>> NumDestSwitchboxConnectionsFails{ // c, r, port, deviceModelNumSrc, targetModelNumSrc - {{3, 0, EAST}, {4, 0}}, - {{3, 2, EAST}, {4, 0}}, - {{3, 3, EAST}, {4, 0}}, - {{3, 4, EAST}, {4, 0}}, - {{3, 5, EAST}, {4, 0}}}; + {{3, 0, StrmSwPortType::EAST}, {4, 0}}, + {{3, 2, StrmSwPortType::EAST}, {4, 0}}, + {{3, 3, StrmSwPortType::EAST}, {4, 0}}, + {{3, 4, StrmSwPortType::EAST}, {4, 0}}, + {{3, 5, StrmSwPortType::EAST}, {4, 0}}}; TEST_P( AMDAIENPUDeviceModelParameterizedAllPairsTimesAllSwitchesNPU4ColTestFixture, @@ -350,7 +351,7 @@ const std::vector> MEMTILE_CONNECTIVITY = { TEST_P(AMDAIENPUDeviceModelParameterizedMemtileConnectivityNPU4ColTestFixture, VerifyAIERTAIE2MemTileConnectivity) { auto [slavePhyPort, masterPhyPort] = GetParam(); - StrmSwPortType slaveLogicalPortType, masterLogicalPortType; + ::StrmSwPortType slaveLogicalPortType, masterLogicalPortType; uint8_t slaveLogicalPortNum, masterLogicalPortNum; XAie_LocType tileLoc = XAie_TileLoc(/*col=*/3, /*row=*/1); @@ -374,23 +375,24 @@ TEST_P(AMDAIENPUDeviceModelParameterizedMemtileConnectivityNPU4ColTestFixture, } // mlir-aie reports true when it should be false -const std::set> IsLegalMemtileConnectionFails{ - // srcPort, srcChan, dstPort, dstChan - {CTRL, 0, DMA, 0}, - // trace - {TRACE, 0, CTRL, 0}, - {TRACE, 0, DMA, 0}, - {TRACE, 0, DMA, 1}, - {TRACE, 0, DMA, 2}, - {TRACE, 0, DMA, 3}, - {TRACE, 0, DMA, 4}, - {TRACE, 0, NORTH, 0}, - {TRACE, 0, NORTH, 1}, - {TRACE, 0, NORTH, 2}, - {TRACE, 0, NORTH, 3}, - {TRACE, 0, NORTH, 4}, - {TRACE, 0, NORTH, 5}, -}; +const std::set> + IsLegalMemtileConnectionFails{ + // srcPort, srcChan, dstPort, dstChan + {StrmSwPortType::CTRL, 0, StrmSwPortType::DMA, 0}, + // trace + {StrmSwPortType::TRACE, 0, StrmSwPortType::CTRL, 0}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::DMA, 0}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::DMA, 1}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::DMA, 2}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::DMA, 3}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::DMA, 4}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 0}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 1}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 2}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 3}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 4}, + {StrmSwPortType::TRACE, 0, StrmSwPortType::NORTH, 5}, + }; TEST_P(AMDAIENPUDeviceModelParameterizedSixTupleNPU4ColTestFixture, IsLegalMemtileConnection) { @@ -399,12 +401,13 @@ TEST_P(AMDAIENPUDeviceModelParameterizedSixTupleNPU4ColTestFixture, // TODO(max): maybe there's a way in gtest for the generators to be // parameterized? - if ((srcStrmSwPortType == CTRL || destStrmSwPortType == CTRL) && + if ((srcStrmSwPortType == StrmSwPortType::CTRL || + destStrmSwPortType == StrmSwPortType::CTRL) && (srcChan > 0 || dstChan > 0)) return; - if (srcStrmSwPortType == TRACE && srcChan > 0) return; - if (srcStrmSwPortType == NORTH && srcChan > 3) return; - if (destStrmSwPortType == SOUTH && srcChan > 3) return; + if (srcStrmSwPortType == StrmSwPortType::TRACE && srcChan > 0) return; + if (srcStrmSwPortType == StrmSwPortType::NORTH && srcChan > 3) return; + if (destStrmSwPortType == StrmSwPortType::SOUTH && srcChan > 3) return; auto srcSw = static_cast(srcStrmSwPortType); auto srcWireB = STRM_SW_PORT_TYPE_TO_WIRE_BUNDLE(srcSw); @@ -416,8 +419,8 @@ TEST_P(AMDAIENPUDeviceModelParameterizedSixTupleNPU4ColTestFixture, auto targetModelIsLegal = targetModel.isLegalMemtileConnection( srcWireB, srcChan, destWireb, dstChan); - if ((srcStrmSwPortType == DMA && destStrmSwPortType == DMA && - srcChan != dstChan) || + if ((srcStrmSwPortType == StrmSwPortType::DMA && + destStrmSwPortType == StrmSwPortType::DMA && srcChan != dstChan) || IsLegalMemtileConnectionFails.count( {srcStrmSwPortType, srcChan, destStrmSwPortType, dstChan})) EXPECT_NE(deviceModelIsLegal, targetModelIsLegal) @@ -455,10 +458,14 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( AllPairsTimesAllSwitchesTests, AMDAIENPUDeviceModelParameterizedAllPairsTimesAllSwitchesNPU4ColTestFixture, - ::testing::Combine(::testing::Range(0, NPU1_4COL_NUM_COLS), - ::testing::Range(0, NPU1_4COL_NUM_ROWS), - ::testing::Values(CORE, DMA, CTRL, FIFO, SOUTH, WEST, - NORTH, EAST, TRACE))); + ::testing::Combine( + ::testing::Range(0, NPU1_4COL_NUM_COLS), + ::testing::Range(0, NPU1_4COL_NUM_ROWS), + ::testing::Values(StrmSwPortType::CORE, StrmSwPortType::DMA, + StrmSwPortType::CTRL, StrmSwPortType::FIFO, + StrmSwPortType::SOUTH, StrmSwPortType::WEST, + StrmSwPortType::NORTH, StrmSwPortType::EAST, + StrmSwPortType::TRACE))); INSTANTIATE_TEST_SUITE_P( VerifyAIERTAIE2MemTileConnectivity, @@ -470,8 +477,12 @@ INSTANTIATE_TEST_SUITE_P( #define MAX_CHANNELS 6 // Figure 6-9: Stream-switch ports and connectivity matrix -const std::vector legalSlaves{DMA, CTRL, SOUTH, NORTH, TRACE}; -const std::vector legalMasters{DMA, CTRL, SOUTH, NORTH}; +const std::vector legalSlaves{ + StrmSwPortType::DMA, StrmSwPortType::CTRL, StrmSwPortType::SOUTH, + StrmSwPortType::NORTH, StrmSwPortType::TRACE}; +const std::vector legalMasters{ + StrmSwPortType::DMA, StrmSwPortType::CTRL, StrmSwPortType::SOUTH, + StrmSwPortType::NORTH}; INSTANTIATE_TEST_SUITE_P( IsLegalMemtileConnectionTests, diff --git a/runtime/src/iree-amd-aie/aie_runtime/test/utest.cxx b/runtime/src/iree-amd-aie/aie_runtime/test/utest.cxx index 669d52fda..951ee6281 100755 --- a/runtime/src/iree-amd-aie/aie_runtime/test/utest.cxx +++ b/runtime/src/iree-amd-aie/aie_runtime/test/utest.cxx @@ -20,7 +20,8 @@ int main(int argc, char** argv) { XAIE2IPU_ROW_SHIFT, XAIE2IPU_NUM_COLS, XAIE2IPU_NUM_ROWS, XAIE2IPU_MEM_TILE_ROW_START, XAIE2IPU_MEM_TILE_NUM_ROWS, XAIE2IPU_SHIM_NUM_ROWS, partitionNumCols, partitionStartCol, - /*aieSim*/ false, /*xaieDebug*/ false); + /*aieSim*/ false, /*xaieDebug*/ false, + mlir::iree_compiler::AMDAIE::AMDAIEDevice::npu1_4col); XAie_LocType tile00 = {.Row = 0, .Col = col}; XAie_LocType tile01 = {.Row = 1, .Col = col}; XAie_LocType tile02 = {.Row = 2, .Col = col};