Skip to content

Commit

Permalink
Merge pull request #20 from arangoml/make-torch-optional
Browse files Browse the repository at this point in the history
make `torch` optional
  • Loading branch information
Alex Geenen authored May 30, 2024
2 parents 95e1182 + 3f0e2f5 commit 5827a76
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "phenolrs"
version = "0.4.1"
version = "0.4.2"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ classifiers = [
]
dependencies = [
"numpy",
"torch",
"torch-geometric",
"python-arango"
]

Expand All @@ -22,6 +20,10 @@ tests = [
"pytest",
"arango-datasets"
]
torch = [
"torch",
"torch-geometric",
]
dynamic = ["version"]

[tool.maturin]
Expand Down
22 changes: 18 additions & 4 deletions python/phenolrs/pyg_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import typing

import numpy as np
import torch
from torch_geometric.data import Data, HeteroData

from phenolrs import PhenolError
from phenolrs.numpy_loader import NumpyLoader

try:
import torch
from torch_geometric.data import Data, HeteroData

TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False


class PygLoader:
@staticmethod
Expand All @@ -20,7 +26,11 @@ def load_into_pyg_data(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> tuple[Data, dict[str, dict[str, int]], dict[str, dict[int, str]]]:
) -> tuple["Data", dict[str, dict[str, int]], dict[str, dict[int, str]]]:
if not TORCH_AVAILABLE:
m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501
raise ImportError(m)

if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand Down Expand Up @@ -99,7 +109,11 @@ def load_into_pyg_heterodata(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> tuple[HeteroData, dict[str, dict[str, int]], dict[str, dict[int, str]]]:
) -> tuple["HeteroData", dict[str, dict[str, int]], dict[str, dict[int, str]]]:
if not TORCH_AVAILABLE:
m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501
raise ImportError(m)

if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand Down

0 comments on commit 5827a76

Please sign in to comment.