diff --git a/src/graph_scheduler/utilities.py b/src/graph_scheduler/utilities.py index fda1d367..12debc39 100644 --- a/src/graph_scheduler/utilities.py +++ b/src/graph_scheduler/utilities.py @@ -4,8 +4,12 @@ import weakref from typing import Dict, Hashable, Set +import networkx as nx -__all__ = ['disable_debug_logging', 'enable_debug_logging'] +__all__ = [ + 'dependency_dict_to_networkx_digraph', 'disable_debug_logging', + 'enable_debug_logging', 'networkx_graph_to_dependency_dict', +] logger = logging.getLogger(__name__) @@ -146,3 +150,37 @@ def disable_debug_logging(level: int = logging.WARNING): f'Restoring root logger level to {logging.getLevelName(level)}' ) root_logger.setLevel(level) + + +def dependency_dict_to_networkx_digraph(graph: typing_graph_dependency_dict) -> nx.DiGraph: + """ + Converts a graph in dependency dict form to a networkx DiGraph + + Args: + graph: a graph in dependency dict form + + Returns: + networkx.DiGraph + """ + return nx.DiGraph(graph).reverse() + + +def networkx_graph_to_dependency_dict(graph: nx.Graph) -> typing_graph_dependency_dict: + """ + Converts a networkx Graph to a graph in dependency dict form + + Args: + graph: a networkx.Graph + + Returns: + a graph in dependency dict form + """ + res_graph = {} + for sender, receivers in graph.adj.items(): + if sender not in res_graph: + res_graph[sender] = set() + for rec in receivers: + if rec not in res_graph: + res_graph[rec] = set() + res_graph[rec].add(sender) + return res_graph