Skip to content

Commit

Permalink
Imdb test case (#34)
Browse files Browse the repository at this point in the history
* add `imdb` test

* fmt

* fmt?

* fix test name

* cleanup

* fix imdb

* add failing test cases for imdb

* cleanup

* new: `test_imdb_networkx` (failing)

* fix: skip field if null

* fix: `test_imdb_pyg`

* fix lint

* fix lint (again)

* more assertions

* add extra assertion

* temp: raise value error

* temp: panic

* attempt: bump version
  • Loading branch information
aMahanna authored Aug 28, 2024
1 parent 90bacf8 commit 1a11a97
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "phenolrs"
version = "0.5.5"
version = "0.5.6"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "phenolrs"
version = "0.5.5"
version = "0.5.6"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Rust",
Expand Down
1 change: 1 addition & 0 deletions python/phenolrs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .phenolrs import * # noqa: F403


__doc__ = phenolrs.__doc__ # type: ignore[name-defined] # noqa: F405
if hasattr(phenolrs, "__all__"): # type: ignore[name-defined] # noqa: F405
__all__ = phenolrs.__all__ # type: ignore[name-defined] # noqa: F405
2 changes: 1 addition & 1 deletion python/phenolrs/networkx/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
DiGraphAdjDict,
DstIndices,
EdgeIndices,
EdgeValuesDict,
GraphAdjDict,
MultiDiGraphAdjDict,
MultiGraphAdjDict,
NodeDict,
SrcIndices,
EdgeValuesDict,
)


Expand Down
21 changes: 12 additions & 9 deletions python/phenolrs/numpy/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,18 @@ def load_graph_to_numpy(
for e_col_name, entries in metagraph["edgeCollections"].items()
]

features_by_col, coo_map, col_to_adb_key_to_ind, col_to_ind_to_adb_key = (
graph_to_numpy_format(
{
"vertex_collections": vertex_collections,
"edge_collections": edge_collections,
"database_config": db_config_options,
"load_config": load_config_options,
}
)
(
features_by_col,
coo_map,
col_to_adb_key_to_ind,
col_to_ind_to_adb_key,
) = graph_to_numpy_format(
{
"vertex_collections": vertex_collections,
"edge_collections": edge_collections,
"database_config": db_config_options,
"load_config": load_config_options,
}
)

return (
Expand Down
2 changes: 1 addition & 1 deletion python/phenolrs/phenolrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ from .networkx.typings import (
DiGraphAdjDict,
DstIndices,
EdgeIndices,
EdgeValuesDict,
GraphAdjDict,
MultiDiGraphAdjDict,
MultiGraphAdjDict,
NodeDict,
SrcIndices,
EdgeValuesDict,
)
from .numpy.typings import (
ArangoCollectionToArangoKeyToIndex,
Expand Down
10 changes: 4 additions & 6 deletions python/phenolrs/pyg/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ def load_into_pyg_data(
raise PhenolError("edgeCollections must map to non-empty dictionary")

if len(metagraph["vertexCollections"]) > 1:
raise PhenolError(
"More than one vertex collection specified for homogeneous dataset"
)
m = "More than one vertex collection specified for homogeneous dataset"
raise PhenolError(m)
if len(metagraph["edgeCollections"]) > 1:
raise PhenolError(
"More than one edge collection specified for homogeneous dataset"
)
m = "More than one edge collection specified for homogeneous dataset"
raise PhenolError(m)

v_col_spec_name = list(metagraph["vertexCollections"].keys())[0]
v_col_spec = list(metagraph["vertexCollections"].values())[0]
Expand Down
36 changes: 24 additions & 12 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,47 @@ def connection_information() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def load_abide(abide_db_name: str, connection_information: Dict[str, Any]) -> None:
def load_dataset(
dataset: str, db_name: str, connection_information: Dict[str, Any]
) -> None:
client = arango.ArangoClient(connection_information["url"])
sys_db = client.db(
"_system",
username=connection_information["username"],
password=connection_information["password"],
)

if not sys_db.has_database(abide_db_name):
sys_db.delete_database(abide_db_name, ignore_missing=True)
sys_db.create_database(abide_db_name)
abide_db = client.db(
abide_db_name,
if not sys_db.has_database(db_name):
sys_db.create_database(db_name)
db = client.db(
db_name,
username=connection_information["username"],
password=connection_information["password"],
)
dsets = Datasets(abide_db)
dsets.load("ABIDE")
dsets = Datasets(db)
dsets.load(dataset)


@pytest.fixture(scope="module")
def load_abide(abide_db_name: str, connection_information: Dict[str, Any]) -> None:
load_dataset("ABIDE", abide_db_name, connection_information)


@pytest.fixture(scope="module")
def load_imdb(imdb_db_name: str, connection_information: Dict[str, Any]) -> None:
load_dataset("IMDB_PLATFORM", imdb_db_name, connection_information)


@pytest.fixture(scope="module")
def abide_db_name() -> str:
return "abide"


@pytest.fixture(scope="module")
def imdb_db_name() -> str:
return "imdb"


@pytest.fixture(scope="module")
def custom_graph_db_name() -> str:
return "custom_graph"
Expand All @@ -77,7 +92,6 @@ def load_line_graph(
)

if not sys_db.has_database(custom_graph_db_name):
sys_db.delete_database(custom_graph_db_name, ignore_missing=True)
sys_db.create_database(custom_graph_db_name)
custom_graph_db = client.db(
custom_graph_db_name,
Expand Down Expand Up @@ -114,7 +128,6 @@ def load_karate(karate_db_name: str, connection_information: Dict[str, Any]) ->
)

if not sys_db.has_database(karate_db_name):
sys_db.delete_database(karate_db_name, ignore_missing=True)
sys_db.create_database(karate_db_name)
karate_db = client.db(
karate_db_name,
Expand Down Expand Up @@ -152,7 +165,6 @@ def load_multigraph(
)

if not sys_db.has_database(multigraph_db_name):
sys_db.delete_database(multigraph_db_name, ignore_missing=True)
sys_db.create_database(multigraph_db_name)
multigraph_db = client.db(
multigraph_db_name,
Expand Down
Loading

0 comments on commit 1a11a97

Please sign in to comment.