Skip to content

Commit

Permalink
Merge pull request #18 from arangoml/MLP-641
Browse files Browse the repository at this point in the history
MLP-641 | return reverse mapping
  • Loading branch information
Alex Geenen authored May 16, 2024
2 parents 644f2a8 + 0e75d13 commit d0a86bf
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "phenolrs"
version = "0.3.0"
version = "0.4.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
[project.optional-dependencies]
tests = [
"pytest",
"arango-datasets"
]
dynamic = ["version"]

Expand Down
20 changes: 12 additions & 8 deletions python/phenolrs/numpy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def load_graph_to_numpy(
dict[str, dict[str, npt.NDArray[np.float64]]],
dict[typing.Tuple[str, str, str], npt.NDArray[np.float64]],
dict[str, dict[str, int]],
dict[str, dict[int, str]],
dict[str, dict[str, str]],
]:
# TODO: replace with pydantic validation
Expand Down Expand Up @@ -83,18 +84,21 @@ def load_graph_to_numpy(
for e_col_name, entries in metagraph["edgeCollections"].items()
]

features_by_col, coo_map, col_to_adb_id_to_ind = graph_to_pyg_format(
{
"database": database,
"vertex_collections": vertex_collections,
"edge_collections": edge_collections,
"configuration": {"database_config": db_config_options},
}
features_by_col, coo_map, col_to_adb_key_to_ind, col_to_ind_to_adb_key = (
graph_to_pyg_format(
{
"database": database,
"vertex_collections": vertex_collections,
"edge_collections": edge_collections,
"configuration": {"database_config": db_config_options},
}
)
)

return (
features_by_col,
coo_map,
col_to_adb_id_to_ind,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
vertex_cols_source_to_output,
)
1 change: 1 addition & 0 deletions python/phenolrs/phenolrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def graph_to_pyg_format(request: dict[str, typing.Any]) -> typing.Tuple[
dict[str, dict[str, npt.NDArray[np.float64]]],
dict[typing.Tuple[str, str, str], npt.NDArray[np.float64]],
dict[str, dict[str, int]],
dict[str, dict[int, str]],
]: ...

class PhenolError(Exception): ...
14 changes: 8 additions & 6 deletions python/phenolrs/pyg_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def load_into_pyg_data(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> tuple[Data, dict[str, dict[str, int]]]:
) -> tuple[Data, dict[str, dict[str, int]], dict[str, dict[int, str]]]:
if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand All @@ -46,7 +46,8 @@ def load_into_pyg_data(
(
features_by_col,
coo_map,
col_to_adb_id_to_ind,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
Expand Down Expand Up @@ -85,7 +86,7 @@ def load_into_pyg_data(
if result.numel() > 0:
data["edge_index"] = result

return data, col_to_adb_id_to_ind
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key

@staticmethod
def load_into_pyg_heterodata(
Expand All @@ -98,7 +99,7 @@ def load_into_pyg_heterodata(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> tuple[HeteroData, dict[str, dict[str, int]]]:
) -> tuple[HeteroData, dict[str, dict[str, int]], dict[str, dict[int, str]]]:
if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand All @@ -112,7 +113,8 @@ def load_into_pyg_heterodata(
(
features_by_col,
coo_map,
col_to_adb_id_to_ind,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
Expand Down Expand Up @@ -142,4 +144,4 @@ def load_into_pyg_heterodata(
if result.numel() > 0:
data[(from_name, edge_col_name, to_name)].edge_index = result

return data, col_to_adb_id_to_ind
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key
42 changes: 34 additions & 8 deletions python/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ def test_phenol_abide_hetero(
password=connection_information["password"],
)

data, col_to_adb_id_to_ind = result
data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result
assert isinstance(data, HeteroData)
assert data["Subjects"]["x"].shape == (871, 2000)
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert (
len(col_to_adb_key_to_ind["Subjects"])
== len(col_to_ind_to_adb_key["Subjects"])
== 871
)

assert data[("Subjects", "medical_affinity_graph", "Subjects")][
"edge_index"
].shape == (2, 606770)

# Metagraph variation
result = PygLoader.load_into_pyg_heterodata(
Expand All @@ -39,10 +47,18 @@ def test_phenol_abide_hetero(
password=connection_information["password"],
)

data, col_to_adb_id_to_ind = result
data, col_to_adb_key_to_ind, col_to_ind_to_adb_key = result
assert isinstance(data, HeteroData)
assert data["Subjects"]["x"].shape == (871, 2000)
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert (
len(col_to_adb_key_to_ind["Subjects"])
== len(col_to_ind_to_adb_key["Subjects"])
== 871
)

assert data[("Subjects", "medical_affinity_graph", "Subjects")][
"edge_index"
].shape == (2, 606770)


def test_phenol_abide_numpy(
Expand All @@ -51,7 +67,8 @@ def test_phenol_abide_numpy(
(
features_by_col,
coo_map,
col_to_adb_id_to_ind,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
Expand All @@ -69,13 +86,18 @@ def test_phenol_abide_numpy(
2,
606770,
)
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert (
len(col_to_adb_key_to_ind["Subjects"])
== len(col_to_ind_to_adb_key["Subjects"])
== 871
)
assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}

(
features_by_col,
coo_map,
col_to_adb_id_to_ind,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
Expand All @@ -90,5 +112,9 @@ def test_phenol_abide_numpy(

assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000)
assert len(coo_map) == 0
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert (
len(col_to_adb_key_to_ind["Subjects"])
== len(col_to_ind_to_adb_key["Subjects"])
== 871
)
assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}
2 changes: 1 addition & 1 deletion src/arangodb/aql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn build_aql_query(collection_description: &CollectionDescription, is_edge: bool
let identifiers = if is_edge {
"_to: doc._to,\n_from: doc._from,\n"
} else {
""
"_key: doc._key,\n"
};
let query = format!(
"
Expand Down
27 changes: 14 additions & 13 deletions src/graphs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,20 @@ impl Graph {
to_id: Vec<u8>,
_data: Vec<u8>,
) -> Result<()> {
// build up the coo representation
let from_col: String = String::from_utf8({
let (from_col, from_key) = {
let s = String::from_utf8(from_id.clone()).expect("_from to be a string");
let id_split = s.find('/').unwrap();
(&s[0..id_split]).into()
})
.unwrap();
let to_col: String = String::from_utf8({
let id_split = s.find('/').expect("Invalid format for _from");
let (col, key) = s.split_at(id_split);
(col.to_string(), key[1..].to_string())
};

let (to_col, to_key) = {
let s = String::from_utf8(to_id.clone()).expect("_to to be a string");
let id_split = s.find('/').unwrap();
(&s[0..id_split]).into()
})
.unwrap();
let id_split = s.find('/').expect("Invalid format for _to");
let (col, key) = s.split_at(id_split);
(col.to_string(), key[1..].to_string())
};

let key_tup = (
String::from_utf8(col_name).unwrap(),
from_col.clone(),
Expand All @@ -223,8 +224,8 @@ impl Graph {
.coo_by_from_edge_to
.get_mut(&key_tup)
.ok_or_else(|| anyhow!("Unable to get COO from to for {:?}", &key_tup))?;
let from_col_id = from_col_keys.get(&String::from_utf8(from_id).unwrap());
let to_col_id = to_col_keys.get(&String::from_utf8(to_id).unwrap());
let from_col_id = from_col_keys.get(&from_key);
let to_col_id = to_col_keys.get(&to_key);
if let (Some(from_id), Some(to_id)) = (from_col_id, to_col_id) {
cur_coo[0].push(*from_id);
cur_coo[1].push(*to_id);
Expand Down
20 changes: 17 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;

#[cfg(not(test))]
type PygCompatible<'a> = (&'a PyDict, &'a PyDict, &'a PyDict);
type PygCompatible<'a> = (&'a PyDict, &'a PyDict, &'a PyDict, &'a PyDict);

#[cfg(not(test))]
create_exception!(phenolrs, PhenolError, PyException);
Expand All @@ -31,18 +31,32 @@ create_exception!(phenolrs, PhenolError, PyException);
#[cfg(not(test))]
fn graph_to_pyg_format(py: Python, request: DataLoadRequest) -> PyResult<PygCompatible> {
let graph = load::retrieve::get_arangodb_graph(request).map_err(PhenolError::new_err)?;

let col_to_features = construct::construct_col_to_features(
convert_nested_features_map(graph.cols_to_features),
py,
)?;

let coo_by_from_edge_to = construct::construct_coo_by_from_edge_to(
convert_coo_edge_map(graph.coo_by_from_edge_to),
py,
)?;

let cols_to_keys_to_inds =
construct::construct_cols_to_keys_to_inds(graph.cols_to_keys_to_inds, py)?;
construct::construct_cols_to_keys_to_inds(graph.cols_to_keys_to_inds.clone(), py)?;

let cols_to_inds_to_keys =
construct::construct_cols_to_inds_to_keys(graph.cols_to_keys_to_inds, py)?;

println!("Finished retrieval!");
let res = (col_to_features, coo_by_from_edge_to, cols_to_keys_to_inds);

let res = (
col_to_features,
coo_by_from_edge_to,
cols_to_keys_to_inds,
cols_to_inds_to_keys,
);

Ok(res)
}

Expand Down
6 changes: 4 additions & 2 deletions src/load/receive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ pub fn receive_vertices(
Ok(val) => val,
};
let id = &v["_id"];
let key = &v["_key"];
match id {
Value::String(i) => {
let mut buf = vec![];
buf.extend_from_slice(i[..].as_bytes());
buf.extend_from_slice(key.as_str().unwrap().as_bytes());
vertex_keys.push(buf);
if current_vertex_col.is_none() {
let pos = i.find('/').unwrap();
Expand Down Expand Up @@ -255,10 +256,11 @@ pub fn receive_vertices(
};
for v in values.result.into_iter() {
let id = &v["_id"];
let key = &v["_key"];
match id {
Value::String(i) => {
let mut buf = vec![];
buf.extend_from_slice(i[..].as_bytes());
buf.extend_from_slice(key.as_str().unwrap().as_bytes());
vertex_keys.push(buf);
if current_vertex_col.is_none() {
let pos = i.find('/').unwrap();
Expand Down
16 changes: 16 additions & 0 deletions src/output/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,19 @@ pub fn construct_cols_to_keys_to_inds(
.for_each(|item| dict.set_item(item.0, item.1).unwrap());
Ok(dict)
}

#[cfg(not(test))]
pub fn construct_cols_to_inds_to_keys(
input: HashMap<String, HashMap<String, usize>>,
py: Python,
) -> PyResult<&PyDict> {
let dict = PyDict::new(py);
input.iter().for_each(|(col_name, inner_map)| {
let inner_dict = PyDict::new(py);
inner_map.iter().for_each(|(key, value)| {
inner_dict.set_item(value, key).unwrap();
});
dict.set_item(col_name, inner_dict).unwrap();
});
Ok(dict)
}

0 comments on commit d0a86bf

Please sign in to comment.