Skip to content

Commit

Permalink
[HW] Add Passes: hw-expunge-module, hw-tree-shake
Browse files Browse the repository at this point in the history
  • Loading branch information
CircuitCoder committed Jan 26, 2025
1 parent b951ce7 commit 83e57c4
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
std::unique_ptr<mlir::Pass> createFooWiresPass();
std::unique_ptr<mlir::Pass> createHWAggregateToCombPass();
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
std::unique_ptr<mlir::Pass> createHWTreeShakePass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
32 changes: 32 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
}

def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
let description = [{
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
and an random suffix would be added to the <instance_path> part.

For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
}];
let constructor = "circt::hw::createHWExpungeModulePass()";

let options = [
ListOption<"modules", "modules", "std::string",
"Comma separated list of module names to be removed from the hierarchy.">,
ListOption<"portPrefixes", "port-prefixes", "std::string",
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
];
}

def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
let summary = "Remove unused modules.";
let description = [{
This pass removes all modules besides a specified list of modules and their transitive dependencies.
}];
let constructor = "circt::hw::createHWTreeShakePass()";

let options = [
ListOption<"keep", "keep", "std::string",
"Comma separated list of module names to be kept.">,
];
}

/**
* Tutorial Pass, doesn't do anything interesting
*/
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ add_circt_dialect_library(CIRCTHWTransforms
VerifyInnerRefNamespace.cpp
FlattenModules.cpp
FooWires.cpp
HWExpungeModule.cpp
HWTreeShake.cpp

DEPENDS
CIRCTHWTransformsIncGen
Expand Down
314 changes: 314 additions & 0 deletions lib/Dialect/HW/Transforms/HWExpungeModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWPasses.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/HW/HWInstanceGraph.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ImmutableList.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Support/Regex.h"
#include <numeric>

namespace circt {
namespace hw {
#define GEN_PASS_DEF_HWEXPUNGEMODULE
#include "circt/Dialect/HW/Passes.h.inc"
} // namespace hw
} // namespace circt

namespace {
struct HWExpungeModulePass
: circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
void runOnOperation() override;
};

struct InstPathSeg {
llvm::StringRef seg;

InstPathSeg(llvm::StringRef seg) : seg(seg) {}
const llvm::StringRef &getSeg() const { return seg; }
operator llvm::StringRef() const { return seg; }

void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(seg); }
};
using InstPath = llvm::ImmutableList<InstPathSeg>;
std::string defaultPrefix(InstPath path) {
std::string accum;
while (!path.isEmpty()) {
accum += path.getHead().getSeg();
accum += "_";
path = path.getTail();
}
accum += "_";
return std::move(accum);
}

// The regex for port prefix specification
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
// manually parse the spec This parser may accept identifiers with invalid
// characters

std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
std::string>
parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
auto [l, r] = in.split("=");
if (r == "")
return "No '=' found in input";
auto [ll, lr] = l.split(":");
if (lr == "")
return "No ':' found before '='";
llvm::SmallVector<llvm::StringRef, 4> segs;
while (lr != "") {
auto [seg, rest] = lr.split(".");
segs.push_back(seg);
lr = rest;
}
InstPath path;
for (auto &seg : llvm::reverse(segs))
path = listFac.add(seg, path);
return std::make_tuple(ll, path, r);
}
} // namespace

void HWExpungeModulePass::runOnOperation() {
auto root = getOperation();
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
root.walk(
[&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; });

// The instance graph. We only use this graph to traverse the hierarchy in post order.
// The order does not change throught out the operation, onlygets weakened, but still valid.
// So we keep this cached instance graph throughout the pass.
auto &instanceGraph = getAnalysis<circt::hw::InstanceGraph>();

// Instance path.
InstPath::Factory pathFactory;

// Process port prefix specifications
// (Module name, Instance path) -> Prefix
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
designatedPrefixes;
bool containsFailure = false;
for (const auto &raw : portPrefixes) {
auto matched = parsePrefixSpec(raw, pathFactory);
if (std::holds_alternative<std::string>(matched)) {
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
containsFailure = true;
continue;
}

auto [module, path, prefix] =
std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
matched);
if (!allModules.contains(module)) {
llvm::errs() << "Module not found in port prefix specification: "
<< module << "\n";
llvm::errs() << "From specification: " << raw << "\n";
containsFailure = true;
continue;
}

// Skip checking instance paths' existence. Non-existent paths are ignored
designatedPrefixes.insert({{module, path}, prefix});
}

if (containsFailure)
return signalPassFailure();

// Instance path * prefix name
using ReplacedDescendent = std::pair<InstPath, std::string>;
// This map holds the expunged descendents of a module
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
expungedDescendents;
for (auto &expunging : this->modules) {
// Clear expungedDescendents
for (auto &it : expungedDescendents)
it.getSecond().clear();

auto expungingMod = allModules.lookup(expunging);
if (!expungingMod)
continue; // Ignored missing modules
auto expungingModTy = expungingMod.getHWModuleType();
auto expungingModPorts = expungingModTy.getPorts();

auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
const std::string &prefix,
auto genOutput, auto emitInput) {
mlir::OpBuilder builder(mod);
// Create ports using *REVERSE* direction of their definitions
for (auto &port : expungingModPorts) {
auto defaultName = prefix + port.name.getValue();
auto finalName = defaultName;
if (port.dir == circt::hw::PortInfo::Input) {
auto val = genOutput(port);
assert(val.getType() == port.type);
mod.appendOutput(finalName, val);
} else if (port.dir == circt::hw::PortInfo::Output) {
auto [_, arg] = mod.appendInput(finalName, port.type);
emitInput(port, arg);
}
}
};

for (auto &instGraphNode : llvm::post_order(&instanceGraph)) {
// Skip extmodule and intmodule because they cannot contain anything
circt::hw::HWModuleOp processing =
llvm::dyn_cast_if_present<circt::hw::HWModuleOp>(
instGraphNode->getModule().getOperation());
if (!processing)
continue;

std::optional<decltype(expungedDescendents.lookup("")) *>
outerExpDescHold = {};
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
if (!outerExpDescHold.has_value())
outerExpDescHold = {
&expungedDescendents.insert({processing.getName(), {}})
.first->getSecond()};
return **outerExpDescHold;
};

mlir::OpBuilder outerBuilder(processing);

processing.walk([&](circt::hw::InstanceOp inst) {
auto instName = inst.getInstanceName();
auto instMod = allModules.lookup(inst.getModuleName());

if (instMod.getOutputNames().size() != inst.getResults().size() ||
instMod.getNumInputPorts() != inst.getInputs().size()) {
// Module have been modified during this pass, create new instances
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
assert(instMod.getNumInputPorts() >= inst.getInputs().size());

auto instModInTypes = instMod.getInputTypes();

llvm::SmallVector<mlir::Value> newInputs;
newInputs.reserve(instMod.getNumInputPorts());

outerBuilder.setInsertionPointAfter(inst);

// Appended inputs are at the end of the input list
for (size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
mlir::Value input;
if (i < inst.getNumInputPorts()) {
input = inst.getInputs()[i];
if (auto existingName = inst.getInputName(i))
assert(existingName == instMod.getInputName(i));
} else {
input =
outerBuilder
.create<mlir::UnrealizedConversionCastOp>(
inst.getLoc(), instModInTypes[i], mlir::ValueRange{})
.getResult(0);
}
newInputs.push_back(input);
}

auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
inst.getLoc(), instMod, inst.getInstanceNameAttr(), newInputs,
inst.getParameters(),
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({}));

for (size_t i = 0; i < inst.getNumResults(); ++i)
assert(inst.getOutputName(i) == instMod.getOutputName(i));
inst.replaceAllUsesWith(
newInst.getResults().slice(0, inst.getNumResults()));
inst.erase();
inst = newInst;
}

llvm::StringMap<mlir::Value> instOMap;
llvm::StringMap<mlir::Value> instIMap;
assert(instMod.getOutputNames().size() == inst.getResults().size());
for (auto [oname, oval] :
llvm::zip(instMod.getOutputNames(), inst.getResults()))
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
assert(instMod.getInputNames().size() == inst.getInputs().size());
for (auto [iname, ival] :
llvm::zip(instMod.getInputNames(), inst.getInputs()))
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;

// Get outer expunged descendent first because it may modify the map and
// invalidate iterators.
auto &outerExpDesc = getOuterExpDesc();
auto instExpDesc = expungedDescendents.find(inst.getModuleName());

if (inst.getModuleName() == expunging) {
// Handle the directly expunged module
// input maps also useful for directly expunged instance

auto singletonPath = pathFactory.create(instName);

auto designatedPrefix =
designatedPrefixes.find({processing.getName(), singletonPath});
std::string prefix = designatedPrefix != designatedPrefixes.end()
? designatedPrefix->getSecond().str()
: (instName + "__").str();

// Port name collision is still possible, but current relying on MLIR
// to automatically rename input arguments.
// TODO: name collision detect

createPortsOn(
processing, prefix,
[&](circt::hw::ModulePort port) {
// Generate output for outer module, so input for us
return instIMap.at(port.name);
},
[&](circt::hw::ModulePort port, mlir::Value val) {
// Generated input for outer module, replace inst results
assert(instOMap.contains(port.name));
instOMap[port.name].replaceAllUsesWith(val);
});

outerExpDesc.emplace_back(singletonPath, prefix);

assert(instExpDesc == expungedDescendents.end() ||
instExpDesc->getSecond().size() == 0);
inst.erase();
} else if (instExpDesc != expungedDescendents.end()) {
// Handle all transitive descendents
if (instExpDesc->second.size() == 0)
return;
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
for (const auto &exp : instExpDesc->second) {
auto newPath = pathFactory.add(instName, exp.first);
auto designatedPrefix =
designatedPrefixes.find({processing.getName(), newPath});
std::string prefix = designatedPrefix != designatedPrefixes.end()
? designatedPrefix->getSecond().str()
: defaultPrefix(newPath);

// TODO: name collision detect

createPortsOn(
processing, prefix,
[&](circt::hw::ModulePort port) {
// Generate output for outer module, directly forward from
// inner inst
return instOMap.at((exp.second + port.name.getValue()).str());
},
[&](circt::hw::ModulePort port, mlir::Value val) {
// Generated input for outer module, replace inst results.
// The operand in question has to be an backedge
auto in =
instIMap.at((exp.second + port.name.getValue()).str());
auto inDef = in.getDefiningOp();
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
in.replaceAllUsesWith(val);
inDef->erase();
});

outerExpDesc.emplace_back(newPath, prefix);
}
}
});
}
}
}

std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
return std::make_unique<HWExpungeModulePass>();
}
Loading

0 comments on commit 83e57c4

Please sign in to comment.