Skip to content

Commit

Permalink
Fix: Not trying to add edges in case their either to or from vertices… (
Browse files Browse the repository at this point in the history
#38)

* Fix: Not trying to add edges in case their either to or from vertices are not known via metagraph definition

* fmt

* new: `test_dblp_pyg`

* fix: bump version

---------

Co-authored-by: Anthony Mahanna <[email protected]>
  • Loading branch information
hkernbach and aMahanna authored Sep 18, 2024
1 parent 9f2cc5f commit 585d5f6
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "phenolrs"
version = "0.5.7"
version = "0.5.8"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "phenolrs"
version = "0.5.7"
version = "0.5.8"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Rust",
Expand Down
3 changes: 3 additions & 0 deletions python/phenolrs/pyg/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def load_into_pyg_heterodata(
for col in features_by_col.keys():
col_mapping = vertex_cols_source_to_output[col]
for feature in features_by_col[col].keys():
if feature == "@collection_name":
continue

target_name = col_mapping[feature]
result = torch.from_numpy(
features_by_col[col][feature].astype(np.float64)
Expand Down
10 changes: 10 additions & 0 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def load_imdb(imdb_db_name: str, connection_information: Dict[str, Any]) -> None
load_dataset("IMDB_PLATFORM", imdb_db_name, connection_information)


@pytest.fixture(scope="module")
def load_dblp(dblp_db_name: str, connection_information: Dict[str, Any]) -> None:
load_dataset("DBLP", dblp_db_name, connection_information)


@pytest.fixture(scope="module")
def abide_db_name() -> str:
return "abide"
Expand All @@ -75,6 +80,11 @@ def imdb_db_name() -> str:
return "imdb"


@pytest.fixture(scope="module")
def dblp_db_name() -> str:
return "dblp"


@pytest.fixture(scope="module")
def custom_graph_db_name() -> str:
return "custom_graph"
Expand Down
88 changes: 88 additions & 0 deletions python/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,94 @@ def test_imdb_pyg(
assert edges["edge_index"].shape == (2, 100000)


def test_dblp_pyg(
load_dblp: None,
dblp_db_name: str,
connection_information: dict[str, str],
) -> None:
metagraph_1 = {
"vertexCollections": {
"author": {"x": "x"},
"paper": {"x": "x"},
"term": {"x": "x"},
"conference": {},
},
"edgeCollections": {
"to": {},
},
}

result_1 = PygLoader.load_into_pyg_heterodata(
dblp_db_name,
metagraph_1,
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)

metagraph_2 = {
"vertexCollections": {
"author": {"x": "x"},
"paper": {"x": "x"},
"term": {"x": "x"},
},
"edgeCollections": {
"to": {},
},
}

result_2 = PygLoader.load_into_pyg_heterodata(
dblp_db_name,
metagraph_2,
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)

for result in [result_1, result_2]:
data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result

assert isinstance(data, HeteroData)
assert set(data.node_types) == {"author", "paper", "term"}
assert set(data.edge_types) == {
("term", "to", "paper"),
("author", "to", "paper"),
("paper", "to", "term"),
("paper", "to", "author"),
}
assert data["author"]["x"].shape == (4057, 334)
assert data["paper"]["x"].shape == (14328, 4231)
assert data["term"]["x"].shape == (7723, 50)

assert (
len(col_to_adb_key_to_ind["author"])
== len(col_to_ind_to_adb_key["author"])
== 4057
)
assert (
len(col_to_adb_key_to_ind["paper"])
== len(col_to_ind_to_adb_key["paper"])
== 14328
)
assert (
len(col_to_adb_key_to_ind["term"])
== len(col_to_ind_to_adb_key["term"])
== 7723
)

edges = data[("author", "to", "paper")]
assert edges["edge_index"].shape == (2, 19645)

edges = data[("paper", "to", "author")]
assert edges["edge_index"].shape == (2, 19645)

edges = data[("term", "to", "paper")]
assert edges["edge_index"].shape == (2, 85810)

edges = data[("paper", "to", "term")]
assert edges["edge_index"].shape == (2, 85810)


def test_abide_numpy(
load_abide: None, abide_db_name: str, connection_information: dict[str, str]
) -> None:
Expand Down
22 changes: 20 additions & 2 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::hash::Hash;
use std::sync::{Arc, RwLock};

use anyhow::{anyhow, Result};
use log::warn;

#[derive(Eq, Hash, PartialEq, Clone, Copy, Ord, PartialOrd, Debug)]
pub struct VertexHash(u64);
Expand Down Expand Up @@ -785,6 +786,23 @@ impl Graph for NumpyGraph {
(col.to_string(), key[1..].to_string())
};

// if either from_col or to_col is not part of the metagraph definition,
// we will not add it as an edge
if !self.cols_to_keys_to_inds.contains_key(&from_col) {
warn!(
"Skipping edge from {} to {} as {} is not part of the metagraph",
from_col, to_col, from_col
);
return Ok(());
}
if !self.cols_to_keys_to_inds.contains_key(&to_col) {
warn!(
"Skipping edge from {} to {} as {} is not part of the metagraph",
from_col, to_col, to_col
);
return Ok(());
}

debug_assert!(field_names.contains(&String::from("@collection_name")));
let col_name_position = field_names
.iter()
Expand All @@ -803,11 +821,11 @@ impl Graph for NumpyGraph {
let from_col_keys = self
.cols_to_keys_to_inds
.get(&from_col)
.ok_or_else(|| anyhow!("Unable to get keys from for {:?}", &from_col))?;
.ok_or_else(|| anyhow!("Unable to get keys `from` for {:?}", &from_col))?;
let to_col_keys = self
.cols_to_keys_to_inds
.get(&to_col)
.ok_or_else(|| anyhow!("Unable to get keys to for {:?}", &to_col))?;
.ok_or_else(|| anyhow!("Unable to get keys `to` for {:?}", &to_col))?;
let cur_coo = self
.coo_by_from_edge_to
.get_mut(&key_tup)
Expand Down

0 comments on commit 585d5f6

Please sign in to comment.