Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 142 additions & 70 deletions passes/techmap/extract_fa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

#include "kernel/yosys.h"
#include "kernel/sigtools.h"
#include "backends/rtlil/rtlil_backend.h"
#include "kernel/consteval.h"
#include <cstdio>
#include <mutex>
#include <thread>

USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN
Expand Down Expand Up @@ -66,9 +70,6 @@ struct ExtractFaWorker
dict<tuple<SigBit, SigBit>, dict<int, pool<SigBit>>> func2;
dict<tuple<SigBit, SigBit, SigBit>, dict<int, pool<SigBit>>> func3;

int count_func2;
int count_func3;

struct func2_and_info_t {
bool inv_a, inv_b, inv_y;
};
Expand All @@ -77,6 +78,26 @@ struct ExtractFaWorker
bool inv_a, inv_b, inv_c, inv_y;
};

struct Counters {
int count_func2;
int count_func3;
};

struct ThreadData {
size_t start;
size_t end;
Counters counters;
ExtractFaWorker* instance;
std::stringstream log_buffer;
std::vector<RTLIL::IdString> ports = {ID::A, ID::B, ID::C, ID::D};
pool<tuple<SigBit, SigBit>> tl_xorxnor2;
pool<tuple<SigBit, SigBit, SigBit>> tl_xorxnor3;
pool<tuple<tuple<SigBit, SigBit>, int, SigBit>> tl_func_2;
pool<tuple<tuple<SigBit, SigBit, SigBit>,int, SigBit>> tl_func_3;
};

std::mutex consteval_mtx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The declaration of this mutex is far removed from the data that is actually protected by it. If raw mutexes are used at all that should be right next to each other.

When using a mutex like this, there is also nothing preventing or even hinting at an issue when introducing new ce accesses that are not protected by the appropriate lock guard. This makes it way too easy to introduce bugs that can be very hard to debug. For that reason I'm inclined to require use of higher level primitives within passes.

We could e.g. add our own Mutex<T> that combines a std::mutex mutex and a T value and only provides access via a lock method that passes out our own MutexGuard<T> that combines a std::lock_guard guard and a T &value ensuring that you only get to access the shared value while you hold the lock. (Unless you explicitly store a reference elsewhere, but that's always a hazard and not specific to multi-threading, making it somewhat easier to spot.) This is more or less the same API that Rust provides but of course there's nothing that stops this approach from being implemented in C++. The first example I found is folly's Synchronized, which also goes into a bit more detail motivating the use of this API over what std::mutex provides.


dict<int, func2_and_info_t> func2_and_info;
dict<int, func3_maj_info_t> func3_maj_info;

Expand Down Expand Up @@ -153,7 +174,7 @@ struct ExtractFaWorker
}
}

void check_partition(SigBit root, pool<SigBit> &leaves)
void check_partition(SigBit root, pool<SigBit> &leaves, ThreadData& data)
{
if (config.enable_ha && GetSize(leaves) == 2)
{
Expand All @@ -163,35 +184,38 @@ struct ExtractFaWorker
SigBit B = SigSpec(leaves)[1];

int func = 0;
for (int i = 0; i < 4; i++)
{
bool a_value = (i & 1) != 0;
bool b_value = (i & 2) != 0;
std::lock_guard lock(consteval_mtx);
for (int i = 0; i < 4; i++)
{
bool a_value = (i & 1) != 0;
bool b_value = (i & 2) != 0;

ce.push();
ce.set(A, a_value ? State::S1 : State::S0);
ce.set(B, b_value ? State::S1 : State::S0);
SigSpec sig = root;

ce.push();
ce.set(A, a_value ? State::S1 : State::S0);
ce.set(B, b_value ? State::S1 : State::S0);
if (!ce.eval(sig)) {
ce.pop();
return;
}

SigSpec sig = root;
if (sig == State::S1)
func |= 1 << i;

if (!ce.eval(sig)) {
ce.pop();
return;
}

if (sig == State::S1)
func |= 1 << i;

ce.pop();
}

// log("%04d %s %s -> %s\n", bindec(func), log_signal(A), log_signal(B), log_signal(root));

if (func == xor2_func || func == xnor2_func)
xorxnor2.insert(tuple<SigBit, SigBit>(A, B));
data.tl_xorxnor2.insert(tuple<SigBit, SigBit>(A, B));

count_func2++;
func2[tuple<SigBit, SigBit>(A, B)][func].insert(root);
data.counters.count_func2++;
data.tl_func_2.insert(
tuple<tuple<SigBit, SigBit>, int, SigBit>(tuple<SigBit, SigBit>(A, B), func, root)
);
}

if (config.enable_fa && GetSize(leaves) == 3)
Expand All @@ -203,52 +227,70 @@ struct ExtractFaWorker
SigBit C = SigSpec(leaves)[2];

int func = 0;
for (int i = 0; i < 8; i++)
{
bool a_value = (i & 1) != 0;
bool b_value = (i & 2) != 0;
bool c_value = (i & 4) != 0;

ce.push();
ce.set(A, a_value ? State::S1 : State::S0);
ce.set(B, b_value ? State::S1 : State::S0);
ce.set(C, c_value ? State::S1 : State::S0);

SigSpec sig = root;
std::lock_guard lock(consteval_mtx);
for (int i = 0; i < 8; i++)
{
bool a_value = (i & 1) != 0;
bool b_value = (i & 2) != 0;
bool c_value = (i & 4) != 0;

ce.push();
ce.set(A, a_value ? State::S1 : State::S0);
ce.set(B, b_value ? State::S1 : State::S0);
ce.set(C, c_value ? State::S1 : State::S0);
SigSpec sig = root;

if (!ce.eval(sig)) {
ce.pop();
return;
}
if (sig == State::S1)
func |= 1 << i;

if (!ce.eval(sig)) {
ce.pop();
return;
}

if (sig == State::S1)
func |= 1 << i;

ce.pop();
}

// log("%08d %s %s %s -> %s\n", bindec(func), log_signal(A), log_signal(B), log_signal(C), log_signal(root));

if (func == xor3_func || func == xnor3_func)
xorxnor3.insert(tuple<SigBit, SigBit, SigBit>(A, B, C));
data.tl_xorxnor3.insert(tuple<SigBit, SigBit, SigBit>(A, B, C));

count_func3++;
func3[tuple<SigBit, SigBit, SigBit>(A, B, C)][func].insert(root);
data.counters.count_func3++;
data.tl_func_3.insert(
tuple<tuple<SigBit, SigBit, SigBit>, int, SigBit>(tuple<SigBit, SigBit, SigBit>(A, B, C), func, root)
);
}
}
void partition_log_cache(std::stringstream& stream, int depth, SigBit signal, bool format_depth=false) {
std::stringstream buf;
RTLIL_BACKEND::dump_sigspec(buf, signal, true);
if(format_depth) {
// at most, this is going to take in maxdepth spaces + 2 brackets + 2 numbers + space + \0
std::vector<char> spacer_buffer;
spacer_buffer.resize(config.maxdepth + 6);
snprintf(spacer_buffer.data(), config.maxdepth + 6, "%*s[%d] ", config.maxdepth-depth, "", depth);
stream << spacer_buffer.data();
}

stream << " " << buf.str();
if(format_depth)
stream << ":";
}

void find_partitions(SigBit root, pool<SigBit> &leaves, pool<pool<SigBit>> &cache, int maxdepth, int maxbreadth)
void find_partitions(SigBit root, pool<SigBit> &leaves, pool<pool<SigBit>> &cache, int maxdepth, int maxbreadth, ThreadData& data)
{
if (cache.count(leaves))
return;

// log("%*s[%d] %s:", 20-maxdepth, "", maxdepth, log_signal(root));
// for (auto bit : leaves)
// log(" %s", log_signal(bit));
// log("\n");
partition_log_cache(data.log_buffer, maxdepth, root, true);
for (auto bit : leaves)
partition_log_cache(data.log_buffer, maxdepth, bit);
data.log_buffer << "\n";

cache.insert(leaves);
check_partition(root, leaves);
check_partition(root, leaves, data);

if (maxdepth == 0)
return;
Expand All @@ -262,7 +304,8 @@ struct ExtractFaWorker
pool<SigBit> new_leaves = leaves;

new_leaves.erase(bit);
for (auto port : {ID::A, ID::B, ID::C, ID::D}) {

for (auto port : data.ports) {
if (!cell->hasPort(port))
continue;
auto bit = sigmap(SigBit(cell->getPort(port)));
Expand All @@ -274,7 +317,7 @@ struct ExtractFaWorker
if (GetSize(new_leaves) > maxbreadth)
continue;

find_partitions(root, new_leaves, cache, maxdepth-1, maxbreadth);
find_partitions(root, new_leaves, cache, maxdepth-1, maxbreadth, data);
}
}

Expand All @@ -290,29 +333,55 @@ struct ExtractFaWorker
void run()
{
log("Extracting full/half adders from %s:\n", log_id(module));
const size_t num_threads = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
std::vector<ThreadData> thread_data(num_threads);

for (auto it : driver)
{
if (it.second->type.in(ID($_BUF_), ID($_NOT_)))
continue;
size_t total_elements = driver.size();
size_t thread_elements = total_elements / num_threads;
for (size_t i = 0; i < num_threads; ++i) {
thread_data[i].start = i * thread_elements;
thread_data[i].end = (i == num_threads - 1) ? total_elements : (i + 1) * thread_elements;
thread_data[i].instance = this;

threads.emplace_back([&data = thread_data[i]]() {
auto& driver = data.instance->driver;
auto& config = data.instance->config;

SigBit root = it.first;
pool<SigBit> leaves = { root };
pool<pool<SigBit>> cache;
for (size_t i = data.start; i < data.end; ++i) {
const auto& it = *driver.element(i);
if (it.second->type.in(ID($_BUF_), ID($_NOT_)))
continue;

if (config.verbose)
log(" checking %s\n", log_signal(it.first));
SigBit root = it.first;
pool<SigBit> leaves = { root };
pool<pool<SigBit>> cache;

count_func2 = 0;
count_func3 = 0;
if (config.verbose)
log(" checking %s\n", log_signal(it.first));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even read-only access for RTLIL data structures currently isn't thread safe. There are a lot more places where worker thread RTLIL access happens, but I picked this as the most obvious one. Until we are able to change that, only the main thread can access RTLIL. To ensure we can maintain this, we also require that this is made obvious by not handing RTLIL references to code running on worker threads in the first place. See #5266 (comment) for a recent discussion on the requirements for adding multi-threaded code to Yosys and the corresponding PR for an example of what is currently possible.

I also think if we are adding multi-threading, we should prefer using work queues to dynamically balance the workload instead of statically splitting it like this PR currently does. The PR I linked to introduces some primitives for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I left the above comment, there was a discussion on providing thread-safe read-only RTLIL access. While the thread starts with a proposal for a thread-safe alternative API, the conclusion was that we will try to make all const RTLIL methods thread-safe eventually. I'm not sure how long it will take to get us there, but compared to what I've been asking for above, that should significantly lower the barrier for adding parallel processing to passes.


find_partitions(root, leaves, cache, config.maxdepth, config.maxbreadth);
data.instance->find_partitions(root, leaves, cache, config.maxdepth, config.maxbreadth, data);
// log("%s", log_buffer.str().c_str());

if (config.verbose && count_func2 > 0)
log(" extracted %d two-input functions\n", count_func2);
if (config.verbose && data.counters.count_func2 > 0)
log(" extracted %d two-input functions\n", data.counters.count_func2);

if (config.verbose && data.counters.count_func3 > 0)
log(" extracted %d three-input functions\n", data.counters.count_func3);
}
});
}

if (config.verbose && count_func3 > 0)
log(" extracted %d three-input functions\n", count_func3);
for (size_t i = 0; i < num_threads; ++i) {
threads[i].join();
for(auto& x3 : thread_data[i].tl_xorxnor3)
xorxnor3.insert(x3);
for(auto& x2 : thread_data[i].tl_xorxnor2)
xorxnor2.insert(x2);
for(auto& f3 : thread_data[i].tl_func_3)
func3[get<0>(f3)][get<1>(f3)].insert(get<2>(f3));
for(auto& f2 : thread_data[i].tl_func_2)
func2[get<0>(f2)][get<1>(f2)].insert(get<2>(f2));
}

for (auto &key : xorxnor3)
Expand Down Expand Up @@ -341,10 +410,13 @@ struct ExtractFaWorker
int func = it.first;
auto f3i = it.second;

int xor_cnt, xnor_cnt;
xor_cnt = func3.at(key).count(xor3_func);
xnor_cnt = func3.at(key).count(xnor3_func);
if (func3.at(key).count(func) == 0)
continue;

if (func3.at(key).count(xor3_func) == 0 && func3.at(key).count(xnor3_func) != 0) {
if (xor_cnt == 0 && xnor_cnt != 0) {
f3i.inv_a = !f3i.inv_a;
f3i.inv_b = !f3i.inv_b;
f3i.inv_c = !f3i.inv_c;
Expand Down Expand Up @@ -413,13 +485,13 @@ struct ExtractFaWorker
}

bool invert_y = f3i.inv_a ^ f3i.inv_b ^ f3i.inv_c;
if (func3.at(key).count(xor3_func)) {
if (xor_cnt) {
SigBit YY = invert_xy ^ invert_y ? module->NotGate(NEW_ID, Y) : Y;
for (auto bit : func3.at(key).at(xor3_func))
assign_new_driver(bit, YY);
}

if (func3.at(key).count(xnor3_func)) {
if (xnor_cnt) {
SigBit YY = invert_xy ^ invert_y ? Y : module->NotGate(NEW_ID, Y);
for (auto bit : func3.at(key).at(xnor3_func))
assign_new_driver(bit, YY);
Expand Down