From 68f886e7353ad2dd906168fa9f5736a7adc4d120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 9 May 2024 12:04:37 +0200 Subject: [PATCH 01/12] Add to_pyg() method --- programl/transform_ops.py | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 6c76734b..7e7e0d2c 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -18,11 +18,13 @@ """ import json import subprocess +import torch from typing import Any, Dict, Iterable, Optional, Union import dgl import networkx as nx from dgl.heterograph import DGLHeteroGraph +from torch_geometric.data import HeteroData from networkx.readwrite import json_graph as nx_json from programl.exceptions import GraphTransformError @@ -258,3 +260,59 @@ def _run_one(graph: ProgramGraph) -> str: if isinstance(graphs, ProgramGraph): return _run_one(graphs) return execute(_run_one, graphs, executor, chunksize) + +def to_pyg( + graphs: Union[ProgramGraph, Iterable[ProgramGraph]], + timeout: int = 300, + vocabulary: Dict[str, int] = None, + executor: Optional[ExecutorLike] = None, + chunksize: Optional[int] = None, +) -> Union[HeteroData, Iterable[HeteroData]]: + + def _run_one(graph: ProgramGraph) -> HeteroData: + # 3 lists, one per edge type + # (control, data and call edges) + adjacencies = [[], [], []] + edge_positions = [[], [], []] + + # Create the adjacency lists + for edge in graph.edge: + adjacencies[edge.flow].append([edge.source, edge.target]) + edge_positions[edge.flow].append(edge.position) + + vocab_ids = None + if vocabulary is not None: + vocab_ids = [ + vocabulary.get(node.text, len(vocabulary.keys())) + for node in graph.node + ] + + # Pass from list to tensor + adjacencies = [torch.tensor(adj_flow_type) for adj_flow_type in adjacencies] + edge_positions = [torch.tensor(edge_pos_flow_type) for edge_pos_flow_type in edge_positions] + + if vocabulary is not None: + vocab_ids = torch.tensor(vocab_ids) + + # Create the graph structure + hetero_graph = HeteroData() + + # Vocabulary index of each node + hetero_graph['nodes'].x = vocab_ids + + # Add the adjacency lists + hetero_graph['nodes', 'control', 'nodes'].edge_index = adjacencies[0].t().contiguous() + hetero_graph['nodes', 'data', 'nodes'].edge_index = adjacencies[1].t().contiguous() + hetero_graph['nodes', 'call', 'nodes'].edge_index = adjacencies[2].t().contiguous() + + # Add the edge positions + hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0] + hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1] + hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2] + + return hetero_graph + + if isinstance(graphs, ProgramGraph): + return _run_one(graphs) + + return execute(_run_one, graphs, executor, chunksize) From 45a758c17fa6bdbf570ed8e63927d98a82f81527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Fri, 10 May 2024 18:00:40 +0200 Subject: [PATCH 02/12] Document to_pyg() function --- programl/transform_ops.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 7e7e0d2c..64f97530 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -268,6 +268,34 @@ def to_pyg( executor: Optional[ExecutorLike] = None, chunksize: Optional[int] = None, ) -> Union[HeteroData, Iterable[HeteroData]]: + """Convert one or more Program Graphs to Pytorch-Geometrics's HeteroData. + This graphs can be used as input for any deep learning model built with + Pytorch-Geometric: + + https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html + + :param graphs: A Program Graph, or a sequence of Program Graphs. + + :param timeout: The maximum number of seconds to wait for an individual + graph conversion before raising an error. If multiple inputs are + provided, this timeout is per-input. + + :param executor: An executor object, with method :code:`submit(callable, + *args, **kwargs)` and returning a Future-like object with methods + :code:`done() -> bool` and :code:`result() -> float`. The executor role + is to dispatch the execution of the jobs locally/on a cluster/with + multithreading depending on the implementation. Eg: + :code:`concurrent.futures.ThreadPoolExecutor`. Defaults to single + threaded execution. This is only used when multiple inputs are given. + + :param chunksize: The number of inputs to read and process at a time. A + larger chunksize improves parallelism but increases memory consumption + as more inputs must be stored in memory. This is only used when multiple + inputs are given. + + :return: A HeteroData graph when a single input is provided, else an + iterable sequence of HeteroData graphs. + """ def _run_one(graph: ProgramGraph) -> HeteroData: # 3 lists, one per edge type From e392140c340293863ef67f519f4695fe8f00b4a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Sat, 11 May 2024 11:08:31 +0200 Subject: [PATCH 03/12] Update requirements --- programl/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/programl/requirements.txt b/programl/requirements.txt index db4d6937..3900bc9a 100644 --- a/programl/requirements.txt +++ b/programl/requirements.txt @@ -5,4 +5,5 @@ networkx>=2.4 numpy>=1.19.3 protobuf>=3.13.0,<4.21.0 torch>=1.8.0 +torch_geometric==2.4.0 tqdm>=4.38.0 From e11c7d16b3057632469fb5526c3cf4550c1dc11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 16 May 2024 10:34:33 +0200 Subject: [PATCH 04/12] Store the text of the nodes, the type edges and the ensure that the correct format is kept even with empty lists --- programl/transform_ops.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 64f97530..831189f3 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -261,6 +261,7 @@ def _run_one(graph: ProgramGraph) -> str: return _run_one(graphs) return execute(_run_one, graphs, executor, chunksize) + def to_pyg( graphs: Union[ProgramGraph, Iterable[ProgramGraph]], timeout: int = 300, @@ -300,14 +301,16 @@ def to_pyg( def _run_one(graph: ProgramGraph) -> HeteroData: # 3 lists, one per edge type # (control, data and call edges) - adjacencies = [[], [], []] - edge_positions = [[], [], []] + adjacencies = [[], [], [], []] + edge_positions = [[], [], [], []] # Create the adjacency lists for edge in graph.edge: adjacencies[edge.flow].append([edge.source, edge.target]) edge_positions[edge.flow].append(edge.position) + node_text = [node.text for node in graph.node] + vocab_ids = None if vocabulary is not None: vocab_ids = [ @@ -326,17 +329,28 @@ def _run_one(graph: ProgramGraph) -> HeteroData: hetero_graph = HeteroData() # Vocabulary index of each node + hetero_graph['nodes']['text'] = node_text hetero_graph['nodes'].x = vocab_ids # Add the adjacency lists - hetero_graph['nodes', 'control', 'nodes'].edge_index = adjacencies[0].t().contiguous() - hetero_graph['nodes', 'data', 'nodes'].edge_index = adjacencies[1].t().contiguous() - hetero_graph['nodes', 'call', 'nodes'].edge_index = adjacencies[2].t().contiguous() + hetero_graph['nodes', 'control', 'nodes'].edge_index = ( + adjacencies[0].t().contiguous() if adjacencies[0].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'data', 'nodes'].edge_index = ( + adjacencies[1].t().contiguous() if adjacencies[1].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'call', 'nodes'].edge_index = ( + adjacencies[2].t().contiguous() if adjacencies[2].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'type', 'nodes'].edge_index = ( + adjacencies[3].t().contiguous() if adjacencies[3].nelement() > 0 else torch.tensor([[], []]) + ) # Add the edge positions hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0] hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1] hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2] + hetero_graph['nodes', 'type', 'nodes'].edge_attr = edge_positions[3] return hetero_graph From 80b22c4f7dd9a5d49eeb689a792d57a2d5ebf932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 16 May 2024 10:35:26 +0200 Subject: [PATCH 05/12] Add to_pyg() to the module --- programl/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/programl/__init__.py b/programl/__init__.py index c049afb2..065068ab 100644 --- a/programl/__init__.py +++ b/programl/__init__.py @@ -52,7 +52,7 @@ to_bytes, to_string, ) -from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx +from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx, to_pyg from programl.util.py.runfiles_path import runfiles_path from programl.version import PROGRAML_VERSION @@ -84,6 +84,7 @@ "to_dot", "to_json", "to_networkx", + "to_pyg", "to_string", "UnsupportedCompiler", ] From 9c89f0deb41c466e7bafbe8a7d9e16b14b9c3105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 16 May 2024 10:37:00 +0200 Subject: [PATCH 06/12] Add tests for to_pyg() method --- tests/to_pyg_test.py | 82 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/to_pyg_test.py diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py new file mode 100644 index 00000000..ac99b318 --- /dev/null +++ b/tests/to_pyg_test.py @@ -0,0 +1,82 @@ + +from concurrent.futures.thread import ThreadPoolExecutor + +import networkx as nx +import pytest + +import programl as pg +from torch_geometric.data import HeteroData +from tests.test_main import main + +pytest_plugins = ["tests.plugins.llvm_program_graph"] + + +@pytest.fixture(scope="session") +def graph() -> pg.ProgramGraph: + return pg.from_cpp("int A() { return 0; }") + +def assert_equal_graphs( + graph1: HeteroData, + graph2: HeteroData +): + assert graph1['nodes']['text'] == graph2['nodes']['text'] + + assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) + assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) + assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) + assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) + +def test_to_pyg_simple_graph(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([graph])) + assert len(graphs) == 1 + assert isinstance(graphs[0], HeteroData) + + +def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph): + pyg_graph = pg.to_pyg(graph) + assert isinstance(pyg_graph, HeteroData) + + +def test_to_pyg_two_inputs(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([graph, graph])) + assert len(graphs) == 2 + assert_equal_graphs(graphs[0], graphs[1]) + +def test_to_pyg_generator(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3)) + assert len(graphs) == 10 + for x in graphs[1:]: + assert_equal_graphs(graphs[0], x) + + +def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): + with ThreadPoolExecutor() as executor: + graphs = list( + pg.to_pyg((graph for _ in range(10)), chunksize=3, executor=executor) + ) + assert len(graphs) == 10 + for x in graphs[1:]: + assert_equal_graphs(graphs[0], x) + + +def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([llvm_program_graph])) + + num_nodes = len(graphs[0]['nodes']['text']) + num_control_edges = graphs[0]['nodes', 'control', 'nodes'].edge_index.size(1) + num_data_edges = graphs[0]['nodes', 'data', 'nodes'].edge_index.size(1) + num_call_edges = graphs[0]['nodes', 'call', 'nodes'].edge_index.size(1) + num_type_edges = graphs[0]['nodes', 'type', 'nodes'].edge_index.size(1) + num_edges = num_control_edges + num_data_edges + num_call_edges + num_type_edges + + assert len(graphs) == 1 + assert isinstance(graphs[0], HeteroData) + assert num_nodes == len(llvm_program_graph.node) + assert num_edges <= len(llvm_program_graph.edge) + + +if __name__ == "__main__": + main() + + + From adfb51fb33ca32ebede6e5f6ca0420d2d580c20a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 16 May 2024 10:37:31 +0200 Subject: [PATCH 07/12] Include to_pyg() tests --- tests/BUILD | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 211ec444..e522a56d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -134,3 +134,14 @@ py_test( "//tests/plugins", ], ) + +py_test( + name = "to_pyg_test", + srcs = ["to_pyg_test.py"], + shard_count = 8, + deps = [ + "//programl", + "//tests:test_main", + "//tests/plugins", + ], +) From 1874e783867b096bba72655e6d95d6c3cfe7a1e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Thu, 16 May 2024 11:03:01 +0200 Subject: [PATCH 08/12] Update documentation of to_pyg() --- programl/transform_ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 831189f3..71fbc346 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -265,7 +265,7 @@ def _run_one(graph: ProgramGraph) -> str: def to_pyg( graphs: Union[ProgramGraph, Iterable[ProgramGraph]], timeout: int = 300, - vocabulary: Dict[str, int] = None, + vocabulary: Optional[Dict[str, int]] = None, executor: Optional[ExecutorLike] = None, chunksize: Optional[int] = None, ) -> Union[HeteroData, Iterable[HeteroData]]: @@ -281,6 +281,10 @@ def to_pyg( graph conversion before raising an error. If multiple inputs are provided, this timeout is per-input. + :param vocabulary: A dictionary containing ProGraML's vocabulary, where the + keys are the text attribute of the nodes and the values their respective + indexes. + :param executor: An executor object, with method :code:`submit(callable, *args, **kwargs)` and returning a Future-like object with methods :code:`done() -> bool` and :code:`result() -> float`. The executor role @@ -299,16 +303,17 @@ def to_pyg( """ def _run_one(graph: ProgramGraph) -> HeteroData: - # 3 lists, one per edge type - # (control, data and call edges) + # 4 lists, one per edge type + # (control, data, call and type edges) adjacencies = [[], [], [], []] edge_positions = [[], [], [], []] - # Create the adjacency lists + # Create the adjacency lists and the positions for edge in graph.edge: adjacencies[edge.flow].append([edge.source, edge.target]) edge_positions[edge.flow].append(edge.position) + # Store the text attributes node_text = [node.text for node in graph.node] vocab_ids = None From 988efda6e96bc7a1b960f488e206f0b1bdeab717 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Sat, 18 May 2024 10:25:29 +0200 Subject: [PATCH 09/12] Add license --- tests/to_pyg_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py index ac99b318..824eed62 100644 --- a/tests/to_pyg_test.py +++ b/tests/to_pyg_test.py @@ -1,4 +1,18 @@ - +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from concurrent.futures.thread import ThreadPoolExecutor import networkx as nx From 37b0fe2eef3e213e994da03c69124254c464a3de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Mon, 20 May 2024 07:52:11 +0200 Subject: [PATCH 10/12] Store the full text attribute of the nodes --- programl/transform_ops.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 71fbc346..bc9d86f7 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -314,7 +314,22 @@ def _run_one(graph: ProgramGraph) -> HeteroData: edge_positions[edge.flow].append(edge.position) # Store the text attributes - node_text = [node.text for node in graph.node] + node_text_list = [] + node_full_text_list = [] + + # Store the text and full text attributes + for node in graph.node: + node_text = node_full_text = node.text + + if ( + node.features + and node.features.feature["full_text"].bytes_list.value + ): + node_full_text = node.features.feature["full_text"].bytes_list.value[0] + + node_text_list.append(node_text) + node_full_text_list.append(node_full_text) + vocab_ids = None if vocabulary is not None: @@ -334,7 +349,8 @@ def _run_one(graph: ProgramGraph) -> HeteroData: hetero_graph = HeteroData() # Vocabulary index of each node - hetero_graph['nodes']['text'] = node_text + hetero_graph['nodes']['text'] = node_text_list + hetero_graph['nodes']['full_text'] = node_full_text_list hetero_graph['nodes'].x = vocab_ids # Add the adjacency lists From 2a93fe66f9300d10641a48fe90c2b9e74112c589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Mon, 20 May 2024 07:54:02 +0200 Subject: [PATCH 11/12] Add tests for different graphs --- tests/to_pyg_test.py | 82 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 67 insertions(+), 15 deletions(-) diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py index 824eed62..c42f2380 100644 --- a/tests/to_pyg_test.py +++ b/tests/to_pyg_test.py @@ -29,39 +29,94 @@ def graph() -> pg.ProgramGraph: return pg.from_cpp("int A() { return 0; }") +@pytest.fixture(scope="session") +def graph2() -> pg.ProgramGraph: + return pg.from_cpp("int B() { return 1; }") + +@pytest.fixture(scope="session") +def graph3() -> pg.ProgramGraph: + return pg.from_cpp("int B(int x) { return x + 1; }") + def assert_equal_graphs( graph1: HeteroData, - graph2: HeteroData + graph2: HeteroData, + equality: bool = True ): - assert graph1['nodes']['text'] == graph2['nodes']['text'] + if equality: + assert graph1['nodes']['full_text'] == graph2['nodes']['full_text'] + + assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) + assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) + assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) + assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) + + else: + text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text'] - assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) - assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) - assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) - assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) + control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal( + graph2['nodes', 'control', 'nodes'].edge_index + ) + data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal( + graph2['nodes', 'data', 'nodes'].edge_index + ) + call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal( + graph2['nodes', 'call', 'nodes'].edge_index + ) + type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal( + graph2['nodes', 'type', 'nodes'].edge_index + ) + + assert ( + text_different + or control_edges_different + or data_edges_different + or call_edges_different + or type_edges_different + ) def test_to_pyg_simple_graph(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph])) assert len(graphs) == 1 assert isinstance(graphs[0], HeteroData) - def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph): pyg_graph = pg.to_pyg(graph) assert isinstance(pyg_graph, HeteroData) +def test_to_pyg_different_two_different_inputs( + graph: pg.ProgramGraph, + graph2: pg.ProgramGraph, +): + pyg_graph = pg.to_pyg(graph) + pyg_graph2 = pg.to_pyg(graph2) + + # Ensure that the graphs are different + assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) + +def test_to_pyg_different_inputs( + graph: pg.ProgramGraph, + graph2: pg.ProgramGraph, + graph3: pg.ProgramGraph +): + pyg_graph = pg.to_pyg(graph) + pyg_graph2 = pg.to_pyg(graph2) + pyg_graph3 = pg.to_pyg(graph3) + + # Ensure that the graphs are different + assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) + assert_equal_graphs(pyg_graph, pyg_graph3, equality=False) + assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False) def test_to_pyg_two_inputs(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph, graph])) assert len(graphs) == 2 - assert_equal_graphs(graphs[0], graphs[1]) + assert_equal_graphs(graphs[0], graphs[1], equality=True) def test_to_pyg_generator(graph: pg.ProgramGraph): graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3)) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x) - + assert_equal_graphs(graphs[0], x, equality=True) def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): with ThreadPoolExecutor() as executor: @@ -70,7 +125,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): ) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x) + assert_equal_graphs(graphs[0], x, equality=True) def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): @@ -90,7 +145,4 @@ def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): if __name__ == "__main__": - main() - - - + main() \ No newline at end of file From a066ffb2f971388fe6867613a01a55ac62c3bdb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Mon, 20 May 2024 21:07:21 +0200 Subject: [PATCH 12/12] Refactor code for determining if two graphs are equal --- tests/to_pyg_test.py | 60 +++++++++++++------------------------------- 1 file changed, 17 insertions(+), 43 deletions(-) diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py index c42f2380..56cdfe57 100644 --- a/tests/to_pyg_test.py +++ b/tests/to_pyg_test.py @@ -37,42 +37,17 @@ def graph2() -> pg.ProgramGraph: def graph3() -> pg.ProgramGraph: return pg.from_cpp("int B(int x) { return x + 1; }") -def assert_equal_graphs( +def graphs_are_equal( graph1: HeteroData, graph2: HeteroData, - equality: bool = True ): - if equality: - assert graph1['nodes']['full_text'] == graph2['nodes']['full_text'] - - assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) - assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) - assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) - assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) - - else: - text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text'] - - control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal( - graph2['nodes', 'control', 'nodes'].edge_index - ) - data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal( - graph2['nodes', 'data', 'nodes'].edge_index - ) - call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal( - graph2['nodes', 'call', 'nodes'].edge_index - ) - type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal( - graph2['nodes', 'type', 'nodes'].edge_index - ) - - assert ( - text_different - or control_edges_different - or data_edges_different - or call_edges_different - or type_edges_different - ) + return ( + (graph1['nodes']['full_text'] == graph2['nodes']['full_text']) + and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)) + and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)) + and (graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)) + and (graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)) + ) def test_to_pyg_simple_graph(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph])) @@ -90,8 +65,8 @@ def test_to_pyg_different_two_different_inputs( pyg_graph = pg.to_pyg(graph) pyg_graph2 = pg.to_pyg(graph2) - # Ensure that the graphs are different - assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) def test_to_pyg_different_inputs( graph: pg.ProgramGraph, @@ -102,21 +77,21 @@ def test_to_pyg_different_inputs( pyg_graph2 = pg.to_pyg(graph2) pyg_graph3 = pg.to_pyg(graph3) - # Ensure that the graphs are different - assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) - assert_equal_graphs(pyg_graph, pyg_graph3, equality=False) - assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False) + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) + assert not graphs_are_equal(pyg_graph, pyg_graph3) + assert not graphs_are_equal(pyg_graph2, pyg_graph3) def test_to_pyg_two_inputs(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph, graph])) assert len(graphs) == 2 - assert_equal_graphs(graphs[0], graphs[1], equality=True) + assert graphs_are_equal(graphs[0], graphs[1]) def test_to_pyg_generator(graph: pg.ProgramGraph): graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3)) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x, equality=True) + assert graphs_are_equal(graphs[0], x) def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): with ThreadPoolExecutor() as executor: @@ -125,8 +100,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): ) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x, equality=True) - + assert graphs_are_equal(graphs[0], x) def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): graphs = list(pg.to_pyg([llvm_program_graph]))