Skip to content

Commit

Permalink
[HW] Add Passes: hw-expunge-module, hw-tree-shake
Browse files Browse the repository at this point in the history
  • Loading branch information
CircuitCoder committed Jan 12, 2025
1 parent 0d35c61 commit 1b1428a
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
std::unique_ptr<mlir::Pass> createFooWiresPass();
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
std::unique_ptr<mlir::Pass> createHWTreeShakePass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
32 changes: 32 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
}

def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
let description = [{
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
and an random suffix would be added to the <instance_path> part.

For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
}];
let constructor = "circt::hw::createHWExpungeModulePass()";

let options = [
ListOption<"modules", "modules", "std::string",
"Comma separated list of module names to be removed from the hierarchy.">,
ListOption<"portPrefixes", "port-prefixes", "std::string",
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
];
}

def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
let summary = "Remove unused modules.";
let description = [{
This pass removes all modules besides a specified list of modules and their transitive dependencies.
}];
let constructor = "circt::hw::createHWTreeShakePass()";

let options = [
ListOption<"keep", "keep", "std::string",
"Comma separated list of module names to be kept.">,
];
}

/**
* Tutorial Pass, doesn't do anything interesting
*/
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ add_circt_dialect_library(CIRCTHWTransforms
VerifyInnerRefNamespace.cpp
FlattenModules.cpp
FooWires.cpp
HWExpungeModule.cpp
HWTreeShake.cpp

DEPENDS
CIRCTHWTransformsIncGen
Expand Down
297 changes: 297 additions & 0 deletions lib/Dialect/HW/Transforms/HWExpungeModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWPasses.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ImmutableList.h"
#include "llvm/Support/Regex.h"

namespace circt {
namespace hw {
#define GEN_PASS_DEF_HWEXPUNGEMODULE
#include "circt/Dialect/HW/Passes.h.inc"
} // namespace hw
} // namespace circt

namespace {
struct HWExpungeModulePass : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
void runOnOperation() override;
};

struct InstPathSeg {
llvm::StringRef seg;

InstPathSeg(llvm::StringRef seg) : seg(seg) {}
const llvm::StringRef& getSeg() const { return seg; }
operator llvm::StringRef() const { return seg; }

void Profile(llvm::FoldingSetNodeID &ID) const {
ID.AddString(seg);
}
};
using InstPath = llvm::ImmutableList<InstPathSeg>;
std::string defaultPrefix(InstPath path) {
std::string accum;
while(!path.isEmpty()) {
accum += path.getHead().getSeg();
accum += "_";
path = path.getTail();
}
accum += "_";
return std::move(accum);
}

// The regex for port prefix specification
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so manually parse the spec
// This parser may accept identifiers with invalid characters

std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>, std::string> parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
auto [l, r] = in.split("=");
if (r == "") return "No '=' found in input";
auto [ll, lr] = l.split(":");
if (lr == "") return "No ':' found before '='";
llvm::SmallVector<llvm::StringRef, 4> segs;
while(lr != "") {
auto [seg, rest] = lr.split(".");
segs.push_back(seg);
lr = rest;
}
InstPath path;
for(auto &seg : llvm::reverse(segs))
path = listFac.add(seg, path);
return std::make_tuple(ll, path, r);
}
} // namespace

void HWExpungeModulePass::runOnOperation() {
auto root = getOperation();
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
root.walk([&](circt::hw::HWModuleLike mod) {
allModules[mod.getName()] = mod;
});
// Reverse-topo order generated by module instances.
// The topo sorted order does not change throught out the operation. It only gets weakened, but
// still valid.
llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
llvm::DenseSet<mlir::StringRef> visited;
auto visit = [&allModules, &visited, &revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
if(visited.contains(mod.getName())) return;
visited.insert(mod.getName());

mod.walk([&](circt::hw::InstanceOp inst) {
auto instModName = inst.getModuleName();
auto instMod = allModules.lookup(instModName);
if(!instMod) inst.emitError("Unknown module ") << instModName;
self(self, instMod);
});

// Reverse topo order, insert on stack pop
revTopo.push_back(mod);
};
for(auto &[_, mod] : allModules) visit(visit, mod);
assert(revTopo.size() == allModules.size());

// Instance path.
InstPath::Factory pathFactory;

// Process port prefix specifications
// (Module name, Instance path) -> Prefix
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef> designatedPrefixes;
bool containsFailure = false;
for(const auto &raw : portPrefixes) {
auto matched = parsePrefixSpec(raw, pathFactory);
if(std::holds_alternative<std::string>(matched)) {
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
containsFailure = true;
continue;
}

auto [module, path, prefix] = std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(matched);
if(!allModules.contains(module)) {
llvm::errs() << "Module not found in port prefix specification: " << module << "\n";
llvm::errs() << "From specification: " << raw << "\n";
containsFailure = true;
continue;
}

// Skip checking instance paths' existence. Non-existent paths are ignored
designatedPrefixes.insert({{module, path}, prefix});
}

if(containsFailure) return signalPassFailure();

// Instance path * prefix name
using ReplacedDescendent = std::pair<InstPath, std::string>;
// This map holds the expunged descendents of a module
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>> expungedDescendents;
for(auto &expunging : this->modules) {
// Clear expungedDescendents
for(auto &it : expungedDescendents)
it.getSecond().clear();

auto expungingMod = allModules.lookup(expunging);
if(!expungingMod) continue; // Ignored missing modules
auto expungingModTy = expungingMod.getHWModuleType();
auto expungingModPorts = expungingModTy.getPorts();

auto createPortsOn = [&expungingModPorts](
circt::hw::HWModuleOp mod,
const std::string &prefix,
auto genOutput,
auto emitInput
) {
mlir::OpBuilder builder(mod);
// Create ports using *REVERSE* direction of their definitions
for(auto &port : expungingModPorts) {
auto defaultName = prefix + port.name.getValue();
auto finalName = defaultName;
if(port.dir == circt::hw::PortInfo::Input) {
auto val = genOutput(port);
assert(val.getType() == port.type);
mod.appendOutput(finalName, val);
} else if(port.dir == circt::hw::PortInfo::Output) {
auto [_, arg] = mod.appendInput(finalName, port.type);
emitInput(port, arg);
}
}
};

for(auto &processingRaw : revTopo) {
// Skip extmodule and intmodule because they cannot contain anything
if(!llvm::isa<circt::hw::HWModuleOp>(processingRaw)) continue;
circt::hw::HWModuleOp processing = llvm::cast<circt::hw::HWModuleOp>(processingRaw);

std::optional<decltype(expungedDescendents.lookup("")) *> outerExpDescHold = {};
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
if(!outerExpDescHold.has_value())
outerExpDescHold = { &expungedDescendents.insert({processing.getName(), {}}).first->getSecond() };
return **outerExpDescHold;
};

mlir::OpBuilder outerBuilder(processing);

processing.walk([&](circt::hw::InstanceOp inst) {
auto instName = inst.getInstanceName();
auto instMod = allModules.lookup(inst.getModuleName());

if(
instMod.getOutputNames().size() != inst.getResults().size()
|| instMod.getNumInputPorts() != inst.getInputs().size()
) {
// Module have been modified during this pass, create new instances
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
assert(instMod.getNumInputPorts() >= inst.getInputs().size());

auto instModInTypes = instMod.getInputTypes();

llvm::SmallVector<mlir::Value> newInputs;
newInputs.reserve(instMod.getNumInputPorts());

outerBuilder.setInsertionPointAfter(inst);

// Appended inputs are at the end of the input list
for(size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
mlir::Value input;
if(i < inst.getNumInputPorts()) {
input = inst.getInputs()[i];
if(auto existingName = inst.getInputName(i))
assert(existingName == instMod.getInputName(i));
} else {
input = outerBuilder.create<mlir::UnrealizedConversionCastOp>(
inst.getLoc(), instModInTypes[i], mlir::ValueRange{}).getResult(0);
}
newInputs.push_back(input);
}

auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
inst.getLoc(),
instMod,
inst.getInstanceNameAttr(),
newInputs,
inst.getParameters(),
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({})
);

for(size_t i = 0; i < inst.getNumResults(); ++i)
assert(inst.getOutputName(i) == instMod.getOutputName(i));
inst.replaceAllUsesWith(newInst.getResults().slice(0, inst.getNumResults()));
inst.erase();
inst = newInst;
}

llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
assert(instMod.getOutputNames().size() == inst.getResults().size());
for(auto [oname, oval] : llvm::zip(instMod.getOutputNames(), inst.getResults()))
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
assert(instMod.getInputNames().size() == inst.getInputs().size());
for(auto [iname, ival] : llvm::zip(instMod.getInputNames(), inst.getInputs()))
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;

// Get outer expunged descendent first because it may modify the map and invalidate iterators.
auto &outerExpDesc = getOuterExpDesc();
auto instExpDesc = expungedDescendents.find(inst.getModuleName());

if(inst.getModuleName() == expunging) {
// Handle the directly expunged module
// input maps also useful for directly expunged instance

auto singletonPath = pathFactory.create(instName);

auto designatedPrefix = designatedPrefixes.find({processing.getName(), singletonPath});
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : (instName + "__").str();

// Port name collision is still possible, but current relying on MLIR
// to automatically rename input arguments.
// TODO: name collision detect

createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
// Generate output for outer module, so input for us
return instIMap.at(port.name);
}, [&](circt::hw::ModulePort port, mlir::Value val) {
// Generated input for outer module, replace inst results
assert(instOMap.contains(port.name));
instOMap[port.name].replaceAllUsesWith(val);
});

outerExpDesc.emplace_back(singletonPath, prefix);

assert(instExpDesc == expungedDescendents.end() || instExpDesc->getSecond().size() == 0);
inst.erase();
} else if(instExpDesc != expungedDescendents.end()) {
// Handle all transitive descendents
if(instExpDesc->second.size() == 0) return;
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
for(const auto &exp : instExpDesc->second) {
auto newPath = pathFactory.add(instName, exp.first);
auto designatedPrefix = designatedPrefixes.find({processing.getName(), newPath});
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : defaultPrefix(newPath);

// TODO: name collision detect

createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
// Generate output for outer module, directly forward from inner inst
return instOMap.at((exp.second + port.name.getValue()).str());
}, [&](circt::hw::ModulePort port, mlir::Value val) {
// Generated input for outer module, replace inst results.
// The operand in question has to be an backedge
auto in = instIMap.at((exp.second + port.name.getValue()).str());
auto inDef = in.getDefiningOp();
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
in.replaceAllUsesWith(val);
inDef->erase();
});

outerExpDesc.emplace_back(newPath, prefix);
}
}
});
}
}
}

std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
return std::make_unique<HWExpungeModulePass>();
}
Loading

0 comments on commit 1b1428a

Please sign in to comment.