-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/pytorch geometric #216
Feature/pytorch geometric #216
Conversation
…orrect format is kept even with empty lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks @igabirondo16! The new to_pyg()
function looks good, though since the CI needs updating I have only validated it by reading the code.
I left a couple of nitpick change requests, nothing major.
Cheers,
Chris
Hi @ChrisCummins , thanks for accepting my changes! I have added the license and the tests for different graphs. For correctly ensuring that two graphs are different, instead of comparing the |
tests/to_pyg_test.py
Outdated
def assert_equal_graphs( | ||
graph1: HeteroData, | ||
graph2: HeteroData, | ||
equality: bool = True | ||
): | ||
if equality: | ||
assert graph1['nodes']['full_text'] == graph2['nodes']['full_text'] | ||
|
||
assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) | ||
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) | ||
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) | ||
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) | ||
|
||
else: | ||
text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text'] | ||
|
||
control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal( | ||
graph2['nodes', 'control', 'nodes'].edge_index | ||
) | ||
data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal( | ||
graph2['nodes', 'data', 'nodes'].edge_index | ||
) | ||
call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal( | ||
graph2['nodes', 'call', 'nodes'].edge_index | ||
) | ||
type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal( | ||
graph2['nodes', 'type', 'nodes'].edge_index | ||
) | ||
|
||
assert ( | ||
text_different | ||
or control_edges_different | ||
or data_edges_different | ||
or call_edges_different | ||
or type_edges_different | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may have misunderstood, but I believe you could simplify this to just a helper function that runs all your equality checks like:
def graphs_are_equal(
graph1: HeteroData,
graph2: HeteroData,
):
return (
(graph1['nodes']['full_text'] == graph2['nodes']['full_text'])
and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)))
and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index))
# ...
)
then in your tests:
assert graphs_are_equal(G1, G2)
assert not graphs_are_equal(G2, G3)
would that work?
Ops! You are totally right! I have corrected the code, now is shorter and easier to understand. |
Great stuff, thanks for your contribution @igabirondo16! |
Fixes issue #174.
Changes summary:
to_pyg()
method based ontorch_geometric.data.HeteroData
.to_pyg
in the python module.