Skip to content

Commit

Permalink
working on UTF-8 automaton decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
koniksedy committed Nov 14, 2024
1 parent 5c6eadf commit 7761928
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 107 deletions.
17 changes: 8 additions & 9 deletions include/mata/nfa/nfa.hh
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@ public:
*/
Nfa& trim(StateRenaming* state_renaming = nullptr);

/**
* @brief Decodes automaton from UTF-8 encoding. Method removes unreachable states from delta.
*
* @return Decoded automaton.
*/

Nfa decode_utf8() const;

/**
* @brief Returns vector ret where ret[q] is the length of the shortest path from any initial state to q
*/
Expand Down Expand Up @@ -742,15 +750,6 @@ Nfa somewhat_simple_revert(const Nfa& aut);
// Removing epsilon transitions
Nfa remove_epsilon(const Nfa& aut, Symbol epsilon = EPSILON);

/**
* @brief Decodes automaton from UTF-8 encoding. Method removes unreachable states from delta.
*
* @param[in] aut Automaton to decode.
* @return Decoded automaton.
*/

Nfa decode_utf8(const Nfa& aut);

/** Encodes a vector of strings (each corresponding to one symbol) into a
* @c Word instance
*/
Expand Down
159 changes: 158 additions & 1 deletion src/nfa/nfa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ bool Nfa::is_flat() const {
mata::nfa::Nfa::TarjanDiscoverCallback callback {};
callback.scc_discover = [&](const std::vector<mata::nfa::State>& scc, const std::vector<mata::nfa::State>& tarjan_stack) -> bool {
(void)tarjan_stack;

for(const mata::nfa::State& st : scc) {
bool one_input_visited = false;
for (const mata::nfa::SymbolPost& sp : this->delta[st]) {
Expand Down Expand Up @@ -648,3 +648,160 @@ Nfa& Nfa::unite_nondet_with(const mata::nfa::Nfa& aut) {

return *this;
}

Nfa Nfa::decode_utf8() const {
// // Decodes UTF-8 like transitions starting from the given state.
// auto decode_utf8_trans = [&](const State state, const uint8_t first_byte) -> std::vector<SymbolPost> {
// // Determine the length of the UTF-8 prefix
// const size_t prefix_len = (first_byte >> 5 == 0b110) ? 3 :
// (first_byte >> 4 == 0b1110) ? 4 :
// (first_byte >> 3 == 0b11110) ? 5 : 0;
// assert(prefix_len > 0);
// uint8_t first_byte_data = first_byte & (0xff >> (prefix_len));
// size_t max_depth = prefix_len - 2;

// std::vector<SymbolPost> result;
// std::stack<std::tuple<State, Symbol, uint8_t>> worklist;
// worklist.push({state, first_byte_data, 0});
// // Inner limited depth DFS - combines multiple transitions into a single UTF-8 symbol
// while (!worklist.empty()) {
// std::tuple<State, Symbol, uint8_t> elem = worklist.top();
// worklist.pop();
// State src = std::get<0>(elem);
// Symbol symbol = std::get<1>(elem);
// uint8_t depth = std::get<2>(elem);
// assert(depth < max_depth);
// depth++;

// for (const SymbolPost &symbol_post : this->delta[src]) {
// const uint8_t symbol_prefix = static_cast<uint8_t>(symbol_post.symbol & 0xc0);
// assert(symbol_prefix == 0x80);
// const uint8_t symbol_data = static_cast<uint8_t>(symbol_post.symbol & 0x7f);
// symbol = (symbol << 6) | symbol_data;

// if (depth == max_depth) {
// // This is the last byte of the UTF-8 symbol.
// result.push_back(SymbolPost{symbol, symbol_post.targets});
// } else {
// // This is an intermediate byte of the UTF-8 symbol. Continue the DFS.
// for (State target : symbol_post.targets) {
// worklist.push({target, symbol, depth});
// }
// }
// }
// }

// return result;
// };

// const size_t num_of_states{ this->num_of_states() };
// Nfa result{ num_of_states, StateSet{this->initial}, StateSet{this->final} };
// mata::BoolVector used(num_of_states, false);

// std::stack<State> worklist;
// for (State state: this->initial) {
// worklist.push(state);
// used[state] = true;
// }

// // Outer DFS - traverses the automaton transitions
// while (!worklist.empty()) {
// State src = worklist.top();
// worklist.pop();
// StatePost &result_state_post = result.delta.mutable_state_post(src);
// for (const SymbolPost &symbol_post: this->delta[src]) {
// Symbol symbol = symbol_post.symbol;
// if (symbol & 0x80) {
// // It is an UTF-8 symbol
// const uint8_t first_byte = static_cast<uint8_t>(symbol);
// for (const State target: symbol_post.targets) {
// for (const SymbolPost &symbol_post_decoded: decode_utf8_trans(target, first_byte)) {
// // Insert decoded transitions
// result_state_post.insert(std::move(symbol_post_decoded));
// // Add targets to the worklist
// for (State target_decoded: symbol_post_decoded.targets) {
// if (used[target_decoded]) {
// continue;
// }
// used[target_decoded] = true;
// worklist.push(target_decoded);
// }
// }
// }
// } else {
// // It is standard ASCII symbol <0;127>
// result_state_post.insert(SymbolPost{symbol, symbol_post.targets});
// for (State target: symbol_post.targets) {
// if (used[target]) {
// continue;
// }
// used[target] = true;
// worklist.push(target);
// }
// }
// }
// }

Nfa result{ this->num_of_states(), StateSet{this->initial}, StateSet{this->final} };
BoolVector used(this->num_of_states(), false);
std::stack<State> worklist;

auto push_state_set = [&](const StateSet& set) {
for (State state: set) {
worklist.push(state);
used[state] = true;
}
};

push_state_set(StateSet{this->initial});
while (!worklist.empty()) {
State q1 = worklist.top();
worklist.pop();

// 1st Byte
for (const SymbolPost &sp1: this->delta[q1]) {
const Symbol s1 = sp1.symbol;
if ((s1 & 0x80) == 0x00) {
result.delta.add(q1, s1, sp1.targets);
push_state_set(sp1.targets);
continue;
}
// 2nd Byte
for (const State q2: sp1.targets) {
for (const SymbolPost &sp2: this->delta[q2]) {
const Symbol s2 = sp2.symbol;
if ((s1 & 0xE0) == 0xC0) {
assert((s2 & 0xC0) == 0x80);
result.delta.add(q1, ((s1 & 0x1F) << 6) | (s2 & 0x3F), sp2.targets);
push_state_set(sp2.targets);
continue;
}
// 3rd Byte
for (const State q3: sp2.targets) {
for (const SymbolPost &sp3: this->delta[q3]) {
const Symbol s3 = sp3.symbol;
if ((s1 & 0xF0) == 0xE0) {
assert((s3 & 0xC0) == 0x80);
result.delta.add(q1, ((s1 & 0x0F) << 12) | ((s2 & 0x3F) << 6) | (s3 & 0x3F), sp3.targets);
push_state_set(sp3.targets);
continue;
}
// 4th Byte
for (const State q4: sp3.targets) {
for (const SymbolPost &sp4: this->delta[q4]) {
const Symbol s4 = sp4.symbol;
assert((s1 & 0xF8) == 0xF0);
assert((s4 & 0xC0) == 0x80);
result.delta.add(q1, ((s1 & 0x07) << 18) | ((s2 & 0x3F) << 12) | ((s3 & 0x3F) << 6) | (s4 & 0x3F), sp4.targets);
push_state_set(sp4.targets);
}
}
}
}
}
}
}
}

return result;
}
96 changes: 0 additions & 96 deletions src/nfa/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1468,99 +1468,3 @@ std::optional<mata::Word> mata::nfa::get_word_from_lang_difference(const Nfa & n
return nfa_lang_difference.final.empty();
}).get_word();
}

Nfa mata::nfa::decode_utf8(const Nfa &aut) {
// Decodes UTF-8 like transitions starting from the given state.
auto decode_utf8_trans = [&](const State state, const uint8_t first_byte) -> std::vector<SymbolPost> {
// Determine the length of the UTF-8 prefix
const size_t prefix_len = (first_byte >> 5 == 0b110) ? 3 :
(first_byte >> 4 == 0b1110) ? 4 :
(first_byte >> 3 == 0b11110) ? 5 : 0;
assert(prefix_len > 0);
uint8_t first_byte_data = first_byte & (0xff >> (prefix_len));
size_t max_depth = prefix_len - 2;

std::vector<SymbolPost> result;
std::stack<std::tuple<State, Symbol, uint8_t>> worklist;
worklist.push({state, first_byte_data, 0});
// Inner limited depth DFS - combines multiple transitions into a single UTF-8 symbol
while (!worklist.empty()) {
std::tuple<State, Symbol, uint8_t> elem = worklist.top();
worklist.pop();
State src = std::get<0>(elem);
Symbol symbol = std::get<1>(elem);
uint8_t depth = std::get<2>(elem);
assert(depth < max_depth);
depth++;

for (const SymbolPost &symbol_post : aut.delta[src]) {
const uint8_t symbol_prefix = static_cast<uint8_t>(symbol_post.symbol & 0xc0);
assert(symbol_prefix == 0x80);
const uint8_t symbol_data = static_cast<uint8_t>(symbol_post.symbol & 0x7f);
symbol = (symbol << 6) | symbol_data;

if (depth == max_depth) {
// This is the last byte of the UTF-8 symbol.
result.push_back(SymbolPost{symbol, symbol_post.targets});
} else {
// This is an intermediate byte of the UTF-8 symbol. Continue the DFS.
for (State target : symbol_post.targets) {
worklist.push({target, symbol, depth});
}
}
}
}

return result;
};

const size_t num_of_states{ aut.num_of_states() };
Nfa result{ num_of_states, StateSet{aut.initial}, StateSet{aut.final} };
mata::BoolVector used(num_of_states, false);

std::stack<State> worklist;
for (State state: aut.initial) {
worklist.push(state);
used[state] = true;
}

// Outer DFS - traverses the automaton transitions
while (!worklist.empty()) {
State src = worklist.top();
worklist.pop();
StatePost &result_state_post = result.delta.mutable_state_post(src);
for (const SymbolPost &symbol_post: aut.delta[src]) {
Symbol symbol = symbol_post.symbol;
if (symbol & 0x80) {
// It is an UTF-8 symbol
const uint8_t first_byte = static_cast<uint8_t>(symbol);
for (const State target: symbol_post.targets) {
for (const SymbolPost &symbol_post_decoded: decode_utf8_trans(target, first_byte)) {
// Insert decoded transitions
result_state_post.insert(std::move(symbol_post_decoded));
// Add targets to the worklist
for (State target_decoded: symbol_post_decoded.targets) {
if (used[target_decoded]) {
continue;
}
used[target_decoded] = true;
worklist.push(target_decoded);
}
}
}
} else {
// It is standard ASCII symbol <0;127>
result_state_post.insert(SymbolPost{symbol, symbol_post.targets});
for (State target: symbol_post.targets) {
if (used[target]) {
continue;
}
used[target] = true;
worklist.push(target);
}
}
}
}

return result;
}
Loading

0 comments on commit 7761928

Please sign in to comment.