diff --git a/Cargo.lock b/Cargo.lock index 015a3be..866b60e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -922,7 +922,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "phenolrs" -version = "0.5.7" +version = "0.5.8" dependencies = [ "anyhow", "arangors-graph-exporter", diff --git a/Cargo.toml b/Cargo.toml index 4dc7542..65f8286 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phenolrs" -version = "0.5.7" +version = "0.5.8" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pyproject.toml b/pyproject.toml index 104a193..8c7bb3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "phenolrs" -version = "0.5.7" +version = "0.5.8" requires-python = ">=3.10" classifiers = [ "Programming Language :: Rust", diff --git a/python/phenolrs/pyg/loader.py b/python/phenolrs/pyg/loader.py index 8d829e0..5977e13 100644 --- a/python/phenolrs/pyg/loader.py +++ b/python/phenolrs/pyg/loader.py @@ -154,6 +154,9 @@ def load_into_pyg_heterodata( for col in features_by_col.keys(): col_mapping = vertex_cols_source_to_output[col] for feature in features_by_col[col].keys(): + if feature == "@collection_name": + continue + target_name = col_mapping[feature] result = torch.from_numpy( features_by_col[col][feature].astype(np.float64) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 53b8511..21644a2 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -65,6 +65,11 @@ def load_imdb(imdb_db_name: str, connection_information: Dict[str, Any]) -> None load_dataset("IMDB_PLATFORM", imdb_db_name, connection_information) +@pytest.fixture(scope="module") +def load_dblp(dblp_db_name: str, connection_information: Dict[str, Any]) -> None: + load_dataset("DBLP", dblp_db_name, connection_information) + + @pytest.fixture(scope="module") def abide_db_name() -> str: return "abide" @@ -75,6 +80,11 @@ def imdb_db_name() -> str: return "imdb" +@pytest.fixture(scope="module") +def dblp_db_name() -> str: + return "dblp" + + @pytest.fixture(scope="module") def custom_graph_db_name() -> str: return "custom_graph" diff --git a/python/tests/test_all.py b/python/tests/test_all.py index 14397a2..2e60fff 100644 --- a/python/tests/test_all.py +++ b/python/tests/test_all.py @@ -108,6 +108,94 @@ def test_imdb_pyg( assert edges["edge_index"].shape == (2, 100000) +def test_dblp_pyg( + load_dblp: None, + dblp_db_name: str, + connection_information: dict[str, str], +) -> None: + metagraph_1 = { + "vertexCollections": { + "author": {"x": "x"}, + "paper": {"x": "x"}, + "term": {"x": "x"}, + "conference": {}, + }, + "edgeCollections": { + "to": {}, + }, + } + + result_1 = PygLoader.load_into_pyg_heterodata( + dblp_db_name, + metagraph_1, + [connection_information["url"]], + username=connection_information["username"], + password=connection_information["password"], + ) + + metagraph_2 = { + "vertexCollections": { + "author": {"x": "x"}, + "paper": {"x": "x"}, + "term": {"x": "x"}, + }, + "edgeCollections": { + "to": {}, + }, + } + + result_2 = PygLoader.load_into_pyg_heterodata( + dblp_db_name, + metagraph_2, + [connection_information["url"]], + username=connection_information["username"], + password=connection_information["password"], + ) + + for result in [result_1, result_2]: + data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result + + assert isinstance(data, HeteroData) + assert set(data.node_types) == {"author", "paper", "term"} + assert set(data.edge_types) == { + ("term", "to", "paper"), + ("author", "to", "paper"), + ("paper", "to", "term"), + ("paper", "to", "author"), + } + assert data["author"]["x"].shape == (4057, 334) + assert data["paper"]["x"].shape == (14328, 4231) + assert data["term"]["x"].shape == (7723, 50) + + assert ( + len(col_to_adb_key_to_ind["author"]) + == len(col_to_ind_to_adb_key["author"]) + == 4057 + ) + assert ( + len(col_to_adb_key_to_ind["paper"]) + == len(col_to_ind_to_adb_key["paper"]) + == 14328 + ) + assert ( + len(col_to_adb_key_to_ind["term"]) + == len(col_to_ind_to_adb_key["term"]) + == 7723 + ) + + edges = data[("author", "to", "paper")] + assert edges["edge_index"].shape == (2, 19645) + + edges = data[("paper", "to", "author")] + assert edges["edge_index"].shape == (2, 19645) + + edges = data[("term", "to", "paper")] + assert edges["edge_index"].shape == (2, 85810) + + edges = data[("paper", "to", "term")] + assert edges["edge_index"].shape == (2, 85810) + + def test_abide_numpy( load_abide: None, abide_db_name: str, connection_information: dict[str, str] ) -> None: diff --git a/src/graph.rs b/src/graph.rs index 09338b9..89c2d5b 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -4,6 +4,7 @@ use std::hash::Hash; use std::sync::{Arc, RwLock}; use anyhow::{anyhow, Result}; +use log::warn; #[derive(Eq, Hash, PartialEq, Clone, Copy, Ord, PartialOrd, Debug)] pub struct VertexHash(u64); @@ -785,6 +786,23 @@ impl Graph for NumpyGraph { (col.to_string(), key[1..].to_string()) }; + // if either from_col or to_col is not part of the metagraph definition, + // we will not add it as an edge + if !self.cols_to_keys_to_inds.contains_key(&from_col) { + warn!( + "Skipping edge from {} to {} as {} is not part of the metagraph", + from_col, to_col, from_col + ); + return Ok(()); + } + if !self.cols_to_keys_to_inds.contains_key(&to_col) { + warn!( + "Skipping edge from {} to {} as {} is not part of the metagraph", + from_col, to_col, to_col + ); + return Ok(()); + } + debug_assert!(field_names.contains(&String::from("@collection_name"))); let col_name_position = field_names .iter() @@ -803,11 +821,11 @@ impl Graph for NumpyGraph { let from_col_keys = self .cols_to_keys_to_inds .get(&from_col) - .ok_or_else(|| anyhow!("Unable to get keys from for {:?}", &from_col))?; + .ok_or_else(|| anyhow!("Unable to get keys `from` for {:?}", &from_col))?; let to_col_keys = self .cols_to_keys_to_inds .get(&to_col) - .ok_or_else(|| anyhow!("Unable to get keys to for {:?}", &to_col))?; + .ok_or_else(|| anyhow!("Unable to get keys `to` for {:?}", &to_col))?; let cur_coo = self .coo_by_from_edge_to .get_mut(&key_tup)