diff --git a/Cargo.lock b/Cargo.lock index e3775c1..b28300a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,7 +806,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "phenolrs" -version = "0.3.0" +version = "0.4.0" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 2b89370..1f95220 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phenolrs" -version = "0.3.0" +version = "0.4.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pyproject.toml b/pyproject.toml index 2dc5dc8..63d188b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ [project.optional-dependencies] tests = [ "pytest", + "arango-datasets" ] dynamic = ["version"] diff --git a/python/phenolrs/numpy_loader.py b/python/phenolrs/numpy_loader.py index e05ad6c..0ca901a 100644 --- a/python/phenolrs/numpy_loader.py +++ b/python/phenolrs/numpy_loader.py @@ -22,6 +22,7 @@ def load_graph_to_numpy( dict[str, dict[str, npt.NDArray[np.float64]]], dict[typing.Tuple[str, str, str], npt.NDArray[np.float64]], dict[str, dict[str, int]], + dict[str, dict[int, str]], dict[str, dict[str, str]], ]: # TODO: replace with pydantic validation @@ -83,18 +84,21 @@ def load_graph_to_numpy( for e_col_name, entries in metagraph["edgeCollections"].items() ] - features_by_col, coo_map, col_to_adb_id_to_ind = graph_to_pyg_format( - { - "database": database, - "vertex_collections": vertex_collections, - "edge_collections": edge_collections, - "configuration": {"database_config": db_config_options}, - } + features_by_col, coo_map, col_to_adb_key_to_ind, col_to_ind_to_adb_key = ( + graph_to_pyg_format( + { + "database": database, + "vertex_collections": vertex_collections, + "edge_collections": edge_collections, + "configuration": {"database_config": db_config_options}, + } + ) ) return ( features_by_col, coo_map, - col_to_adb_id_to_ind, + col_to_adb_key_to_ind, + col_to_ind_to_adb_key, vertex_cols_source_to_output, ) diff --git a/python/phenolrs/phenolrs.pyi b/python/phenolrs/phenolrs.pyi index 3baf9fc..c4ad824 100644 --- a/python/phenolrs/phenolrs.pyi +++ b/python/phenolrs/phenolrs.pyi @@ -7,6 +7,7 @@ def graph_to_pyg_format(request: dict[str, typing.Any]) -> typing.Tuple[ dict[str, dict[str, npt.NDArray[np.float64]]], dict[typing.Tuple[str, str, str], npt.NDArray[np.float64]], dict[str, dict[str, int]], + dict[str, dict[int, str]], ]: ... class PhenolError(Exception): ... diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index 3e81d53..977176d 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -20,7 +20,7 @@ def load_into_pyg_data( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[Data, dict[str, dict[str, int]]]: + ) -> tuple[Data, dict[str, dict[str, int]], dict[str, dict[int, str]]]: if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -46,7 +46,8 @@ def load_into_pyg_data( ( features_by_col, coo_map, - col_to_adb_id_to_ind, + col_to_adb_key_to_ind, + col_to_ind_to_adb_key, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( database, @@ -85,7 +86,7 @@ def load_into_pyg_data( if result.numel() > 0: data["edge_index"] = result - return data, col_to_adb_id_to_ind + return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key @staticmethod def load_into_pyg_heterodata( @@ -98,7 +99,7 @@ def load_into_pyg_heterodata( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[HeteroData, dict[str, dict[str, int]]]: + ) -> tuple[HeteroData, dict[str, dict[str, int]], dict[str, dict[int, str]]]: if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -112,7 +113,8 @@ def load_into_pyg_heterodata( ( features_by_col, coo_map, - col_to_adb_id_to_ind, + col_to_adb_key_to_ind, + col_to_ind_to_adb_key, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( database, @@ -142,4 +144,4 @@ def load_into_pyg_heterodata( if result.numel() > 0: data[(from_name, edge_col_name, to_name)].edge_index = result - return data, col_to_adb_id_to_ind + return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key diff --git a/python/tests/test_all.py b/python/tests/test_all.py index 4b34d6c..771692f 100644 --- a/python/tests/test_all.py +++ b/python/tests/test_all.py @@ -20,10 +20,18 @@ def test_phenol_abide_hetero( password=connection_information["password"], ) - data, col_to_adb_id_to_ind = result + data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result assert isinstance(data, HeteroData) assert data["Subjects"]["x"].shape == (871, 2000) - assert len(col_to_adb_id_to_ind["Subjects"]) == 871 + assert ( + len(col_to_adb_key_to_ind["Subjects"]) + == len(col_to_ind_to_adb_key["Subjects"]) + == 871 + ) + + assert data[("Subjects", "medical_affinity_graph", "Subjects")][ + "edge_index" + ].shape == (2, 606770) # Metagraph variation result = PygLoader.load_into_pyg_heterodata( @@ -39,10 +47,18 @@ def test_phenol_abide_hetero( password=connection_information["password"], ) - data, col_to_adb_id_to_ind = result + data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result assert isinstance(data, HeteroData) assert data["Subjects"]["x"].shape == (871, 2000) - assert len(col_to_adb_id_to_ind["Subjects"]) == 871 + assert ( + len(col_to_adb_key_to_ind["Subjects"]) + == len(col_to_ind_to_adb_key["Subjects"]) + == 871 + ) + + assert data[("Subjects", "medical_affinity_graph", "Subjects")][ + "edge_index" + ].shape == (2, 606770) def test_phenol_abide_numpy( @@ -51,7 +67,8 @@ def test_phenol_abide_numpy( ( features_by_col, coo_map, - col_to_adb_id_to_ind, + col_to_adb_key_to_ind, + col_to_ind_to_adb_key, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( connection_information["dbName"], @@ -69,13 +86,18 @@ def test_phenol_abide_numpy( 2, 606770, ) - assert len(col_to_adb_id_to_ind["Subjects"]) == 871 + assert ( + len(col_to_adb_key_to_ind["Subjects"]) + == len(col_to_ind_to_adb_key["Subjects"]) + == 871 + ) assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}} ( features_by_col, coo_map, - col_to_adb_id_to_ind, + col_to_adb_key_to_ind, + col_to_ind_to_adb_key, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( connection_information["dbName"], @@ -90,5 +112,9 @@ def test_phenol_abide_numpy( assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000) assert len(coo_map) == 0 - assert len(col_to_adb_id_to_ind["Subjects"]) == 871 + assert ( + len(col_to_adb_key_to_ind["Subjects"]) + == len(col_to_ind_to_adb_key["Subjects"]) + == 871 + ) assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}} diff --git a/src/arangodb/aql.rs b/src/arangodb/aql.rs index c6310c4..b2cc99e 100644 --- a/src/arangodb/aql.rs +++ b/src/arangodb/aql.rs @@ -246,7 +246,7 @@ fn build_aql_query(collection_description: &CollectionDescription, is_edge: bool let identifiers = if is_edge { "_to: doc._to,\n_from: doc._from,\n" } else { - "" + "_key: doc._key,\n" }; let query = format!( " diff --git a/src/graphs.rs b/src/graphs.rs index f1ef508..2640ba3 100644 --- a/src/graphs.rs +++ b/src/graphs.rs @@ -189,19 +189,20 @@ impl Graph { to_id: Vec, _data: Vec, ) -> Result<()> { - // build up the coo representation - let from_col: String = String::from_utf8({ + let (from_col, from_key) = { let s = String::from_utf8(from_id.clone()).expect("_from to be a string"); - let id_split = s.find('/').unwrap(); - (&s[0..id_split]).into() - }) - .unwrap(); - let to_col: String = String::from_utf8({ + let id_split = s.find('/').expect("Invalid format for _from"); + let (col, key) = s.split_at(id_split); + (col.to_string(), key[1..].to_string()) + }; + + let (to_col, to_key) = { let s = String::from_utf8(to_id.clone()).expect("_to to be a string"); - let id_split = s.find('/').unwrap(); - (&s[0..id_split]).into() - }) - .unwrap(); + let id_split = s.find('/').expect("Invalid format for _to"); + let (col, key) = s.split_at(id_split); + (col.to_string(), key[1..].to_string()) + }; + let key_tup = ( String::from_utf8(col_name).unwrap(), from_col.clone(), @@ -223,8 +224,8 @@ impl Graph { .coo_by_from_edge_to .get_mut(&key_tup) .ok_or_else(|| anyhow!("Unable to get COO from to for {:?}", &key_tup))?; - let from_col_id = from_col_keys.get(&String::from_utf8(from_id).unwrap()); - let to_col_id = to_col_keys.get(&String::from_utf8(to_id).unwrap()); + let from_col_id = from_col_keys.get(&from_key); + let to_col_id = to_col_keys.get(&to_key); if let (Some(from_id), Some(to_id)) = (from_col_id, to_col_id) { cur_coo[0].push(*from_id); cur_coo[1].push(*to_id); diff --git a/src/lib.rs b/src/lib.rs index 29eb679..ddae9d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; #[cfg(not(test))] -type PygCompatible<'a> = (&'a PyDict, &'a PyDict, &'a PyDict); +type PygCompatible<'a> = (&'a PyDict, &'a PyDict, &'a PyDict, &'a PyDict); #[cfg(not(test))] create_exception!(phenolrs, PhenolError, PyException); @@ -31,18 +31,32 @@ create_exception!(phenolrs, PhenolError, PyException); #[cfg(not(test))] fn graph_to_pyg_format(py: Python, request: DataLoadRequest) -> PyResult { let graph = load::retrieve::get_arangodb_graph(request).map_err(PhenolError::new_err)?; + let col_to_features = construct::construct_col_to_features( convert_nested_features_map(graph.cols_to_features), py, )?; + let coo_by_from_edge_to = construct::construct_coo_by_from_edge_to( convert_coo_edge_map(graph.coo_by_from_edge_to), py, )?; + let cols_to_keys_to_inds = - construct::construct_cols_to_keys_to_inds(graph.cols_to_keys_to_inds, py)?; + construct::construct_cols_to_keys_to_inds(graph.cols_to_keys_to_inds.clone(), py)?; + + let cols_to_inds_to_keys = + construct::construct_cols_to_inds_to_keys(graph.cols_to_keys_to_inds, py)?; + println!("Finished retrieval!"); - let res = (col_to_features, coo_by_from_edge_to, cols_to_keys_to_inds); + + let res = ( + col_to_features, + coo_by_from_edge_to, + cols_to_keys_to_inds, + cols_to_inds_to_keys, + ); + Ok(res) } diff --git a/src/load/receive.rs b/src/load/receive.rs index e61a394..669eca7 100644 --- a/src/load/receive.rs +++ b/src/load/receive.rs @@ -189,10 +189,11 @@ pub fn receive_vertices( Ok(val) => val, }; let id = &v["_id"]; + let key = &v["_key"]; match id { Value::String(i) => { let mut buf = vec![]; - buf.extend_from_slice(i[..].as_bytes()); + buf.extend_from_slice(key.as_str().unwrap().as_bytes()); vertex_keys.push(buf); if current_vertex_col.is_none() { let pos = i.find('/').unwrap(); @@ -255,10 +256,11 @@ pub fn receive_vertices( }; for v in values.result.into_iter() { let id = &v["_id"]; + let key = &v["_key"]; match id { Value::String(i) => { let mut buf = vec![]; - buf.extend_from_slice(i[..].as_bytes()); + buf.extend_from_slice(key.as_str().unwrap().as_bytes()); vertex_keys.push(buf); if current_vertex_col.is_none() { let pos = i.find('/').unwrap(); diff --git a/src/output/construct.rs b/src/output/construct.rs index 83338a8..23a3bfb 100644 --- a/src/output/construct.rs +++ b/src/output/construct.rs @@ -43,3 +43,19 @@ pub fn construct_cols_to_keys_to_inds( .for_each(|item| dict.set_item(item.0, item.1).unwrap()); Ok(dict) } + +#[cfg(not(test))] +pub fn construct_cols_to_inds_to_keys( + input: HashMap>, + py: Python, +) -> PyResult<&PyDict> { + let dict = PyDict::new(py); + input.iter().for_each(|(col_name, inner_map)| { + let inner_dict = PyDict::new(py); + inner_map.iter().for_each(|(key, value)| { + inner_dict.set_item(value, key).unwrap(); + }); + dict.set_item(col_name, inner_dict).unwrap(); + }); + Ok(dict) +}