diff --git a/README.md b/README.md index 97cf427..1838690 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,14 @@ pip install fastgraphml #### Example Homogneous Graphs ```python -from fastgraphml.graph_embeddings import SAGE, GAT +from fastgraphml.graph_embeddings import get_sage_model, get_gat_model from fastgraphml.graph_embeddings import downstream_tasks from fastgraphml import Datasets from arango import ArangoClient +SAGE = get_sage_model() +GAT = get_gat_model() + # Initialize the ArangoDB client. client = ArangoClient("http://127.0.0.1:8529") db = client.db('_system', username='root', password='') @@ -66,12 +69,15 @@ embeddings = model.get_embeddings() # get embeddings #### Example Heterogeneous Graphs ```python -from fastgraphml.graph_embeddings import METAPATH2VEC, DMGI +from fastgraphml.graph_embeddings import get_metapath2vec_model, get_dmgi_model from fastgraphml.graph_embeddings import downstream_tasks from fastgraphml import Datasets from arango import ArangoClient +METAPATH2VEC = get_metapath2vec_model() +DMGI = get_dmgi_model() + # Initialize the ArangoDB client. client = ArangoClient("http://127.0.0.1:8529") db = client.db('_system', username='root') @@ -107,10 +113,13 @@ embeddings = model.get_embeddings() # get embeddings ### Use Case 2: Generates Graph Embeddings using PyG graphs: ```python -from fastgraphml.graph_embeddings import SAGE, GAT +from fastgraphml.graph_embeddings import get_sage_model, get_gat_model from fastgraphml.graph_embeddings import downstream_tasks from torch_geometric.datasets import Planetoid +SAGE = get_sage_model() +GAT = get_gat_model() + # load pyg dataset dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] diff --git a/fastgraphml/__init__.py b/fastgraphml/__init__.py index a26e2a4..d1a0a30 100644 --- a/fastgraphml/__init__.py +++ b/fastgraphml/__init__.py @@ -1,8 +1,18 @@ -from arango_datasets.datasets import Datasets +def get_dmgi_model(): + from .graph_embeddings.models.dmgi import DMGI + return DMGI -from fastgraphml.graph_embeddings.models.dmgi import DMGI -from fastgraphml.graph_embeddings.models.gat import GAT -from fastgraphml.graph_embeddings.models.graph_sage import SAGE -from fastgraphml.graph_embeddings.models.metapath2vec import METAPATH2VEC +def get_gat_model(): + from .graph_embeddings.models.gat import GAT + return GAT + +def get_sage_model(): + from .graph_embeddings.models.graph_sage import SAGE + return SAGE + +def get_metapath2vec_model(): + from .graph_embeddings.models.metapath2vec import METAPATH2VEC + return METAPATH2VEC + +__all__ = ["get_dmgi_model", "get_gat_model", "get_sage_model", "get_metapath2vec_model"] -__all__ = ["DMGI", "GAT", "SAGE", "METAPATH2VEC", "Datasets"] diff --git a/fastgraphml/graph_embeddings/__init__.py b/fastgraphml/graph_embeddings/__init__.py index a9279a3..187e387 100644 --- a/fastgraphml/graph_embeddings/__init__.py +++ b/fastgraphml/graph_embeddings/__init__.py @@ -1,6 +1,18 @@ -from fastgraphml.graph_embeddings.models.dmgi import DMGI -from fastgraphml.graph_embeddings.models.gat import GAT -from fastgraphml.graph_embeddings.models.graph_sage import SAGE -from fastgraphml.graph_embeddings.models.metapath2vec import METAPATH2VEC +def get_dmgi_model(): + from .models.dmgi import DMGI + return DMGI + +def get_gat_model(): + from .models.gat import GAT + return GAT + +def get_sage_model(): + from .models.graph_sage import SAGE + return SAGE + +def get_metapath2vec_model(): + from .models.metapath2vec import METAPATH2VEC + return METAPATH2VEC + +__all__ = ["get_dmgi_model", "get_gat_model", "get_sage_model", "get_metapath2vec_model"] -__all__ = ["DMGI", "GAT", "SAGE", "METAPATH2VEC"]