diff --git a/project/cfg_utils.py b/project/cfg_utils.py new file mode 100644 index 000000000..fde6cadf5 --- /dev/null +++ b/project/cfg_utils.py @@ -0,0 +1,18 @@ +from pyformlang.cfg import CFG, Production, Epsilon + + +def cfg_to_weak_normal_form(cfg: CFG) -> CFG: + cnf = cfg.to_normal_form() + cnf_productions = cnf.productions + + cfg_nullable_symbols = cfg.get_nullable_symbols() + + wcnf_productions = set() + + for production in cnf_productions: + wcnf_productions.add(production) + + for sym in cfg_nullable_symbols: + wcnf_productions.add(Production(sym, [Epsilon()], filtering=False)) + + return CFG(start_symbol=cfg.start_symbol, productions=wcnf_productions) diff --git a/project/cfpq.py b/project/cfpq.py new file mode 100644 index 000000000..367c3d450 --- /dev/null +++ b/project/cfpq.py @@ -0,0 +1,60 @@ +from pyformlang.cfg import CFG, Terminal +import networkx as nx + +from project.cfg_utils import cfg_to_weak_normal_form + + +def hellings_based_cfpq( + cfg: CFG, + graph: nx.DiGraph, + start_nodes: set[int] = None, + final_nodes: set[int] = None, +) -> set[tuple[int, int]]: + wcnf = cfg_to_weak_normal_form(cfg) + wcnf_productions = wcnf.productions + wcnf_nullable_symbols = wcnf.get_nullable_symbols() + + m = set() + r = set() + + edges = graph.edges(data="label") + + for u, v, label in edges: + for production in wcnf_productions: + if Terminal(label) in production.body: + m.add((production.head, u, v)) + r.add((production.head, u, v)) + + for node in graph.nodes: + for sym in wcnf_nullable_symbols: + m.add((sym, node, node)) + r.add((sym, node, node)) + + while m: + (N, a, b) = m.pop() + for M, c, d in r.copy(): + if b == c: + for production in wcnf_productions: + if [N, M] == production.body: + tr = (production.head, a, d) + if tr not in r: + m.add(tr) + r.add(tr) + if a == d: + for production in wcnf_productions: + if [M, N] == production.body: + tr = (production.head, c, b) + if tr not in r: + m.add(tr) + r.add(tr) + + result = set() + for N, a, b in r: + if ( + N == wcnf.start_symbol + and (not start_nodes or a in start_nodes) + and (not final_nodes or b in final_nodes) + ): + result.add((a, b)) + + return result diff --git a/tests/autotests/rpq_template_test.py b/tests/autotests/rpq_template_test.py index f57e27a2a..e98e78975 100644 --- a/tests/autotests/rpq_template_test.py +++ b/tests/autotests/rpq_template_test.py @@ -6,7 +6,7 @@ from typing import Callable, Iterable try: - from project.task4 import ms_bfs_based_rpq + from project.rpq import ms_bfs_based_rpq except ImportError: pass diff --git a/tests/autotests/test_task06.py b/tests/autotests/test_task06.py index 01b5339f5..d0276ca4f 100644 --- a/tests/autotests/test_task06.py +++ b/tests/autotests/test_task06.py @@ -9,7 +9,7 @@ # Fix import statements in try block to run tests try: - from project.task6 import hellings_based_cfpq + from project.cfpq import hellings_based_cfpq except ImportError: pytestmark = pytest.mark.skip("Task 6 is not ready to test!") @@ -30,8 +30,8 @@ def test_different_grammars_hellings(self, graph, eq_grammars): def test_cfg_to_weak_normal_form_exists(): try: - import project.task6 + import project.cfg_utils - assert "cfg_to_weak_normal_form" in dir(project.task6) + assert "cfg_to_weak_normal_form" in dir(project.cfg_utils) except NameError: assert False