Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pytorch geometric #216

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion programl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
to_bytes,
to_string,
)
from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx
from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx, to_pyg
from programl.util.py.runfiles_path import runfiles_path
from programl.version import PROGRAML_VERSION

Expand Down Expand Up @@ -84,6 +84,7 @@
"to_dot",
"to_json",
"to_networkx",
"to_pyg",
"to_string",
"UnsupportedCompiler",
]
1 change: 1 addition & 0 deletions programl/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ networkx>=2.4
numpy>=1.19.3
protobuf>=3.13.0,<4.21.0
torch>=1.8.0
torch_geometric==2.4.0
tqdm>=4.38.0
121 changes: 121 additions & 0 deletions programl/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
"""
import json
import subprocess
import torch
from typing import Any, Dict, Iterable, Optional, Union

import dgl
import networkx as nx
from dgl.heterograph import DGLHeteroGraph
from torch_geometric.data import HeteroData
from networkx.readwrite import json_graph as nx_json

from programl.exceptions import GraphTransformError
Expand Down Expand Up @@ -258,3 +260,122 @@ def _run_one(graph: ProgramGraph) -> str:
if isinstance(graphs, ProgramGraph):
return _run_one(graphs)
return execute(_run_one, graphs, executor, chunksize)


def to_pyg(
graphs: Union[ProgramGraph, Iterable[ProgramGraph]],
timeout: int = 300,
vocabulary: Optional[Dict[str, int]] = None,
executor: Optional[ExecutorLike] = None,
chunksize: Optional[int] = None,
) -> Union[HeteroData, Iterable[HeteroData]]:
"""Convert one or more Program Graphs to Pytorch-Geometrics's HeteroData.
This graphs can be used as input for any deep learning model built with
Pytorch-Geometric:

https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html

:param graphs: A Program Graph, or a sequence of Program Graphs.

:param timeout: The maximum number of seconds to wait for an individual
graph conversion before raising an error. If multiple inputs are
provided, this timeout is per-input.

:param vocabulary: A dictionary containing ProGraML's vocabulary, where the
keys are the text attribute of the nodes and the values their respective
indexes.

:param executor: An executor object, with method :code:`submit(callable,
*args, **kwargs)` and returning a Future-like object with methods
:code:`done() -> bool` and :code:`result() -> float`. The executor role
is to dispatch the execution of the jobs locally/on a cluster/with
multithreading depending on the implementation. Eg:
:code:`concurrent.futures.ThreadPoolExecutor`. Defaults to single
threaded execution. This is only used when multiple inputs are given.

:param chunksize: The number of inputs to read and process at a time. A
larger chunksize improves parallelism but increases memory consumption
as more inputs must be stored in memory. This is only used when multiple
inputs are given.

:return: A HeteroData graph when a single input is provided, else an
iterable sequence of HeteroData graphs.
"""

def _run_one(graph: ProgramGraph) -> HeteroData:
# 4 lists, one per edge type
# (control, data, call and type edges)
adjacencies = [[], [], [], []]
edge_positions = [[], [], [], []]

# Create the adjacency lists and the positions
for edge in graph.edge:
adjacencies[edge.flow].append([edge.source, edge.target])
edge_positions[edge.flow].append(edge.position)

# Store the text attributes
node_text_list = []
node_full_text_list = []

# Store the text and full text attributes
for node in graph.node:
node_text = node_full_text = node.text

if (
node.features
and node.features.feature["full_text"].bytes_list.value
):
node_full_text = node.features.feature["full_text"].bytes_list.value[0]

node_text_list.append(node_text)
node_full_text_list.append(node_full_text)


vocab_ids = None
if vocabulary is not None:
vocab_ids = [
vocabulary.get(node.text, len(vocabulary.keys()))
for node in graph.node
]

# Pass from list to tensor
adjacencies = [torch.tensor(adj_flow_type) for adj_flow_type in adjacencies]
edge_positions = [torch.tensor(edge_pos_flow_type) for edge_pos_flow_type in edge_positions]

if vocabulary is not None:
vocab_ids = torch.tensor(vocab_ids)

# Create the graph structure
hetero_graph = HeteroData()

# Vocabulary index of each node
hetero_graph['nodes']['text'] = node_text_list
hetero_graph['nodes']['full_text'] = node_full_text_list
hetero_graph['nodes'].x = vocab_ids

# Add the adjacency lists
hetero_graph['nodes', 'control', 'nodes'].edge_index = (
adjacencies[0].t().contiguous() if adjacencies[0].nelement() > 0 else torch.tensor([[], []])
)
hetero_graph['nodes', 'data', 'nodes'].edge_index = (
adjacencies[1].t().contiguous() if adjacencies[1].nelement() > 0 else torch.tensor([[], []])
)
hetero_graph['nodes', 'call', 'nodes'].edge_index = (
adjacencies[2].t().contiguous() if adjacencies[2].nelement() > 0 else torch.tensor([[], []])
)
hetero_graph['nodes', 'type', 'nodes'].edge_index = (
adjacencies[3].t().contiguous() if adjacencies[3].nelement() > 0 else torch.tensor([[], []])
)

# Add the edge positions
hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0]
hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1]
hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2]
hetero_graph['nodes', 'type', 'nodes'].edge_attr = edge_positions[3]

return hetero_graph

if isinstance(graphs, ProgramGraph):
return _run_one(graphs)

return execute(_run_one, graphs, executor, chunksize)
11 changes: 11 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,14 @@ py_test(
"//tests/plugins",
],
)

py_test(
name = "to_pyg_test",
srcs = ["to_pyg_test.py"],
shard_count = 8,
deps = [
"//programl",
"//tests:test_main",
"//tests/plugins",
],
)
148 changes: 148 additions & 0 deletions tests/to_pyg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2019-2020 the ProGraML authors.
#
# Contact Chris Cummins <[email protected]>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures.thread import ThreadPoolExecutor

import networkx as nx
import pytest

import programl as pg
from torch_geometric.data import HeteroData
from tests.test_main import main

pytest_plugins = ["tests.plugins.llvm_program_graph"]


@pytest.fixture(scope="session")
def graph() -> pg.ProgramGraph:
return pg.from_cpp("int A() { return 0; }")

@pytest.fixture(scope="session")
def graph2() -> pg.ProgramGraph:
return pg.from_cpp("int B() { return 1; }")

@pytest.fixture(scope="session")
def graph3() -> pg.ProgramGraph:
return pg.from_cpp("int B(int x) { return x + 1; }")

def assert_equal_graphs(
graph1: HeteroData,
graph2: HeteroData,
equality: bool = True
):
if equality:
assert graph1['nodes']['full_text'] == graph2['nodes']['full_text']

assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)

else:
text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text']

control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal(
graph2['nodes', 'control', 'nodes'].edge_index
)
data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal(
graph2['nodes', 'data', 'nodes'].edge_index
)
call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal(
graph2['nodes', 'call', 'nodes'].edge_index
)
type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal(
graph2['nodes', 'type', 'nodes'].edge_index
)

assert (
text_different
or control_edges_different
or data_edges_different
or call_edges_different
or type_edges_different
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have misunderstood, but I believe you could simplify this to just a helper function that runs all your equality checks like:

def graphs_are_equal(
    graph1: HeteroData,
    graph2: HeteroData,
):
    return (
        (graph1['nodes']['full_text'] == graph2['nodes']['full_text'])
        and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)))
        and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index))
        # ...
    )

then in your tests:

assert graphs_are_equal(G1, G2)
assert not graphs_are_equal(G2, G3)

would that work?


def test_to_pyg_simple_graph(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([graph]))
assert len(graphs) == 1
assert isinstance(graphs[0], HeteroData)

def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph):
pyg_graph = pg.to_pyg(graph)
assert isinstance(pyg_graph, HeteroData)

def test_to_pyg_different_two_different_inputs(
graph: pg.ProgramGraph,
graph2: pg.ProgramGraph,
):
pyg_graph = pg.to_pyg(graph)
pyg_graph2 = pg.to_pyg(graph2)

# Ensure that the graphs are different
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)

def test_to_pyg_different_inputs(
graph: pg.ProgramGraph,
graph2: pg.ProgramGraph,
graph3: pg.ProgramGraph
):
pyg_graph = pg.to_pyg(graph)
pyg_graph2 = pg.to_pyg(graph2)
pyg_graph3 = pg.to_pyg(graph3)

# Ensure that the graphs are different
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)
assert_equal_graphs(pyg_graph, pyg_graph3, equality=False)
assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False)

def test_to_pyg_two_inputs(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([graph, graph]))
assert len(graphs) == 2
assert_equal_graphs(graphs[0], graphs[1], equality=True)

def test_to_pyg_generator(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3))
assert len(graphs) == 10
for x in graphs[1:]:
assert_equal_graphs(graphs[0], x, equality=True)

def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
with ThreadPoolExecutor() as executor:
graphs = list(
pg.to_pyg((graph for _ in range(10)), chunksize=3, executor=executor)
)
assert len(graphs) == 10
for x in graphs[1:]:
assert_equal_graphs(graphs[0], x, equality=True)


def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([llvm_program_graph]))

num_nodes = len(graphs[0]['nodes']['text'])
num_control_edges = graphs[0]['nodes', 'control', 'nodes'].edge_index.size(1)
num_data_edges = graphs[0]['nodes', 'data', 'nodes'].edge_index.size(1)
num_call_edges = graphs[0]['nodes', 'call', 'nodes'].edge_index.size(1)
num_type_edges = graphs[0]['nodes', 'type', 'nodes'].edge_index.size(1)
num_edges = num_control_edges + num_data_edges + num_call_edges + num_type_edges

assert len(graphs) == 1
assert isinstance(graphs[0], HeteroData)
assert num_nodes == len(llvm_program_graph.node)
assert num_edges <= len(llvm_program_graph.edge)


if __name__ == "__main__":
main()