diff --git a/CHANGELOG.md b/CHANGELOG.md
index 271acd853b5d..221f02fcd2e8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
+- Added various GRetriever Architecture Benchmarking examples ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
+- Added `profiler.nvtxit` with some examples ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
+- Added `loader.RagQueryLoader` with Remote Backend Example ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
+- Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
+
### Changed
### Deprecated
diff --git a/docs/source/_figures/flowchart.svg b/docs/source/_figures/flowchart.svg
new file mode 100644
index 000000000000..188d37b14f41
--- /dev/null
+++ b/docs/source/_figures/flowchart.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/_figures/multihop_example.svg b/docs/source/_figures/multihop_example.svg
new file mode 100644
index 000000000000..4925dcb9713d
--- /dev/null
+++ b/docs/source/_figures/multihop_example.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/_figures/remote_backend.svg b/docs/source/_figures/remote_backend.svg
new file mode 100644
index 000000000000..c5791f0a95de
--- /dev/null
+++ b/docs/source/_figures/remote_backend.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/advanced/rag.rst b/docs/source/advanced/rag.rst
new file mode 100644
index 000000000000..2d8513fd56e4
--- /dev/null
+++ b/docs/source/advanced/rag.rst
@@ -0,0 +1,614 @@
+Working with LLM RAG in Pytorch Geometric
+=========================================
+
+This series aims to provide a starting point and for
+multi-step LLM Retrieval Augmented Generation
+(RAG) using Graph Neural Networks.
+
+Motivation
+----------
+
+As Large Language Models (LLMs) quickly grow to dominate industry, they
+are increasingly being deployed at scale in use cases that require very
+specific contextual expertise. LLMs often struggle with these cases out
+of the box, as they will hallucinate answers that are not included in
+their training data. At the same time, many business already have large
+graph databases full of important context that can provide important
+domain-specific context to reduce hallucination and improve answer
+fidelity for LLMs. Graph Neural Networks (GNNs) provide a means for
+efficiently encoding this contextual information into the model, which
+can help LLMs to better understand and generate answers. Hence, theres
+an open research question as to how to effectively use GNN encodings
+efficiently for this purpose, that the tooling provided here can help
+investigate.
+
+Architecture
+------------
+
+To model the use-case of RAG from a large knowledge graph of millions of
+nodes, we present the following architecture:
+
+
+
+
+
+.. figure:: ../_figures/flowchart.svg
+ :align: center
+ :width: 100%
+
+
+
+Graph RAG as shown in the diagram above follows the following order of
+operations:
+
+0. To start, not pictured here, there must exist a large knowledge graph
+ that exists as a source of truth. The nodes and edges of this
+ knowledge graph
+
+During inference time, RAG implementations that follow this architecture
+are composed of the following steps:
+
+1. Tokenize and encode the query using the LLM Encoder
+2. Retrieve a subgraph of the larger knowledge graph (KG) relevant to
+ the query and encode it using a GNN
+3. Jointly embed the GNN embedding with the LLM embedding
+4. Utilize LLM Decoder to decode joint embedding and generate a response
+
+
+
+
+Encoding a Large Knowledge Graph
+================================
+
+To start, a Large Knowledge Graph needs to be created from triplets or
+multiple subgraphs in a dataset.
+
+Example 1: Building from Already Existing Datasets
+--------------------------------------------------
+
+In most RAG scenarios, the subset of the information corpus that gets
+retrieved is crucial for whether the appropriate response to the LLM.
+The same is true for GNN based RAG. For example, consider the
+WebQSPDataset.
+
+.. code:: python
+
+ from torch_geometric.datasets import WebQSPDataset
+
+ num_questions = 100
+ ds = WebQSPDataset('small_sample', limit=num_questions)
+
+
+WebQSP is a dataset that is based off of a subset of the Freebase
+Knowledge Graph, which is an open-source knowledge graph formerly
+maintained by Google. For each question-answer pair in the dataset, a
+subgraph was chosen based on a Semantic SPARQL search on the larger
+knowledge graph, to provide relevent context on finding the answer. So
+each entry in the dataset consists of:
+- A question to be answered
+- The answer
+- A knowledge graph subgraph of Freebase that has the context
+needed to answer the question.
+
+.. code:: python
+
+ ds.raw_dataset
+
+ >>> Dataset({
+ features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
+ num_rows: 100
+ })
+
+
+
+.. code:: python
+
+ ds.raw_dataset[0]
+
+
+ >>> {'id': 'WebQTrn-0',
+ 'question': 'what is the name of justin bieber brother',
+ 'answer': ['Jaxon Bieber'],
+ 'q_entity': ['Justin Bieber'],
+ 'a_entity': ['Jaxon Bieber'],
+ 'graph': [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'],
+ ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'],
+ ...],
+ 'choices': []}
+
+
+
+Although this dataset can be trained on as-is, a couple problems emerge
+from doing so:
+1. A retrieval algorithm needs to be implemented and
+executed during inference time, that might not appropriately correspond
+to the algorithm that was used to generate the dataset subgraphs.
+1. The dataset as is not stored computationally efficiently, as there will
+exist many duplicate nodes and edges that are shared between the
+questions.
+
+As a result, it makes sense in this scenario to be able to encode all
+the entries into a large knowledge graph, so that duplicate nodes and
+edges can be avoided, and so that alternative retrieval algorithms can
+be tried. We can do this with the LargeGraphIndexer class:
+
+.. code:: python
+
+ from torch_geometric.data import LargeGraphIndexer, Data, get_features_for_triplets_groups
+ from torch_geometric.nn.nlp import SentenceTransformer
+ import time
+ import torch
+ import tqdm
+ from itertools import chain
+ import networkx as nx
+
+.. code:: python
+
+ raw_dataset_graphs = [[tuple(trip) for trip in graph] for graph in ds.raw_dataset['graph']]
+ print(raw_dataset_graphs[0][:10])
+
+ >>> [('P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'), ('1Club.FM: Power', 'broadcast.content.artist', 'P!nk'), ...]
+
+
+To show the benefits of this indexer in action, we will use the
+following model to encode this sample of graphs using LargeGraphIndexer,
+along with naively.
+
+.. code:: python
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = SentenceTransformer(model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+
+First, we compare the clock times of encoding using both methods.
+
+.. code:: python
+
+ # Indexing question-by-question
+ dataset_graphs_embedded = []
+ start = time.time()
+ for graph in tqdm.tqdm(raw_dataset_graphs):
+ nodes_map = dict()
+ edges_map = dict()
+ edge_idx_base = []
+
+ for src, edge, dst in graph:
+ # Collect nodes
+ if src not in nodes_map:
+ nodes_map[src] = len(nodes_map)
+ if dst not in nodes_map:
+ nodes_map[dst] = len(nodes_map)
+
+ # Collect edge types
+ if edge not in edges_map:
+ edges_map[edge] = len(edges_map)
+
+ # Record edge
+ edge_idx_base.append((nodes_map[src], edges_map[edge], nodes_map[dst]))
+
+ # Encode nodes and edges
+ sorted_nodes = list(sorted(nodes_map.keys(), key=lambda x: nodes_map[x]))
+ sorted_edges = list(sorted(edges_map.keys(), key=lambda x: edges_map[x]))
+
+ x = model.encode(sorted_nodes, batch_size=256)
+ edge_attrs_map = model.encode(sorted_edges, batch_size=256)
+
+ edge_attrs = []
+ edge_idx = []
+ for trip in edge_idx_base:
+ edge_attrs.append(edge_attrs_map[trip[1]])
+ edge_idx.append([trip[0], trip[2]])
+
+ dataset_graphs_embedded.append(Data(x=x, edge_index=torch.tensor(edge_idx).T, edge_attr=torch.stack(edge_attrs, dim=0)))
+
+
+ print(time.time()-start)
+
+ >>> 121.68579435348511
+
+
+
+.. code:: python
+
+ # Using LargeGraphIndexer to make one large knowledge graph
+ from torch_geometric.data.large_graph_indexer import EDGE_RELATION
+
+ start = time.time()
+ all_triplets_together = chain.from_iterable(raw_dataset_graphs)
+ # Index as one large graph
+ print('Indexing...')
+ indexer = LargeGraphIndexer.from_triplets(all_triplets_together)
+
+ # first the nodes
+ unique_nodes = indexer.get_unique_node_features()
+ node_encs = model.encode(unique_nodes, batch_size=256)
+ indexer.add_node_feature(new_feature_name='x', new_feature_vals=node_encs)
+
+ # then the edges
+ unique_edges = indexer.get_unique_edge_features(feature_name=EDGE_RELATION)
+ edge_attr = model.encode(unique_edges, batch_size=256)
+ indexer.add_edge_feature(new_feature_name="edge_attr", new_feature_vals=edge_attr, map_from_feature=EDGE_RELATION)
+
+ ckpt_time = time.time()
+ whole_knowledge_graph = indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr')
+ whole_graph_done = time.time()
+ print(f"Time to create whole knowledge_graph: {whole_graph_done-start}")
+
+ # Compute this to make sure we're comparing like to like on final time printout
+ whole_graph_diff = whole_graph_done-ckpt_time
+
+ # retrieve subgraphs
+ print('Retrieving Subgraphs...')
+ dataset_graphs_embedded_largegraphindexer = [graph for graph in tqdm.tqdm(get_features_for_triplets_groups(indexer=indexer, triplet_groups=raw_dataset_graphs), total=num_questions)]
+ print(time.time()-start-whole_graph_diff)
+
+ >>> Indexing...
+ >>> Time to create whole knowledge_graph: 114.01080107688904
+ >>> Retrieving Subgraphs...
+ >>> 114.66037964820862
+
+
+The large graph indexer allows us to compute the entire knowledge graph
+from a series of samples, so that new retrieval methods can also be
+tested on the entire graph. We will see this attempted in practice later
+on.
+
+It’s worth noting that, although the times are relatively similar right
+now, the speedup with largegraphindexer will be much higher as the size
+of the knowledge graph grows. This is due to the speedup being a factor
+of the number of unique nodes and edges in the graph.
+
+
+We expect the two results to be functionally identical, with the
+differences being due to floating point jitter.
+
+.. code:: python
+
+ def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.8):
+ def _sorted_tensors_are_close(tensor1, tensor2):
+ return torch.all(torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1) > thresh)
+ def _graphs_are_same(tensor1, tensor2):
+ return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T))
+ return _sorted_tensors_are_close(ground_truth.x, new_method.x) \
+ and _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr) \
+ and _graphs_are_same(ground_truth.edge_index, new_method.edge_index)
+
+
+ all_results_match = True
+ for old_graph, new_graph in tqdm.tqdm(zip(dataset_graphs_embedded, dataset_graphs_embedded_largegraphindexer), total=num_questions):
+ all_results_match &= results_are_close_enough(old_graph, new_graph)
+ all_results_match
+
+ >>> True
+
+
+
+When scaled up to the entire dataset, we see a 2x speedup with indexing
+this way on the WebQSP Dataset.
+
+Example 2: Building a new Dataset from Questions and an already-existing Knowledge Graph
+----------------------------------------------------------------------------------------
+
+Motivation
+~~~~~~~~~~
+
+One potential application of knowledge graph structural encodings is
+capturing the relationships between different entities that are multiple
+hops apart. This can be challenging for an LLM to recognize from
+prepended graph information. Here’s a motivating example (credit to
+@Rishi Puri):
+
+
+.. figure:: ../_figures/multihop_example.svg
+ :align: center
+ :width: 100%
+
+
+
+In this example, the question can only be answered by reasoning about
+the relationships between the entities in the knowledge graph.
+
+Building a Multi-Hop QA Dataset
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+To see an example of encoding a large knowledge graph starting from an
+existing set of triplets, check out the multi-hop example in
+`examples/llm_plus_gnn/multihop_rag`.
+
+Question: How do we extract a contextual subgraph for a given query?
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The chosen retrieval algorithm is a critical component in the pipeline
+for affecting RAG performance. In the next section, we will
+demonstrate a naive method of retrieval for a large knowledge graph.
+
+
+Retrieval Algorithms and Scaling Retrieval
+==========================================
+
+Motivation
+----------
+
+When building a RAG Pipeline for inference, the retrieval component is
+important for the following reasons:
+1. A given algorithm for retrieving subgraph context can have a
+marked effect on the hallucination rate of the responses in the model
+2. A given retrieval algorithm needs to be able to scale to larger
+graphs of millions of nodes and edges in order to be practical for production.
+
+In this section, we will explore how to construct a RAG retrieval
+algorithm from a given subgraph, and conduct some experiments to
+evaluate its runtime performance.
+
+We want to do so in-line with Pytorch Geometric’s in-house framework for
+remote backends:
+
+
+.. figure:: ../_figures/remote_2.png
+ :align: center
+ :width: 100%
+
+
+
+As seen here, the GraphStore is used to store the neighbor relations
+between the nodes of the graph, whereas the FeatureStore is used to
+store the node and edge features in the graph.
+
+Let’s start by loading in a knowledge graph dataset for the sake of our
+experiment:
+
+.. code:: python
+
+ from torch_geometric.data import LargeGraphIndexer
+ from torch_geometric.datasets import WebQSPDataset
+ from itertools import chain
+
+ ds = WebQSPDataset(root='demo', limit=10)
+
+Let’s set up our set of questions and graph triplets:
+
+.. code:: python
+
+ questions = ds.raw_dataset['question']
+ questions
+
+ >>> ['what is the name of justin bieber brother',
+ 'what character did natalie portman play in star wars',
+ 'what country is the grand bahama island in',
+ 'what kind of money to take to bahamas',
+ 'what character did john noble play in lord of the rings',
+ 'who does joakim noah play for',
+ 'where are the nfl redskins from',
+ 'where did saki live',
+ 'who did draco malloy end up marrying',
+ 'which countries border the us']
+
+
+ ds.raw_dataset[:10]['graph'][0][:10]
+
+
+ >>> [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'],
+ ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'],
+ ['Somebody to Love', 'music.recording.contributions', 'm.0rqp4h0'],
+ ['Rudolph Valentino', 'freebase.valuenotation.is_reviewed', 'Place of birth'],
+ ['Ice Cube', 'broadcast.artist.content', '.977 The Hits Channel'],
+ ['Colbie Caillat', 'broadcast.artist.content', 'Hot Wired Radio'],
+ ['Stephen Melton', 'people.person.nationality', 'United States of America'],
+ ['Record producer',
+ 'music.performance_role.regular_performances',
+ 'm.012m1vf1'],
+ ['Justin Bieber', 'award.award_winner.awards_won', 'm.0yrkc0l'],
+ ['1.FM Top 40', 'broadcast.content.artist', 'Geri Halliwell']]
+
+
+ all_triplets = chain.from_iterable((row['graph'] for row in ds.raw_dataset))
+
+With these questions and triplets, we want to:
+1. Consolidate all the relations in these triplets into a Knowledge Graph
+2. Create a FeatureStore that encodes all the nodes and edges in the knowledge graph
+3. Create a GraphStore that encodes all the edge indices in the knowledge graph
+
+In order to create a remote backend, we need to define a FeatureStore
+and GraphStore locally, as well as a method for initializing its state
+from triplets. The code methods used in this tutorial can be found in
+`examples/llm_plus_gnn`.
+
+.. code:: python
+
+ from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet
+ from rag_construction_utils import create_remote_backend_from_triplets, RemoteGraphBackendLoader
+
+ # We define this GraphStore to sample the neighbors of a node locally.
+ # Ideally for a real remote backend, this interface would be replaced with an API to a Graph DB, such as Neo4j.
+ from rag_graph_store import NeighborSamplingRAGGraphStore
+
+ # We define this FeatureStore to encode the nodes and edges locally, and perform appoximate KNN when indexing.
+ # Ideally for a real remote backend, this interface would be replaced with an API to a vector DB, such as Pinecone.
+ from rag_feature_store import SentenceTransformerFeatureStore
+
+.. code:: python
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = SentenceTransformer(model_name="sentence-transformers/all-roberta-large-v1").to(device)
+
+ backend_loader: RemoteGraphBackendLoader = create_remote_backend_from_triplets(
+ triplets=all_triplets, # All the triplets to insert into the backend
+ node_embedding_model=model, # Embedding model to process triplets with
+ node_method_to_call="encode", # This method will encode the nodes/edges with 'model.encode' in this case.
+ path="backend", # Save path
+ pre_transform=preprocess_triplet, # Preprocessing function to apply to triplets before invoking embedding model.
+ node_method_kwargs={"batch_size": 256}, # Keyword arguments to pass to the node_method_to_call.
+ graph_db=NeighborSamplingRAGGraphStore, # Graph Store to use
+ feature_db=SentenceTransformerFeatureStore # Feature Store to use
+ )
+ # This loader saves a copy of the processed data locally to be transformed into a graphstore and featurestore when load() is called.
+ feature_store, graph_store = backend_loader.load()
+
+Now that we have initialized our remote backends, we can now retrieve
+from them using a Loader to query the backends, as shown in this
+diagram:
+
+
+.. figure:: ../_figures/remote_3.png
+ :align: center
+ :width: 100%
+
+
+
+.. code:: python
+
+ from torch_geometric.loader import RAGQueryLoader
+
+ query_loader = RAGQueryLoader(
+ data=(feature_store, graph_store), # Remote Rag Graph Store and Feature Store
+ # Arguments to pass into the seed node/edge retrieval methods for the FeatureStore.
+ # In this case, it's k for the KNN on the nodes and edges.
+ seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10},
+ # Arguments to pass into the GraphStore's Neighbor sampling method.
+ # In this case, the GraphStore implements a NeighborLoader, so it takes the same arguments.
+ sampler_kwargs={"num_neighbors": [40]*3},
+ # Arguments to pass into the FeatureStore's feature loading method.
+ loader_kwargs={},
+ # An optional local transform that can be applied on the returned subgraph.
+ local_filter=None,
+ )
+
+To make better sense of this loader’s arguments, let’s take a closer
+look at the retrieval process for a remote backend:
+
+
+.. figure:: ../_figures/remote_backend.svg
+ :align: center
+ :width: 100%
+
+
+
+As we see here, there are 3 important steps to any remote backend
+procedure for graphs:
+1. Retrieve the seed nodes and edges to begin our retrieval process from.
+2. Traverse the graph neighborhood of the seed nodes/edges to gather local context.
+3. Fetch the features associated with the subgraphs obtained from the traversal.
+
+We can see that our Query Loader construction allows us to specify
+unique hyperparameters for each unique step in this retrieval.
+
+Now we can submit our queries to the remote backend to retrieve our
+subgraphs:
+
+.. code:: python
+
+ sub_graphs = []
+ for q in tqdm.tqdm(questions):
+ sub_graphs.append(query_loader.query(q))
+
+
+ sub_graphs[0]
+
+ >>> Data(x=[2251, 1024], edge_index=[2, 7806], edge_attr=[7806, 1024], node_idx=[2251], edge_idx=[7806])
+
+
+
+These subgraphs are now retrieved using a different retrieval method
+when compared to the original WebQSP dataset. Can we compare the
+properties of this method to the original WebQSPDataset’s retrieval
+method? Let’s compare some basics properties of the subgraphs:
+
+.. code:: python
+
+ def _eidx_helper(subg: Data, ground_truth: Data):
+ subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
+ if isinstance(subg_eidx, torch.Tensor):
+ subg_eidx = subg_eidx.tolist()
+ if isinstance(gt_eidx, torch.Tensor):
+ gt_eidx = gt_eidx.tolist()
+ subg_e = set(subg_eidx)
+ gt_e = set(gt_eidx)
+ return subg_e, gt_e
+ def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ total_e = set(range(num_edges))
+ tp = len(subg_e & gt_e)
+ tn = len(total_e-(subg_e | gt_e))
+ return (tp+tn)/num_edges
+ def check_retrieval_precision(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(subg_e)
+ def check_retrieval_recall(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(gt_e)
+
+
+ ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet)
+ num_edges = len(ds.indexer._edges)
+
+
+ for subg, ground_truth in tqdm.tqdm(zip((query_loader.query(q) for q in questions), ground_truth_graphs)):
+ print(f"Size: {len(subg.x)}, Ground Truth Size: {len(ground_truth.x)}, Accuracy: {check_retrieval_accuracy(subg, ground_truth, num_edges)}, Precision: {check_retrieval_precision(subg, ground_truth)}, Recall: {check_retrieval_recall(subg, ground_truth)}")
+
+ >>> Size: 2193, Ground Truth Size: 1709, Accuracy: 0.6636780705203827, Precision: 0.22923807012918535, Recall: 0.1994037381034285
+ >>> Size: 2682, Ground Truth Size: 1251, Accuracy: 0.7158736400576746, Precision: 0.10843513670738801, Recall: 0.22692963233503774
+ >>> Size: 2087, Ground Truth Size: 1285, Accuracy: 0.7979813868134749, Precision: 0.0547879177377892, Recall: 0.15757855822550831
+ >>> Size: 2975, Ground Truth Size: 1988, Accuracy: 0.6956088609254162, Precision: 0.14820555621795636, Recall: 0.21768826619964973
+ >>> Size: 2594, Ground Truth Size: 633, Accuracy: 0.78849128326124, Precision: 0.04202616198163095, Recall: 0.2032301480484522
+ >>> Size: 2462, Ground Truth Size: 1044, Accuracy: 0.7703499803381832, Precision: 0.07646643109540636, Recall: 0.19551861221539574
+ >>> Size: 2011, Ground Truth Size: 1382, Accuracy: 0.7871804954777821, Precision: 0.10117783355860205, Recall: 0.13142713819914723
+ >>> Size: 2011, Ground Truth Size: 1052, Accuracy: 0.802831301612269, Precision: 0.06452691407556001, Recall: 0.16702726092600606
+ >>> Size: 2892, Ground Truth Size: 1012, Accuracy: 0.7276182985974571, Precision: 0.10108615156751419, Recall: 0.20860927152317882
+ >>> Size: 1817, Ground Truth Size: 1978, Accuracy: 0.7530475815965395, Precision: 0.1677807486631016, Recall: 0.11696178937558248
+
+
+
+Note that, since we’re only comparing the results of 10 graphs here,
+this retrieval algorithm is not taking into account the full corpus of
+nodes in the dataset. If you want to see a full example, look at
+``rag_generate.py``, or ``rag_generate_multihop.py`` These examples
+generate datasets for the entirety of the WebQSP dataset, or the
+WikiData Multihop datasets that are discussed in Section 0.
+
+Evaluating Runtime Performance
+------------------------------
+
+Pytorch Geometric provides multiple methods for evalutaing runtime
+performance. In this notebook, we utilize NVTX to profile the different
+components of our RAG Query Loader.
+
+The method ``nvtxit`` allows for profiling the utilization and timings
+of any methods that get wrapped by it in a Python script.
+
+To see an example of this, check out
+``nvtx_examples/nvtx_rag_backend_example.py``.
+
+This script mirrors this notebook’s functionality, but notably, it
+includes the following code snippet:
+
+.. code:: python
+
+ # Patch FeatureStore and GraphStore
+
+ SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes)
+ SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges)
+ SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph)
+ NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph)
+ rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query)
+
+Importantly, this snippet wraps the methods of FeatureStore, GraphStore,
+and the Query method from QueryLoader so that it will be recognized as a
+unique frame in NVTX.
+
+This can be executed by the included shell script ``nvtx_run.sh``:
+
+.. code:: bash
+
+ ...
+
+ # Get the base name of the Python file
+ python_file=$(basename "$1")
+
+ # Run nsys profile on the Python file
+ nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1"
+
+ echo "Profile data saved as profile_${python_file%.py}.nsys-rep"
+
+The generated resulting ``.nsys-rep`` file can be visualized using tools
+like Nsight Systems or Nsight Compute, that can show the relative
+timings of the FeatureStore, GraphStore, and QueryLoader methods.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 1c67eeddeec6..02315461c6ec 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -44,6 +44,7 @@ In addition, it consists of easy-to-use mini-batch loaders for operating on many
advanced/remote
advanced/graphgym
advanced/cpu_affinity
+ advanced/rag
.. toctree::
:maxdepth: 1
diff --git a/examples/README.md b/examples/README.md
index 64fc2dffdcdd..af324f09f79d 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -14,9 +14,9 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see
- [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges.
- [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph).
-For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile).
+For examples on co-training LLM with GNN, see examples and README under [`examples/llm_plus_gnn`](./llm_plus_gnn).
-For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu).
+For examples on using `torch.compile`, see examples and README under [`examples/compile`](./compile).
For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero).
diff --git a/examples/llm/README.md b/examples/llm/README.md
index f1f01428d991..5177e501d610 100644
--- a/examples/llm/README.md
+++ b/examples/llm/README.md
@@ -1,5 +1,8 @@
# Examples for Co-training LLMs and GNNs
-| Example | Description |
-| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
+| Example | Description |
+| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
+| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
+| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
+| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py
index 48b012917553..2844eb11ff3b 100644
--- a/examples/llm/g_retriever.py
+++ b/examples/llm/g_retriever.py
@@ -8,25 +8,52 @@
`pip install datasets transformers pcst_fast sentencepiece accelerate`
"""
import argparse
+import gc
import math
-import os.path as osp
+import multiprocessing as mp
import re
import time
+from typing import Any, Callable, Dict, List, Type
import pandas as pd
import torch
+import torch.nn as nn
from torch import Tensor
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from torch_geometric import seed_everything
+from torch_geometric.data import Dataset
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM
-def compute_metrics(eval_output):
+def _detect_hallucinate(inp):
+ pred, label = inp
+ try:
+ split_pred = pred.split('[/s]')[0].strip().split('|')
+ correct_hit = len(re.findall(split_pred[0], label)) > 0
+ correct_hit = correct_hit or any(
+ [label_i in pred.lower() for label_i in label.split('|')])
+ hallucination = not correct_hit
+ return hallucination
+ except: # noqa
+ return "skip"
+
+
+def detect_hallucinate(pred_batch, label_batch):
+ r"""An approximation for the unsolved task of detecting hallucinations.
+ We define a hallucination as an output that contains no instances of
+ acceptable label.
+ """
+ with mp.Pool(len(pred_batch)) as p:
+ res = p.map(_detect_hallucinate, zip(pred_batch, label_batch))
+ return res
+
+
+def compute_accuracy(eval_output) -> float:
df = pd.concat([pd.DataFrame(d) for d in eval_output])
all_hit = []
all_precision = []
@@ -68,6 +95,12 @@ def compute_metrics(eval_output):
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')
+ return hit
+
+
+def compute_n_parameters(model: torch.nn.Module) -> int:
+ return sum([p.numel() for p in model.parameters() if p.requires_grad])
+
def save_params_dict(model, save_path):
state_dict = model.state_dict()
@@ -112,6 +145,9 @@ def train(
lr,
checkpointing=False,
tiny_llama=False,
+ model=None,
+ dataset=None,
+ model_save_name=None,
):
def adjust_learning_rate(param_group, LR, epoch):
# Decay the learning rate with half-cycle cosine after warmup
@@ -127,13 +163,19 @@ def adjust_learning_rate(param_group, LR, epoch):
return lr
start_time = time.time()
- path = osp.dirname(osp.realpath(__file__))
- path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
- train_dataset = WebQSPDataset(path, split='train')
- val_dataset = WebQSPDataset(path, split='val')
- test_dataset = WebQSPDataset(path, split='test')
-
seed_everything(42)
+ if dataset is None:
+ dataset = WebQSPDataset()
+ gc.collect()
+ elif not isinstance(dataset, Dataset) and callable(dataset):
+ dataset = dataset()
+ gc.collect()
+ idx_split = dataset.split_idxs
+
+ # Step 1: Build Node Classification Dataset
+ train_dataset = [dataset[i] for i in idx_split['train']]
+ val_dataset = [dataset[i] for i in idx_split['val']]
+ test_dataset = [dataset[i] for i in idx_split['test']]
train_loader = DataLoader(train_dataset, batch_size=batch_size,
drop_last=True, pin_memory=True, shuffle=True)
@@ -142,24 +184,28 @@ def adjust_learning_rate(param_group, LR, epoch):
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)
- gnn = GAT(
- in_channels=1024,
- hidden_channels=hidden_channels,
- out_channels=1024,
- num_layers=num_gnn_layers,
- heads=4,
- )
- if tiny_llama:
- llm = LLM(
- model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
- num_params=1,
+ if model is None:
+ gc.collect()
+ gnn = GAT(
+ in_channels=1024,
+ hidden_channels=hidden_channels,
+ out_channels=1024,
+ num_layers=num_gnn_layers,
+ heads=4,
)
- model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048)
- else:
- llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
- model = GRetriever(llm=llm, gnn=gnn)
+ if tiny_llama:
+ llm = LLM(
+ model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
+ num_params=1,
+ )
+ model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048)
+ else:
+ llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
+ model = GRetriever(llm=llm, gnn=gnn)
+
+ if model_save_name is None:
+ model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm'
- model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm'
params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{
@@ -240,10 +286,376 @@ def adjust_learning_rate(param_group, LR, epoch):
eval_output.append(eval_data)
progress_bar_test.update(1)
+ # Step 6 Post-processing & compute metrics
compute_metrics(eval_output)
print(f"Total Training Time: {time.time() - start_time:2f}s")
save_params_dict(model, f'{model_save_name}.pt')
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')
+ print("Done!")
+ return prep_time, dataset, eval_output
+
+
+def _eval_hallucinations_on_loader(outs, loader, eval_batch_size):
+ model_save_list = []
+ model_preds = []
+ for out in outs:
+ model_preds += out['pred']
+ for i, batch in enumerate(loader):
+ correct_answer = batch.label
+
+ model_pred = model_preds[i * eval_batch_size:(i + 1) * eval_batch_size]
+ model_hallucinates = detect_hallucinate(model_pred, correct_answer)
+ model_save_list += [tup for tup in zip(model_pred, model_hallucinates)]
+ return model_save_list
+
+
+def benchmark_models(models: List[Type[nn.Module]], model_names: List[str],
+ model_kwargs: List[Dict[str, Any]], dataset: Dataset,
+ lr: float, epochs: int, batch_size: int,
+ eval_batch_size: int, loss_fn: Callable,
+ inference_fn: Callable, skip_LLMs: bool = True,
+ tiny_llama: bool = False, checkpointing: bool = True,
+ force: bool = False, root_dir='.'):
+ """Utility function for creating a model benchmark for GRetriever that
+ grid searches over hyperparameters. Produces a DataFrame containing
+ metrics for each model.
+
+ Args:
+ models (List[Type[nn.Module]]): Models to be benchmarked.
+ model_names (List[str]): Name of save files for model checkpoints
+ model_kwargs (List[Dict[str, Any]]): Parameters to use for each
+ particular model.
+ dataset (Dataset): Input dataset to train on.
+ lr (float): Learning rate
+ epochs (int): Number of epochs
+ batch_size (int): Batch size for training
+ eval_batch_size (int): Batch size for eval. Also determines
+ hallucination detection concurrancy.
+ loss_fn (Callable): Loss function
+ inference_fn (Callable): Inference function
+ skip_LLMs (bool, optional): Whether to skip LLM-only runs.
+ Defaults to True.
+ tiny_llama (bool, optional): Whether to use tiny llama as LLM.
+ Defaults to False.
+ checkpointing (bool, optional): Whether to checkpoint models.
+ Defaults to True.
+ force (bool, optional): Whether to rerun already existing results.
+ Defaults to False.
+ root_dir (str, optional): Dir to save results and checkpoints in.
+ Defaults to '.'.
+ """
+ model_log: Dict[str, Dict[str, Any]] = dict()
+ idx_split = dataset.split_idxs
+ test_dataset = [dataset[i] for i in idx_split['test']]
+ loader = DataLoader(test_dataset, batch_size=eval_batch_size,
+ drop_last=False, pin_memory=True, shuffle=False)
+
+ if not skip_LLMs:
+ if tiny_llama:
+ pure_llm = LLM(
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1",
+ num_params=1,
+ )
+ else:
+ pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf",
+ num_params=7)
+
+ if force or not path.exists(root_dir + "/pure_llm_model_log.pt"):
+ model_log["pure_llm"] = dict()
+
+ pure_preds = []
+ for batch in tqdm(loader):
+ pure_llm_preds = pure_llm.inference(batch.question, batch.desc,
+ max_tokens=256)
+ pure_preds += pure_llm_preds
+ pure_preds = [{"pred": pred} for pred in pure_preds]
+
+ model_log["pure_llm"]["preds"] = pure_preds
+ model_log["pure_llm"]["hallucinates_list"] = \
+ _eval_hallucinations_on_loader(pure_preds, loader,
+ eval_batch_size)
+ model_log["pure_llm"]["n_params"] = compute_n_parameters(pure_llm)
+ torch.save(model_log["pure_llm"],
+ root_dir + "/pure_llm_model_log.pt")
+ else:
+ model_log["pure_llm"] = \
+ torch.load(root_dir+"/pure_llm_model_log.pt")
+
+ # LORA
+ if force or not path.exists(root_dir + "/tuned_llm_model_log.pt"):
+ model_log["tuned_llm"] = dict()
+ since = time.time()
+ gc.collect()
+ prep_time, _, lora_eval_outs = train(since, epochs, None, None,
+ batch_size, eval_batch_size,
+ lr, loss_fn, inference_fn,
+ model=pure_llm,
+ dataset=dataset)
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ gc.collect()
+ e2e_time = round(time.time() - since, 2)
+ model_log["tuned_llm"]["prep_time"] = prep_time
+ model_log["tuned_llm"]["e2e_time"] = e2e_time
+ model_log["tuned_llm"]["eval_output"] = lora_eval_outs
+ print("E2E time (e2e_time) =", e2e_time, "seconds")
+ print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds")
+
+ model_log["tuned_llm"]["hallucinates_list"] = \
+ _eval_hallucinations_on_loader(lora_eval_outs, loader,
+ eval_batch_size)
+ model_log["tuned_llm"]["n_params"] = compute_n_parameters(pure_llm)
+ torch.save(model_log["tuned_llm"],
+ root_dir + "/tuned_llm_model_log.pt")
+ else:
+ model_log["tuned_llm"] = \
+ torch.load(root_dir+"/tuned_llm_model_log.pt")
+
+ del pure_llm
+ gc.collect()
+
+ # All other models
+ for name, Model, kwargs in zip(model_names, models, model_kwargs):
+ model_log[name] = dict()
+ train_model = True
+ if path.exists(root_dir + f"/{name}.pt") and not force:
+ print(f"Model {name} appears to already exist.")
+ print("Would you like to retrain?")
+ train_model = str(input("(y/n):")).lower() == "y"
+
+ if train_model:
+ since = time.time()
+ gc.collect()
+ model = Model(**kwargs)
+ prep_time, _, model_eval_outs = train(
+ since=since, num_epochs=epochs, hidden_channels=None,
+ num_gnn_layers=None, batch_size=batch_size,
+ eval_batch_size=eval_batch_size, lr=lr, loss_fn=loss_fn,
+ inference_fn=inference_fn, checkpointing=checkpointing,
+ tiny_llama=tiny_llama, dataset=dataset,
+ model_save_name=root_dir + '/' + name, model=model)
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ gc.collect()
+ e2e_time = round(time.time() - since, 2)
+ model_log[name]["prep_time"] = prep_time
+ model_log[name]["e2e_time"] = e2e_time
+ model_log[name]["eval_output"] = model_eval_outs
+ print("E2E time (e2e_time) =", e2e_time, "seconds")
+ print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds")
+ model_log[name]["n_params"] = compute_n_parameters(model)
+ del model
+ gc.collect()
+ else:
+ model_eval_outs = torch.load(root_dir + f"/{name}_eval_outs.pt")
+
+ # Calculate Hallucinations
+ skip_hallucination_detection = False
+
+ if path.exists(root_dir + f"/{name}_model_log.pt") and not force:
+ print(f"Saved outputs for {name} have been found.")
+ print("Would you like to redo?")
+ user_input = str(input("(y/n):")).lower()
+ skip_hallucination_detection = user_input != "y"
+
+ if not skip_hallucination_detection:
+ model_save_list = _eval_hallucinations_on_loader(
+ model_eval_outs, loader, eval_batch_size)
+
+ model_log[name]["hallucinates_list"] = model_save_list
+ torch.save(model_log[name], root_dir + f"/{name}_model_log.pt")
+ else:
+ model_log[name]["hallucinates_list"] = \
+ torch.load(
+ root_dir+f"/{name}_model_log.pt"
+ )["hallucinates_list"]
+
+ hal_dict = {
+ k: [tup[1] for tup in v["hallucinates_list"]]
+ for (k, v) in model_log.items()
+ }
+ hallucinates_df = pd.DataFrame(hal_dict).astype(str)
+ hallucinates_df = hallucinates_df.apply(pd.Series.value_counts).transpose()
+ hallucinates_df['e2e_time'] = pd.Series(
+ {k: v.get('e2e_time')
+ for (k, v) in model_log.items()})
+ hallucinates_df['n_params'] = pd.Series(
+ {k: v.get('n_params')
+ for (k, v) in model_log.items()})
+ print(hallucinates_df)
+ hallucinates_df.to_csv(root_dir + "/hallucinates_df.csv", index=False)
+
+
+def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size,
+ eval_batch_size, loss_fn, inference_fn,
+ skip_pretrained_LLM=False, tiny_llama=False):
+ if not skip_pretrained_LLM:
+ print("First comparing against a pretrained LLM...")
+ # Step 1: Define a single batch size test loader
+ idx_split = dataset.split_idxs
+ test_dataset = [dataset[i] for i in idx_split['test']]
+ # batch size 1 loader for simplicity
+ loader = DataLoader(test_dataset, batch_size=eval_batch_size,
+ drop_last=False, pin_memory=True, shuffle=False)
+ if tiny_llama:
+ pure_llm = LLM(
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1",
+ num_params=1,
+ )
+ else:
+ pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf",
+ num_params=7)
+ if path.exists("demo_save_dict.pt"):
+ print("Saved outputs for the first step of the demo found.")
+ print("Would you like to redo?")
+ user_input = str(input("(y/n):")).lower()
+ skip_step_one = user_input == "n"
+ else:
+ skip_step_one = False
+
+ if not skip_step_one:
+ gnn_llm_hallucin_sum = 0
+ pure_llm_hallucin_sum = 0
+ gnn_save_list = []
+ untuned_llm_save_list = []
+ gnn_llm_preds = []
+ for out in gnn_llm_eval_outs:
+ gnn_llm_preds += out['pred']
+ if skip_pretrained_LLM:
+ print("Checking GNN+LLM for hallucinations...")
+ else:
+ print(
+ "Checking pretrained LLM vs trained GNN+LLM for hallucinations..." # noqa
+ )
+ for i, batch in enumerate(tqdm(loader)):
+ question = batch.question
+ correct_answer = batch.label
+
+ gnn_llm_pred = gnn_llm_preds[i * eval_batch_size:(i + 1) *
+ eval_batch_size]
+ gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred,
+ correct_answer)
+ gnn_save_list += [
+ tup for tup in zip(gnn_llm_pred, gnn_llm_hallucinates)
+ ]
+
+ if not skip_pretrained_LLM:
+ # GNN+LLM only using 32 tokens to answer.
+ # Allow more output tokens for untrained LLM
+ pure_llm_pred = pure_llm.inference(batch.question, batch.desc,
+ max_tokens=256)
+ pure_llm_hallucinates = detect_hallucinate(
+ pure_llm_pred, correct_answer)
+ else:
+ pure_llm_pred = [''] * len(gnn_llm_hallucinates)
+ pure_llm_hallucinates = [False] * len(gnn_llm_hallucinates)
+ untuned_llm_save_list += [
+ tup for tup in zip(pure_llm_pred, pure_llm_hallucinates)
+ ]
+
+ for gnn_llm_hal, pure_llm_hal in zip(gnn_llm_hallucinates,
+ pure_llm_hallucinates):
+ if gnn_llm_hal == "skip" or pure_llm_hal == "skip": # noqa
+ # skipping when hallucination is hard to eval
+ continue
+ gnn_llm_hallucin_sum += int(gnn_llm_hal)
+ pure_llm_hallucin_sum += int(pure_llm_hal)
+ if not skip_pretrained_LLM:
+ print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum)
+ print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum)
+ percent = 100.0 * round(
+ 1 - (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2)
+ print(f"GNN reduces pretrained LLM hallucinations by: ~{percent}%")
+ print("Note: hallucinations detected by regex hence the ~")
+ print("Now we see how the LLM compares when finetuned...")
+ print("Saving outputs of GNN+LLM and pretrained LLM...")
+ save_dict = {
+ "gnn_save_list": gnn_save_list,
+ "untuned_llm_save_list": untuned_llm_save_list,
+ "gnn_llm_hallucin_sum": gnn_llm_hallucin_sum,
+ "pure_llm_hallucin_sum": pure_llm_hallucin_sum
+ }
+ torch.save(save_dict, "demo_save_dict.pt")
+ print("Done!")
+ else:
+ save_dict = torch.load("demo_save_dict.pt")
+ gnn_save_list = save_dict["gnn_save_list"]
+ untuned_llm_save_list = save_dict["untuned_llm_save_list"]
+ gnn_llm_hallucin_sum = save_dict["gnn_llm_hallucin_sum"]
+ pure_llm_hallucin_sum = save_dict["pure_llm_hallucin_sum"]
+
+ trained_llm_hallucin_sum = 0
+ untuned_llm_hallucin_sum = pure_llm_hallucin_sum
+ final_prnt_str = ""
+ if path.exists("llm.pt") and path.exists("llm_eval_outs.pt"):
+ print("Existing finetuned LLM found.")
+ print("Would you like to retrain?")
+ user_input = str(input("(y/n):")).lower()
+ retrain = user_input == "y"
+ else:
+ retrain = True
+ if retrain:
+ print("Finetuning LLM...")
+ since = time.time()
+ _, _, pure_llm_eval_outputs = train(since, epochs, None, None,
+ batch_size, eval_batch_size, lr,
+ loss_fn, inference_fn,
+ model=pure_llm, dataset=dataset)
+ e2e_time = round(time.time() - since, 2)
+ print("E2E time (e2e_time) =", e2e_time, "seconds")
+ else:
+ pure_llm_eval_outputs = torch.load("llm_eval_outs.pt")
+ pure_llm_preds = []
+ for out in pure_llm_eval_outputs:
+ pure_llm_preds += out['pred']
+ print("Final comparison between all models...")
+ for i, batch in enumerate(tqdm(loader)):
+ question = batch.question
+ correct_answer = batch.label
+ gnn_llm_pred, gnn_llm_hallucinates = list(
+ zip(*gnn_save_list[i * eval_batch_size:(i + 1) * eval_batch_size]))
+ untuned_llm_pred, untuned_llm_hallucinates = list(
+ zip(*untuned_llm_save_list[i * eval_batch_size:(i + 1) *
+ eval_batch_size]))
+ pure_llm_pred = pure_llm_preds[i * eval_batch_size:(i + 1) *
+ eval_batch_size]
+ pure_llm_hallucinates = detect_hallucinate(pure_llm_pred,
+ correct_answer)
+ for j in range(len(gnn_llm_pred)):
+ if skip_pretrained_LLM:
+ # we did not check the untrained LLM, so do not decide to demo
+ # based on this.
+ # HACK
+ untuned_llm_hallucinates = {j: True}
+ if gnn_llm_hallucinates[j] == "skip" or untuned_llm_hallucinates[
+ j] == "skip" or pure_llm_hallucinates[j] == "skip":
+ continue
+ trained_llm_hallucin_sum += int(pure_llm_hallucinates[j])
+ if untuned_llm_hallucinates[j] and pure_llm_hallucinates[
+ j] and not gnn_llm_hallucinates[j]: # noqa
+ final_prnt_str += "Prompt: '" + question[j] + "'\n"
+ final_prnt_str += "Label: '" + correct_answer[j] + "'\n"
+ if not skip_pretrained_LLM:
+ final_prnt_str += "Untuned LLM Output: '" \
+ + untuned_llm_pred[j] + "'\n" # noqa
+ final_prnt_str += "Tuned LLM Output: '" + pure_llm_pred[
+ j] + "'\n"
+ final_prnt_str += "GNN+LLM Output: '" + gnn_llm_pred[j] + "'\n"
+ final_prnt_str += "\n" + "#" * 20 + "\n\n"
+ if not skip_pretrained_LLM:
+ print("Total untuned LLM Hallucinations:", untuned_llm_hallucin_sum)
+ print("Total tuned LLM Hallucinations:", trained_llm_hallucin_sum)
+ print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum)
+ if not skip_pretrained_LLM:
+ percent = 100.0 * round(
+ 1 - (gnn_llm_hallucin_sum / untuned_llm_hallucin_sum), 2)
+ print(f"GNN reduces untuned LLM hallucinations by: ~{percent}%")
+ tuned_percent = 100.0 * round(
+ 1 - (gnn_llm_hallucin_sum / trained_llm_hallucin_sum), 2)
+ print(f"GNN reduces tuned LLM hallucinations by: ~{tuned_percent}%")
+ print("Note: hallucinations detected by regex hence the ~")
+ print("Potential instances where GNN solves the hallucinations of LLM:")
+ print(final_prnt_str)
if __name__ == '__main__':
@@ -256,6 +668,9 @@ def adjust_learning_rate(param_group, LR, epoch):
parser.add_argument('--eval_batch_size', type=int, default=16)
parser.add_argument('--checkpointing', action='store_true')
parser.add_argument('--tiny_llama', action='store_true')
+ parser.add_argument(
+ "--skip_pretrained_llm_eval", action="store_true",
+ help="This flag will skip the evaluation of the pretrained LLM.")
args = parser.parse_args()
start_time = time.time()
diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md
new file mode 100644
index 000000000000..e072e6746b7c
--- /dev/null
+++ b/examples/llm/g_retriever_utils/README.md
@@ -0,0 +1,11 @@
+# Examples for LLM and GNN co-training
+
+| Example | Description |
+| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend |
+| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend |
+| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore |
+| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) |
+| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. |
+
+NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes.
diff --git a/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py
new file mode 100644
index 000000000000..76148cfc09e5
--- /dev/null
+++ b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py
@@ -0,0 +1,105 @@
+"""Used to benchmark the performance of an untuned/fine tuned LLM against
+GRetriever with various architectures and layer depths.
+"""
+# %%
+import argparse
+import sys
+
+import torch
+
+from torch_geometric.datasets import WebQSPDataset
+from torch_geometric.nn.models import GAT, MLP, GRetriever
+
+sys.path.append('..')
+from g_retriever import ( # noqa: E402
+ benchmark_models,
+ get_loss,
+ inference_step,
+)
+
+# %%
+parser = argparse.ArgumentParser(description="""Benchmarker for GRetriever
+NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models."""
+ )
+parser.add_argument("--hidden_channels", type=int, default=1024)
+parser.add_argument("--learning_rate", type=float, default=1e-5)
+parser.add_argument("--epochs", type=int, default=2)
+parser.add_argument("--batch_size", type=int, default=8)
+parser.add_argument("--eval_batch_size", type=int, default=16)
+parser.add_argument("--tiny_llama", action='store_true')
+
+parser.add_argument("--dataset_path", type=str, required=False)
+# Default to WebQSP split
+parser.add_argument("--num_train", type=int, default=2826)
+parser.add_argument("--num_val", type=int, default=246)
+parser.add_argument("--num_test", type=int, default=1628)
+
+args = parser.parse_args()
+
+# %%
+hidden_channels = args.hidden_channels
+lr = args.learning_rate
+epochs = args.epochs
+batch_size = args.batch_size
+eval_batch_size = args.eval_batch_size
+
+# %%
+if not args.dataset_path:
+ ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True)
+else:
+ # We just assume that the size of the dataset accomodates the
+ # train/val/test split, because checking may be expensive.
+ dataset = torch.load(args.dataset_path)
+
+ class MockDataset:
+ """Utility class to patch the fields in WebQSPDataset used by
+ GRetriever.
+ """
+ def __init__(self) -> None:
+ pass
+
+ @property
+ def split_idxs(self) -> dict:
+ # Imitates the WebQSP split method
+ return {
+ "train":
+ torch.arange(args.num_train),
+ "val":
+ torch.arange(args.num_val) + args.num_train,
+ "test":
+ torch.arange(args.num_test) + args.num_train + args.num_val,
+ }
+
+ def __getitem__(self, idx: int):
+ return dataset[idx]
+
+ ds = MockDataset()
+
+# %%
+model_names = []
+model_classes = []
+model_kwargs = []
+model_type = ["GAT", "MLP"]
+models = {"GAT": GAT, "MLP": MLP}
+# Use to vary the depth of the GNN model
+num_layers = [4]
+# Use to vary the number of LLM tokens reserved for GNN output
+num_tokens = [1]
+for m_type in model_type:
+ for n_layer in num_layers:
+ for n_tokens in num_tokens:
+ model_names.append(f"{m_type}_{n_layer}_{n_tokens}")
+ model_classes.append(GRetriever)
+ kwargs = dict(gnn_hidden_channels=hidden_channels,
+ num_gnn_layers=n_layer, gnn_to_use=models[m_type],
+ mlp_out_tokens=n_tokens)
+ if args.tiny_llama:
+ kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1'
+ kwargs['mlp_out_dim'] = 2048
+ kwargs['num_llm_params'] = 1
+ model_kwargs.append(kwargs)
+
+# %%
+benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs,
+ batch_size, eval_batch_size, get_loss, inference_step,
+ skip_LLMs=False, tiny_llama=args.tiny_llama, force=True)
diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py
new file mode 100644
index 000000000000..0f1c0e1b87ec
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_backend_utils.py
@@ -0,0 +1,224 @@
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+ Protocol,
+ Tuple,
+ Type,
+ runtime_checkable,
+)
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+
+from torch_geometric.data import (
+ FeatureStore,
+ GraphStore,
+ LargeGraphIndexer,
+ TripletLike,
+)
+from torch_geometric.data.large_graph_indexer import EDGE_RELATION
+from torch_geometric.distributed import (
+ LocalFeatureStore,
+ LocalGraphStore,
+ Partitioner,
+)
+from torch_geometric.typing import EdgeType, NodeType
+
+RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
+
+# TODO: Make everything compatible with Hetero graphs aswell
+
+
+# Adapted from LocalGraphStore
+@runtime_checkable
+class ConvertableGraphStore(Protocol):
+ @classmethod
+ def from_data(
+ cls,
+ edge_id: Tensor,
+ edge_index: Tensor,
+ num_nodes: int,
+ is_sorted: bool = False,
+ ) -> GraphStore:
+ ...
+
+ @classmethod
+ def from_hetero_data(
+ cls,
+ edge_id_dict: Dict[EdgeType, Tensor],
+ edge_index_dict: Dict[EdgeType, Tensor],
+ num_nodes_dict: Dict[NodeType, int],
+ is_sorted: bool = False,
+ ) -> GraphStore:
+ ...
+
+ @classmethod
+ def from_partition(cls, root: str, pid: int) -> GraphStore:
+ ...
+
+
+# Adapted from LocalFeatureStore
+@runtime_checkable
+class ConvertableFeatureStore(Protocol):
+ @classmethod
+ def from_data(
+ cls,
+ node_id: Tensor,
+ x: Optional[Tensor] = None,
+ y: Optional[Tensor] = None,
+ edge_id: Optional[Tensor] = None,
+ edge_attr: Optional[Tensor] = None,
+ ) -> FeatureStore:
+ ...
+
+ @classmethod
+ def from_hetero_data(
+ cls,
+ node_id_dict: Dict[NodeType, Tensor],
+ x_dict: Optional[Dict[NodeType, Tensor]] = None,
+ y_dict: Optional[Dict[NodeType, Tensor]] = None,
+ edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
+ edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
+ ) -> FeatureStore:
+ ...
+
+ @classmethod
+ def from_partition(cls, root: str, pid: int) -> FeatureStore:
+ ...
+
+
+class RemoteDataType(Enum):
+ DATA = auto()
+ PARTITION = auto()
+
+
+@dataclass
+class RemoteGraphBackendLoader:
+ """Utility class to load triplets into a RAG Backend."""
+ path: str
+ datatype: RemoteDataType
+ graph_store_type: Type[ConvertableGraphStore]
+ feature_store_type: Type[ConvertableFeatureStore]
+
+ def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
+ if self.datatype == RemoteDataType.DATA:
+ data_obj = torch.load(self.path)
+ graph_store = self.graph_store_type.from_data(
+ edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
+ num_nodes=data_obj.num_nodes)
+ feature_store = self.feature_store_type.from_data(
+ node_id=data_obj['node_id'], x=data_obj.x,
+ edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
+ elif self.datatype == RemoteDataType.PARTITION:
+ if pid is None:
+ assert pid is not None, \
+ "Partition ID must be defined for loading from a " \
+ + "partitioned store."
+ graph_store = self.graph_store_type.from_partition(self.path, pid)
+ feature_store = self.feature_store_type.from_partition(
+ self.path, pid)
+ else:
+ raise NotImplementedError
+ return (feature_store, graph_store)
+
+
+# TODO: make profilable
+def create_remote_backend_from_triplets(
+ triplets: Iterable[TripletLike], node_embedding_model: Module,
+ edge_embedding_model: Module | None = None,
+ graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
+ feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
+ node_method_to_call: str = "forward",
+ edge_method_to_call: str | None = None,
+ pre_transform: Callable[[TripletLike], TripletLike] | None = None,
+ path: str = '', n_parts: int = 1,
+ node_method_kwargs: Optional[Dict[str, Any]] = None,
+ edge_method_kwargs: Optional[Dict[str, Any]] = None
+) -> RemoteGraphBackendLoader:
+ """Utility function that can be used to create a RAG Backend from triplets.
+
+ Args:
+ triplets (Iterable[TripletLike]): Triplets to load into the RAG
+ Backend.
+ node_embedding_model (Module): Model to embed nodes into a feature
+ space.
+ edge_embedding_model (Module | None, optional): Model to embed edges
+ into a feature space. Defaults to the node model.
+ graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
+ use. Defaults to LocalGraphStore.
+ feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
+ class to use. Defaults to LocalFeatureStore.
+ node_method_to_call (str, optional): method to call for embeddings on
+ the node model. Defaults to "forward".
+ edge_method_to_call (str | None, optional): method to call for
+ embeddings on the edge model. Defaults to the node method.
+ pre_transform (Callable[[TripletLike], TripletLike] | None, optional):
+ optional preprocessing function for triplets. Defaults to None.
+ path (str, optional): path to save resulting stores. Defaults to ''.
+ n_parts (int, optional): Number of partitons to store in.
+ Defaults to 1.
+ node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass
+ into node encoding method. Defaults to None.
+ edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass
+ into edge encoding method. Defaults to None.
+
+ Returns:
+ RemoteGraphBackendLoader: Loader to load RAG backend from disk or
+ memory.
+ """
+ # Will return attribute errors for missing attributes
+ if not issubclass(graph_db, ConvertableGraphStore):
+ getattr(graph_db, "from_data")
+ getattr(graph_db, "from_hetero_data")
+ getattr(graph_db, "from_partition")
+ elif not issubclass(feature_db, ConvertableFeatureStore):
+ getattr(feature_db, "from_data")
+ getattr(feature_db, "from_hetero_data")
+ getattr(feature_db, "from_partition")
+
+ # Resolve callable methods
+ node_method_kwargs = node_method_kwargs \
+ if node_method_kwargs is not None else dict()
+
+ edge_embedding_model = edge_embedding_model \
+ if edge_embedding_model is not None else node_embedding_model
+ edge_method_to_call = edge_method_to_call \
+ if edge_method_to_call is not None else node_method_to_call
+ edge_method_kwargs = edge_method_kwargs \
+ if edge_method_kwargs is not None else node_method_kwargs
+
+ # These will return AttributeErrors if they don't exist
+ node_model = getattr(node_embedding_model, node_method_to_call)
+ edge_model = getattr(edge_embedding_model, edge_method_to_call)
+
+ indexer = LargeGraphIndexer.from_triplets(triplets,
+ pre_transform=pre_transform)
+
+ node_feats = node_model(indexer.get_node_features(), **node_method_kwargs)
+ indexer.add_node_feature('x', node_feats)
+
+ edge_feats = edge_model(
+ indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
+ **edge_method_kwargs)
+ indexer.add_edge_feature(new_feature_name="edge_attr",
+ new_feature_vals=edge_feats,
+ map_from_feature=EDGE_RELATION)
+
+ data = indexer.to_data(node_feature_name='x',
+ edge_feature_name='edge_attr')
+
+ if n_parts == 1:
+ torch.save(data, path)
+ return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
+ feature_db)
+ else:
+ partitioner = Partitioner(data=data, num_parts=n_parts, root=path)
+ partitioner.generate_partition()
+ return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
+ graph_db, feature_db)
diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py
new file mode 100644
index 000000000000..e01e9e59bb88
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_feature_store.py
@@ -0,0 +1,189 @@
+import gc
+from collections.abc import Iterable, Iterator
+from typing import Any, Dict, Optional, Type, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+from torchmetrics.functional import pairwise_cosine_similarity
+
+from torch_geometric.data import Data, HeteroData
+from torch_geometric.distributed import LocalFeatureStore
+from torch_geometric.nn.nlp import SentenceTransformer
+from torch_geometric.nn.pool import ApproxMIPSKNNIndex
+from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
+from torch_geometric.typing import InputEdges, InputNodes
+
+
+# NOTE: Only compatible with Homogeneous graphs for now
+class KNNRAGFeatureStore(LocalFeatureStore):
+ def __init__(self, enc_model: Type[Module],
+ model_kwargs: Optional[Dict[str,
+ Any]] = None, *args, **kwargs):
+ self.device = torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu")
+ self.enc_model = enc_model(*args, **kwargs).to(self.device)
+ self.enc_model.eval()
+ self.model_kwargs = \
+ model_kwargs if model_kwargs is not None else dict()
+ super().__init__()
+
+ @property
+ def x(self) -> Tensor:
+ return self.get_tensor(group_name=None, attr_name='x')
+
+ @property
+ def edge_attr(self) -> Tensor:
+ return self.get_tensor(group_name=(None, None), attr_name='edge_attr')
+
+ def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes:
+ result = next(self._retrieve_seed_nodes_batch([query], k_nodes))
+ gc.collect()
+ torch.cuda.empty_cache()
+ return result
+
+ def _retrieve_seed_nodes_batch(self, query: Iterable[Any],
+ k_nodes: int) -> Iterator[InputNodes]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ query_enc = self.enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device))
+ topk = min(k_nodes, len(self.x))
+ for q in prizes:
+ _, indices = torch.topk(q, topk, largest=True)
+ yield indices
+
+ def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges:
+ result = next(self._retrieve_seed_edges_batch([query], k_edges))
+ gc.collect()
+ torch.cuda.empty_cache()
+ return result
+
+ def _retrieve_seed_edges_batch(self, query: Iterable[Any],
+ k_edges: int) -> Iterator[InputEdges]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ query_enc = self.enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+
+ prizes = pairwise_cosine_similarity(query_enc,
+ self.edge_attr.to(self.device))
+ topk = min(k_edges, len(self.edge_attr))
+ for q in prizes:
+ _, indices = torch.topk(q, topk, largest=True)
+ yield indices
+
+ def load_subgraph(
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
+ ) -> Union[Data, HeteroData]:
+
+ if isinstance(sample, HeteroSamplerOutput):
+ raise NotImplementedError
+
+ # NOTE: torch_geometric.loader.utils.filter_custom_store can be used
+ # here if it supported edge features
+ node_id = sample.node
+ edge_id = sample.edge
+ edge_index = torch.stack((sample.row, sample.col), dim=0)
+ x = self.x[node_id]
+ edge_attr = self.edge_attr[edge_id]
+
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
+ node_idx=node_id, edge_idx=edge_id)
+
+
+# TODO: Refactor because composition >> inheritance
+
+
+def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor,
+ device: torch.device, batch_size: int = 2**20):
+ """Add new features to the existing KNN index in batches.
+
+ Args:
+ knn_index (ApproxMIPSKNNIndex): Index to add features to.
+ emb (Tensor): Embeddings to add.
+ device (torch.device): Device to store in
+ batch_size (int, optional): Batch size to iterate by.
+ Defaults to 2**20, which equates to 4GB if working with
+ 1024 dim floats.
+ """
+ for i in range(0, emb.size(0), batch_size):
+ if emb.size(0) - i >= batch_size:
+ emb_batch = emb[i:i + batch_size].to(device)
+ else:
+ emb_batch = emb[i:].to(device)
+ knn_index.add(emb_batch)
+
+
+class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore):
+ def __init__(self, enc_model: Type[Module],
+ model_kwargs: Optional[Dict[str,
+ Any]] = None, *args, **kwargs):
+ # TODO: Add kwargs for approx KNN to parameters here.
+ super().__init__(enc_model, model_kwargs, *args, **kwargs)
+ self.node_knn_index = None
+ self.edge_knn_index = None
+
+ def _retrieve_seed_nodes_batch(self, query: Iterable[Any],
+ k_nodes: int) -> Iterator[InputNodes]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ enc_model = self.enc_model.to(self.device)
+ query_enc = enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ del enc_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.node_knn_index is None:
+ self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100,
+ num_cells_to_visit=100,
+ bits_per_vector=4)
+ # Need to add in batches to avoid OOM
+ _add_features_to_knn_index(self.node_knn_index, self.x,
+ self.device)
+
+ output = self.node_knn_index.search(query_enc, k=k_nodes)
+ yield from output.index
+
+ def _retrieve_seed_edges_batch(self, query: Iterable[Any],
+ k_edges: int) -> Iterator[InputEdges]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ enc_model = self.enc_model.to(self.device)
+ query_enc = enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ del enc_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.edge_knn_index is None:
+ self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100,
+ num_cells_to_visit=100,
+ bits_per_vector=4)
+ # Need to add in batches to avoid OOM
+ _add_features_to_knn_index(self.edge_knn_index, self.edge_attr,
+ self.device)
+
+ output = self.edge_knn_index.search(query_enc, k=k_edges)
+ yield from output.index
+
+
+# TODO: These two classes should be refactored
+class SentenceTransformerFeatureStore(KNNRAGFeatureStore):
+ def __init__(self, *args, **kwargs):
+ kwargs['model_name'] = kwargs.get(
+ 'model_name', 'sentence-transformers/all-roberta-large-v1')
+ super().__init__(SentenceTransformer, *args, **kwargs)
+
+
+class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore):
+ def __init__(self, *args, **kwargs):
+ kwargs['model_name'] = kwargs.get(
+ 'model_name', 'sentence-transformers/all-roberta-large-v1')
+ super().__init__(SentenceTransformer, *args, **kwargs)
diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py
new file mode 100644
index 000000000000..c6895b453b0c
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_generate.py
@@ -0,0 +1,137 @@
+# %%
+import argparse
+from itertools import chain
+from typing import Tuple
+
+import pandas as pd
+import torch
+import tqdm
+from rag_backend_utils import create_remote_backend_from_triplets
+from rag_feature_store import SentenceTransformerFeatureStore
+from rag_graph_store import NeighborSamplingRAGGraphStore
+
+from torch_geometric.data import Data
+from torch_geometric.datasets import WebQSPDataset
+from torch_geometric.datasets.web_qsp_dataset import (
+ preprocess_triplet,
+ retrieval_via_pcst,
+)
+from torch_geometric.loader import RAGQueryLoader
+from torch_geometric.nn.nlp import SentenceTransformer
+
+# %%
+parser = argparse.ArgumentParser(description="""Generate new WebQSP subgraphs
+NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models."""
+ )
+# TODO: Add more arguments for configuring rag params
+parser.add_argument("--use_pcst", action="store_true")
+parser.add_argument("--num_samples", type=int, default=4700)
+parser.add_argument("--out_file", default="subg_results.pt")
+args = parser.parse_args()
+
+# %%
+ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True,
+ force_reload=True)
+
+# %%
+triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset)
+
+# %%
+questions = ds.raw_dataset['question']
+
+# %%
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = SentenceTransformer(
+ model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+# %%
+fs, gs = create_remote_backend_from_triplets(
+ triplets=triplets, node_embedding_model=model,
+ node_method_to_call="encode", path="backend",
+ pre_transform=preprocess_triplet, node_method_kwargs={
+ "batch_size": 256
+ }, graph_db=NeighborSamplingRAGGraphStore,
+ feature_db=SentenceTransformerFeatureStore).load()
+
+# %%
+
+
+def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
+ topk_e: int = 3,
+ cost_e: float = 0.5) -> Tuple[Data, str]:
+ q_emb = model.encode(query)
+ textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
+ out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
+ textual_edges, topk, topk_e, cost_e)
+ out_graph["desc"] = desc
+ return out_graph
+
+
+def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]:
+ textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
+ desc = (
+ textual_nodes.to_csv(index=False) + "\n" +
+ textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"]))
+ graph["desc"] = desc
+ return graph
+
+
+transform = apply_retrieval_via_pcst \
+ if args.use_pcst else apply_retrieval_with_text
+
+query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5},
+ seed_edges_kwargs={"k_edges": 5},
+ sampler_kwargs={"num_neighbors": [50] * 2},
+ local_filter=transform)
+
+
+# %%
+# Accuracy Metrics to be added to Profiler
+def _eidx_helper(subg: Data, ground_truth: Data):
+ subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
+ if isinstance(subg_eidx, torch.Tensor):
+ subg_eidx = subg_eidx.tolist()
+ if isinstance(gt_eidx, torch.Tensor):
+ gt_eidx = gt_eidx.tolist()
+ subg_e = set(subg_eidx)
+ gt_e = set(gt_eidx)
+ return subg_e, gt_e
+
+
+def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ total_e = set(range(num_edges))
+ tp = len(subg_e & gt_e)
+ tn = len(total_e - (subg_e | gt_e))
+ return (tp + tn) / num_edges
+
+
+def check_retrieval_precision(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(subg_e)
+
+
+def check_retrieval_recall(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(gt_e)
+
+
+# %%
+retrieval_stats = {"precision": [], "recall": [], "accuracy": []}
+subgs = []
+node_len = []
+edge_len = []
+for subg in tqdm.tqdm(query_loader.query(q) for q in questions):
+ subgs.append(subg)
+ node_len.append(subg['x'].shape[0])
+ edge_len.append(subg['edge_attr'].shape[0])
+
+for i, subg in enumerate(subgs):
+ subg['question'] = questions[i]
+ subg['label'] = ds[i]['label']
+
+pd.DataFrame.from_dict(retrieval_stats).to_csv(
+ args.out_file.split('.')[0] + '_metadata.csv')
+torch.save(subgs, args.out_file)
diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py
new file mode 100644
index 000000000000..48473f287233
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_graph_store.py
@@ -0,0 +1,107 @@
+from typing import Optional, Union
+
+import torch
+from torch import Tensor
+
+from torch_geometric.data import FeatureStore
+from torch_geometric.distributed import LocalGraphStore
+from torch_geometric.sampler import (
+ HeteroSamplerOutput,
+ NeighborSampler,
+ NodeSamplerInput,
+ SamplerOutput,
+)
+from torch_geometric.sampler.neighbor_sampler import NumNeighborsType
+from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes
+
+
+class NeighborSamplingRAGGraphStore(LocalGraphStore):
+ def __init__(self, feature_store: Optional[FeatureStore] = None,
+ num_neighbors: NumNeighborsType = [1], **kwargs):
+ self.feature_store = feature_store
+ self._num_neighbors = num_neighbors
+ self.sample_kwargs = kwargs
+ self._sampler_is_initialized = False
+ super().__init__()
+
+ def _init_sampler(self):
+ if self.feature_store is None:
+ raise AttributeError("Feature store not registered yet.")
+ self.sampler = NeighborSampler(data=(self.feature_store, self),
+ num_neighbors=self._num_neighbors,
+ **self.sample_kwargs)
+ self._sampler_is_initialized = True
+
+ def register_feature_store(self, feature_store: FeatureStore):
+ self.feature_store = feature_store
+ self._sampler_is_initialized = False
+
+ def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool:
+ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
+ self._sampler_is_initialized = False
+ return ret
+
+ @property
+ def edge_index(self):
+ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
+
+ def put_edge_index(self, edge_index: EdgeTensorType, *args,
+ **kwargs) -> bool:
+ ret = super().put_edge_index(edge_index, *args, **kwargs)
+ # HACK
+ self.edge_idx_args = args
+ self.edge_idx_kwargs = kwargs
+ self._sampler_is_initialized = False
+ return ret
+
+ @property
+ def num_neighbors(self):
+ return self._num_neighbors
+
+ @num_neighbors.setter
+ def num_neighbors(self, num_neighbors: NumNeighborsType):
+ self._num_neighbors = num_neighbors
+ if hasattr(self, 'sampler'):
+ self.sampler.num_neighbors = num_neighbors
+
+ def sample_subgraph(
+ self, seed_nodes: InputNodes, seed_edges: InputEdges,
+ num_neighbors: Optional[NumNeighborsType] = None
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
+ """Sample the graph starting from the given nodes and edges using the
+ in-built NeighborSampler.
+
+ Args:
+ seed_nodes (InputNodes): Seed nodes to start sampling from.
+ seed_edges (InputEdges): Seed edges to start sampling from.
+ num_neighbors (Optional[NumNeighborsType], optional): Parameters
+ to determine how many hops and number of neighbors per hop.
+ Defaults to None.
+
+ Returns:
+ Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
+ for the input.
+ """
+ if not self._sampler_is_initialized:
+ self._init_sampler()
+ if num_neighbors is not None:
+ self.num_neighbors = num_neighbors
+
+ # FIXME: Right now, only input nodes/edges as tensors are be supported
+ if not isinstance(seed_nodes, Tensor):
+ raise NotImplementedError
+ if not isinstance(seed_edges, Tensor):
+ raise NotImplementedError
+ device = seed_nodes.device
+
+ # TODO: Call sample_from_edges for seed_edges
+ # Turning them into nodes for now.
+ seed_edges = self.edge_index.to(device).T[seed_edges.to(
+ device)].reshape(-1)
+ seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0)
+
+ seed_nodes = seed_nodes.unique().contiguous()
+ node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
+ out = self.sampler.sample_from_nodes(node_sample_input)
+
+ return out
diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md
new file mode 100644
index 000000000000..ff43b16a2c05
--- /dev/null
+++ b/examples/llm/multihop_rag/README.md
@@ -0,0 +1,9 @@
+# Examples for LLM and GNN co-training
+
+| Example | Description |
+| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
+| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. |
+| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. |
+| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. |
+
+NOTE: Performance of GRetriever on this dataset has not been evaluated.
diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh
new file mode 100644
index 000000000000..3c1970d39440
--- /dev/null
+++ b/examples/llm/multihop_rag/multihop_download.sh
@@ -0,0 +1,12 @@
+#!/bin/sh
+
+# Wikidata5m
+
+wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz"
+tar -xvf "wikidata5m_alias.tar.gz"
+wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz"
+gzip -d "wikidata5m_all_triplet.txt.gz" -f
+
+# 2Multihopqa
+wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip"
+unzip -o "data_ids_april7.zip"
diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py
new file mode 100644
index 000000000000..46052bdf1b15
--- /dev/null
+++ b/examples/llm/multihop_rag/multihop_preprocess.py
@@ -0,0 +1,276 @@
+"""Example workflow for downloading and assembling a multihop QA dataset."""
+
+import argparse
+import json
+from subprocess import call
+
+import pandas as pd
+import torch
+import tqdm
+
+from torch_geometric.data import LargeGraphIndexer
+
+# %% [markdown]
+# # Encoding A Large Knowledge Graph Part 2
+
+# %% [markdown]
+# In this notebook, we will continue where we left off by building a new
+# multi-hop QA dataset based on Wikidata.
+
+# %% [markdown]
+# ## Example 2: Building a new Dataset from Questions and an already-existing
+# Knowledge Graph
+
+# %% [markdown]
+# ### Motivation
+
+# %% [markdown]
+# One potential application of knowledge graph structural encodings is
+# capturing the relationships between different entities that are multiple
+# hops apart. This can be challenging for an LLM to recognize from prepended
+# graph information. Here's a motivating example (credit to @Rishi Puri):
+
+# %% [markdown]
+# In this example, the question can only be answered by reasoning about the
+# relationships between the entities in the knowledge graph.
+
+# %% [markdown]
+# ### Building a Multi-Hop QA Dataset
+
+# %% [markdown]
+# To start, we need to download the raw data of a knowledge graph.
+# In this case, we use WikiData5M
+# ([Wang et al]
+# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)).
+# Here we download the raw triplets and their entity codes. Information about
+# this dataset can be found
+# [here](https://deepgraphlearning.github.io/project/wikidata5m).
+
+# %% [markdown]
+# The following download contains the ID to plaintext mapping for all the
+# entities and relations in the knowledge graph:
+
+rv = call("./multihop_download.sh")
+
+# %% [markdown]
+# To start, we are going to preprocess the knowledge graph to substitute each
+# of the entity/relation codes with their plaintext aliases. This makes it
+# easier to use a pre-trained textual encoding model to create triplet
+# embeddings, as such a model likely won't understand how to properly embed
+# the entity codes.
+
+# %%
+
+# %%
+parser = argparse.ArgumentParser(description="Preprocess wikidata5m")
+parser.add_argument("--n_triplets", type=int, default=-1)
+args = parser.parse_args()
+
+# %%
+# Substitute entity codes with their aliases
+# Picking the first alias for each entity (rather arbitrarily)
+alias_map = {}
+rel_alias_map = {}
+for line in open('wikidata5m_entity.txt'):
+ parts = line.strip().split('\t')
+ entity_id = parts[0]
+ aliases = parts[1:]
+ alias_map[entity_id] = aliases[0]
+for line in open('wikidata5m_relation.txt'):
+ parts = line.strip().split('\t')
+ relation_id = parts[0]
+ relation_name = parts[1]
+ rel_alias_map[relation_id] = relation_name
+
+# %%
+full_graph = []
+missing_total = 0
+total = 0
+limit = None if args.n_triplets == -1 else args.n_triplets
+i = 0
+
+for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')):
+ if limit is not None and i >= limit:
+ break
+ src, rel, dst = line.strip().split('\t')
+ if src not in alias_map:
+ missing_total += 1
+ if dst not in alias_map:
+ missing_total += 1
+ if rel not in rel_alias_map:
+ missing_total += 1
+ total += 3
+ full_graph.append([
+ alias_map.get(src, src),
+ rel_alias_map.get(rel, rel),
+ alias_map.get(dst, dst)
+ ])
+ i += 1
+print(f"Missing aliases: {missing_total}/{total}")
+
+# %% [markdown]
+# Now `full_graph` represents the knowledge graph triplets in
+# understandable plaintext.
+
+# %% [markdown]
+# Next, we need a set of multi-hop questions that the Knowledge Graph will
+# provide us with context for. We utilize a subset of
+# [HotPotQA](https://hotpotqa.github.io/)
+# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called
+# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop)
+# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)),
+# which includes a subgraph of entities that serve as the ground truth
+# justification for answering each multi-hop question:
+
+# %%
+with open('train.json') as f:
+ train_data = json.load(f)
+train_df = pd.DataFrame(train_data)
+train_df['split_type'] = 'train'
+
+with open('dev.json') as f:
+ dev_data = json.load(f)
+dev_df = pd.DataFrame(dev_data)
+dev_df['split_type'] = 'dev'
+
+with open('test.json') as f:
+ test_data = json.load(f)
+test_df = pd.DataFrame(test_data)
+test_df['split_type'] = 'test'
+
+df = pd.concat([train_df, dev_df, test_df])
+
+# %% [markdown]
+# Now we need to extract the subgraphs
+
+# %%
+df['graph_size'] = df['evidences_id'].apply(lambda row: len(row))
+
+# %% [markdown]
+# (Optional) We take only questions where the evidence graph is greater than
+# 0. (Note: this gets rid of the test set):
+
+# %%
+# df = df[df['graph_size'] > 0]
+
+# %%
+refined_df = df[[
+ '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type',
+ 'graph_size'
+]]
+
+# %% [markdown]
+# Checkpoint:
+
+# %%
+refined_df.to_csv('wikimultihopqa_refined.csv', index=False)
+
+# %% [markdown]
+# Now we need to check that all the entities mentioned in the question/answer
+# set are also present in the Wikidata graph:
+
+# %%
+relation_map = {}
+with open('wikidata5m_relation.txt') as f:
+ for line in tqdm.tqdm(f):
+ parts = line.strip().split('\t')
+ for i in range(1, len(parts)):
+ if parts[i] not in relation_map:
+ relation_map[parts[i]] = []
+ relation_map[parts[i]].append(parts[0])
+
+# %%
+entity_set = set()
+with open('wikidata5m_entity.txt') as f:
+ for line in tqdm.tqdm(f):
+ entity_set.add(line.strip().split('\t')[0])
+
+# %%
+missing_entities = set()
+missing_entity_idx = set()
+for i, row in enumerate(refined_df.itertuples()):
+ for trip in row.evidences_id:
+ entities = trip[0], trip[2]
+ for entity in entities:
+ if entity not in entity_set:
+ # print(
+ # f'The following entity was not found in the KG: {entity}'
+ # )
+ missing_entities.add(entity)
+ missing_entity_idx.add(i)
+
+# %% [markdown]
+# Right now, we drop the missing entity entries. Additional preprocessing can
+# be done here to resolve the entity/relation collisions, but that is out of
+# the scope for this notebook.
+
+# %%
+# missing relations are ok, but missing entities cannot be mapped to
+# plaintext, so they should be dropped.
+refined_df.reset_index(inplace=True, drop=True)
+
+# %%
+cleaned_df = refined_df.drop(missing_entity_idx)
+
+# %% [markdown]
+# Now we save the resulting graph and questions/answers dataset:
+
+# %%
+cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False)
+
+# %%
+
+# %%
+torch.save(full_graph, 'wikimultihopqa_full_graph.pt')
+
+# %% [markdown]
+# ### Question: How do we extract a contextual subgraph for a given query?
+
+# %% [markdown]
+# The chosen retrieval algorithm is a critical component in the pipeline for
+# affecting RAG performance. In the next section (1), we will demonstrate a
+# naive method of retrieval for a large knowledge graph, and how to apply it
+# to this dataset along with WebQSP.
+
+# %% [markdown]
+# ### Preparing a Textualized Graph for LLM
+
+# %% [markdown]
+# For now however, we need to prepare the graph data to be used as a plaintext
+# prefix to the LLM. In order to do this, we want to prompt the LLM to use the
+# unique nodes, and unique edge triplets of a given subgraph. In order to do
+# this, we prepare a unique indexed node df and edge df for the knowledge
+# graph now. This process occurs trivially with the LargeGraphIndexer:
+
+# %%
+
+# %%
+indexer = LargeGraphIndexer.from_triplets(full_graph)
+
+# %%
+# Node DF
+textual_nodes = pd.DataFrame.from_dict(
+ {"node_attr": indexer.get_node_features()})
+textual_nodes["node_id"] = textual_nodes.index
+textual_nodes = textual_nodes[["node_id", "node_attr"]]
+
+# %% [markdown]
+# Notice how LargeGraphIndexer ensures that there are no duplicate indices:
+
+# %%
+# Edge DF
+textual_edges = pd.DataFrame(indexer.get_edge_features(),
+ columns=["src", "edge_attr", "dst"])
+textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]]
+textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]]
+
+# %% [markdown]
+# Note: The edge table refers to each node by its index in the node table.
+# We will see how this gets utilized later when indexing a subgraph.
+
+# %% [markdown]
+# Now we can save the result
+
+# %%
+textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False)
+textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False)
diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py
new file mode 100644
index 000000000000..de93a9e75dd1
--- /dev/null
+++ b/examples/llm/multihop_rag/rag_generate_multihop.py
@@ -0,0 +1,88 @@
+# %%
+import argparse
+import sys
+from typing import Tuple
+
+import pandas as pd
+import torch
+import tqdm
+
+from torch_geometric.data import Data
+from torch_geometric.datasets.web_qsp_dataset import (
+ preprocess_triplet,
+ retrieval_via_pcst,
+)
+from torch_geometric.loader import RAGQueryLoader
+from torch_geometric.nn.nlp import SentenceTransformer
+
+sys.path.append('..')
+
+from g_retriever_utils.rag_backend_utils import \
+ create_remote_backend_from_triplets # noqa: E402
+from g_retriever_utils.rag_feature_store import \
+ SentenceTransformerApproxFeatureStore # noqa: E402
+from g_retriever_utils.rag_graph_store import \
+ NeighborSamplingRAGGraphStore # noqa: E402
+
+# %%
+parser = argparse.ArgumentParser(
+ description="Generate new multihop dataset for rag")
+# TODO: Add more arguments for configuring rag params
+parser.add_argument("--num_samples", type=int)
+args = parser.parse_args()
+
+# %%
+triplets = torch.load('wikimultihopqa_full_graph.pt')
+
+# %%
+df = pd.read_csv('wikimultihopqa_cleaned.csv')
+questions = df['question'][:args.num_samples]
+labels = df['answer'][:args.num_samples]
+
+# %%
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = SentenceTransformer(
+ model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+# %%
+fs, gs = create_remote_backend_from_triplets(
+ triplets=triplets, node_embedding_model=model,
+ node_method_to_call="encode", path="backend",
+ pre_transform=preprocess_triplet, node_method_kwargs={
+ "batch_size": 256
+ }, graph_db=NeighborSamplingRAGGraphStore,
+ feature_db=SentenceTransformerApproxFeatureStore).load()
+
+# %%
+
+all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv')
+all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv')
+
+
+def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
+ topk_e: int = 3,
+ cost_e: float = 0.5) -> Tuple[Data, str]:
+ q_emb = model.encode(query)
+ textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index()
+ out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
+ textual_edges, topk, topk_e, cost_e)
+ out_graph["desc"] = desc
+ return out_graph
+
+
+# %%
+query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10},
+ seed_edges_kwargs={"k_edges": 10},
+ sampler_kwargs={"num_neighbors": [40] * 3},
+ local_filter=apply_retrieval_via_pcst)
+
+# %%
+subgs = []
+for q, l in tqdm.tqdm(zip(questions, labels)):
+ subg = query_loader.query(q)
+ subg['question'] = q
+ subg['label'] = l
+ subgs.append(subg)
+
+torch.save(subgs, 'subg_results.pt')
diff --git a/examples/llm/nvtx_examples/README.md b/examples/llm/nvtx_examples/README.md
new file mode 100644
index 000000000000..aa4f070d9824
--- /dev/null
+++ b/examples/llm/nvtx_examples/README.md
@@ -0,0 +1,7 @@
+# Examples for LLM and GNN co-training
+
+| Example | Description |
+| -------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
+| [`nvtx_run.sh`](./nvtx_run.sh) | Runs nsys profiler on a given Python file that contains NVTX calls. |
+| [`nvtx_rag_backend_example.py`](./nvtx_rag_backend_example.py) | Example script for nsys profiling a RAG Backend such as that used in [`rag_generate.py`](../g_retriever_utils/rag_generate.py). |
+| [`nvtx_webqsp_example.py`](./nvtx_webqsp_example.py) | Example script for nsys profiling the WebQSP dataset. |
diff --git a/examples/llm/nvtx_examples/nvtx_rag_backend_example.py b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py
new file mode 100644
index 000000000000..b30e34b8c7b1
--- /dev/null
+++ b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py
@@ -0,0 +1,144 @@
+# %%
+import argparse
+import sys
+from itertools import chain
+from typing import Tuple
+
+import torch
+
+from torch_geometric.data import Data, get_features_for_triplets_groups
+from torch_geometric.datasets import WebQSPDataset
+from torch_geometric.datasets.web_qsp_dataset import (
+ preprocess_triplet,
+ retrieval_via_pcst,
+)
+from torch_geometric.loader import rag_loader
+from torch_geometric.nn.nlp import SentenceTransformer
+from torch_geometric.profile.nvtx import nvtxit
+
+sys.path.append('..')
+from g_retriever_utils.rag_backend_utils import \
+ create_remote_backend_from_triplets # noqa: E402
+from g_retriever_utils.rag_feature_store import \
+ SentenceTransformerFeatureStore # noqa: E402
+from g_retriever_utils.rag_graph_store import \
+ NeighborSamplingRAGGraphStore # noqa: E402
+
+# %%
+# Patch FeatureStore and GraphStore
+
+SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(
+ SentenceTransformerFeatureStore.retrieve_seed_nodes)
+SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(
+ SentenceTransformerFeatureStore.retrieve_seed_edges)
+SentenceTransformerFeatureStore.load_subgraph = nvtxit()(
+ SentenceTransformerFeatureStore.load_subgraph)
+NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(
+ NeighborSamplingRAGGraphStore.sample_subgraph)
+rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query)
+
+# %%
+ds = WebQSPDataset("small_ds_1", force_reload=True, limit=10)
+
+# %%
+triplets = list(chain.from_iterable(d['graph'] for d in ds.raw_dataset))
+
+# %%
+questions = ds.raw_dataset['question']
+
+# %%
+ground_truth_graphs = get_features_for_triplets_groups(
+ ds.indexer, (d['graph'] for d in ds.raw_dataset),
+ pre_transform=preprocess_triplet)
+num_edges = len(ds.indexer._edges)
+
+# %%
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = SentenceTransformer('sentence-transformers/all-roberta-large-v1').to(
+ device)
+
+# %%
+fs, gs = create_remote_backend_from_triplets(
+ triplets=triplets, node_embedding_model=model,
+ node_method_to_call="encode", path="backend",
+ pre_transform=preprocess_triplet, node_method_kwargs={
+ "batch_size": 256
+ }, graph_db=NeighborSamplingRAGGraphStore,
+ feature_db=SentenceTransformerFeatureStore).load()
+
+# %%
+
+
+@nvtxit()
+def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
+ topk_e: int = 3,
+ cost_e: float = 0.5) -> Tuple[Data, str]:
+ q_emb = model.encode(query)
+ textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
+ out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
+ textual_edges, topk, topk_e, cost_e)
+ out_graph["desc"] = desc
+ return graph
+
+
+# %%
+query_loader = rag_loader.RAGQueryLoader(
+ data=(fs, gs), seed_nodes_kwargs={"k_nodes":
+ 10}, seed_edges_kwargs={"k_edges": 10},
+ sampler_kwargs={"num_neighbors":
+ [40] * 10}, local_filter=apply_retrieval_via_pcst)
+
+
+# %%
+# Accuracy Metrics to be added to Profiler
+def _eidx_helper(subg: Data, ground_truth: Data):
+ subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
+ if isinstance(subg_eidx, torch.Tensor):
+ subg_eidx = subg_eidx.tolist()
+ if isinstance(gt_eidx, torch.Tensor):
+ gt_eidx = gt_eidx.tolist()
+ subg_e = set(subg_eidx)
+ gt_e = set(gt_eidx)
+ return subg_e, gt_e
+
+
+def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ total_e = set(range(num_edges))
+ tp = len(subg_e & gt_e)
+ tn = len(total_e - (subg_e | gt_e))
+ return (tp + tn) / num_edges
+
+
+def check_retrieval_precision(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(subg_e)
+
+
+def check_retrieval_recall(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(gt_e)
+
+
+# %%
+
+
+@nvtxit()
+def _run_eval():
+ for subg, gt in zip((query_loader.query(q) for q in questions),
+ ground_truth_graphs):
+ print(check_retrieval_accuracy(subg, gt, num_edges),
+ check_retrieval_precision(subg, gt),
+ check_retrieval_recall(subg, gt))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
+ args = parser.parse_args()
+ if args.capture_torch_kernels:
+ with torch.autograd.profiler.emit_nvtx():
+ _run_eval()
+ else:
+ _run_eval()
diff --git a/examples/llm/nvtx_examples/nvtx_run.sh b/examples/llm/nvtx_examples/nvtx_run.sh
new file mode 100644
index 000000000000..4c6fce7c8224
--- /dev/null
+++ b/examples/llm/nvtx_examples/nvtx_run.sh
@@ -0,0 +1,27 @@
+#!/bin/sh
+
+# Check if the user provided a Python file
+if [ -z "$1" ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+# Check if the provided file exists
+if [[ ! -f "$1" ]]; then
+ echo "Error: File '$1' does not exist."
+ exit 1
+fi
+
+# Check if the provided file is a Python file
+if [[ ! "$1" == *.py ]]; then
+ echo "Error: '$1' is not a Python file."
+ exit 1
+fi
+
+# Get the base name of the Python file
+python_file=$(basename "$1")
+
+# Run nsys profile on the Python file
+nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1"
+
+echo "Profile data saved as profile_${python_file%.py}.nsys-rep"
diff --git a/examples/llm/nvtx_examples/nvtx_webqsp_example.py b/examples/llm/nvtx_examples/nvtx_webqsp_example.py
new file mode 100644
index 000000000000..a1e9611ee5cc
--- /dev/null
+++ b/examples/llm/nvtx_examples/nvtx_webqsp_example.py
@@ -0,0 +1,22 @@
+import argparse
+
+import torch
+
+from torch_geometric.datasets import web_qsp_dataset
+from torch_geometric.profile import nvtxit
+
+# Apply Patches
+web_qsp_dataset.retrieval_via_pcst = nvtxit()(
+ web_qsp_dataset.retrieval_via_pcst)
+web_qsp_dataset.WebQSPDataset.process = nvtxit()(
+ web_qsp_dataset.WebQSPDataset.process)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
+ args = parser.parse_args()
+ if args.capture_torch_kernels:
+ with torch.autograd.profiler.emit_nvtx():
+ ds = web_qsp_dataset.WebQSPDataset('baseline', limit=10)
+ else:
+ ds = web_qsp_dataset.WebQSPDataset('baseline', limit=10)
diff --git a/test/data/test_large_graph_indexer.py b/test/data/test_large_graph_indexer.py
new file mode 100644
index 000000000000..840299ba4934
--- /dev/null
+++ b/test/data/test_large_graph_indexer.py
@@ -0,0 +1,170 @@
+import random
+import string
+from typing import List
+
+import networkx as nx
+import torch
+
+from torch_geometric.data import (
+ Data,
+ LargeGraphIndexer,
+ TripletLike,
+ get_features_for_triplets,
+)
+from torch_geometric.data.large_graph_indexer import (
+ EDGE_PID,
+ EDGE_RELATION,
+ NODE_PID,
+)
+
+# create possible nodes and edges for graph
+strkeys = string.ascii_letters + string.digits
+NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)})
+EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)})
+
+
+def featurize(s: str) -> int:
+ return int.from_bytes(s.encode(), 'little')
+
+
+def sample_triplets(amount: int = 1) -> List[TripletLike]:
+ trips = []
+ for i in range(amount):
+ h, t = random.sample(NODE_POOL, k=2)
+ r = random.sample(EDGE_POOL, k=1)[0]
+ trips.append(tuple([h, r, t]))
+ return trips
+
+
+def preprocess_triplet(triplet: TripletLike) -> TripletLike:
+ h, r, t = triplet
+ return h.lower(), r, t.lower()
+
+
+def test_basic_collate():
+ graphs = [sample_triplets(1000) for i in range(2)]
+
+ indexer_0 = LargeGraphIndexer.from_triplets(
+ graphs[0], pre_transform=preprocess_triplet)
+ indexer_1 = LargeGraphIndexer.from_triplets(
+ graphs[1], pre_transform=preprocess_triplet)
+
+ big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1])
+
+ assert len(indexer_0._nodes) + len(
+ indexer_1._nodes) - len(indexer_0._nodes.keys()
+ & indexer_1._nodes.keys()) == len(
+ big_indexer._nodes)
+ assert len(indexer_0._edges) + len(
+ indexer_1._edges) - len(indexer_0._edges.keys()
+ & indexer_1._edges.keys()) == len(
+ big_indexer._edges)
+
+ assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes)
+ assert len(set(big_indexer._edges.values())) == len(big_indexer._edges)
+
+ for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()):
+ assert big_indexer.node_attr[NODE_PID][
+ big_indexer._nodes[node]] == node
+
+
+def test_large_graph_index():
+ graphs = [sample_triplets(1000) for i in range(100)]
+
+ # Preprocessing of trips lowercases nodes but not edges
+ node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL}
+ edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}
+
+ def encode_graph_from_trips(triplets: List[TripletLike]) -> Data:
+ seen_nodes = dict()
+ edge_attrs = list()
+ edge_idx = []
+ for trip in triplets:
+ trip = preprocess_triplet(trip)
+ h, r, t = trip
+ seen_nodes[h] = len(
+ seen_nodes) if h not in seen_nodes else seen_nodes[h]
+ seen_nodes[t] = len(
+ seen_nodes) if t not in seen_nodes else seen_nodes[t]
+ edge_attrs.append(edge_feature_vecs[r])
+ edge_idx.append((seen_nodes[h], seen_nodes[t]))
+
+ x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()])
+ edge_idx = torch.LongTensor(edge_idx).T
+ edge_attrs = torch.Tensor(edge_attrs)
+ return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs)
+
+ naive_graph_ds = [
+ encode_graph_from_trips(triplets=trips) for trips in graphs
+ ]
+
+ indexer = LargeGraphIndexer.collate([
+ LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet)
+ for g in graphs
+ ])
+ indexer_nodes = indexer.get_unique_node_features()
+ indexer_node_vals = torch.Tensor(
+ [node_feature_vecs[n] for n in indexer_nodes])
+ indexer_edges = indexer.get_unique_edge_features(
+ feature_name=EDGE_RELATION)
+ indexer_edge_vals = torch.Tensor(
+ [edge_feature_vecs[e] for e in indexer_edges])
+ indexer.add_node_feature('x', indexer_node_vals)
+ indexer.add_edge_feature('edge_attr', indexer_edge_vals,
+ map_from_feature=EDGE_RELATION)
+ large_graph_ds = [
+ get_features_for_triplets(indexer=indexer, triplets=g,
+ node_feature_name='x',
+ edge_feature_name='edge_attr',
+ pre_transform=preprocess_triplet)
+ for g in graphs
+ ]
+
+ for ds in large_graph_ds:
+ assert NODE_PID in ds
+ assert EDGE_PID in ds
+ assert "node_idx" in ds
+ assert "edge_idx" in ds
+
+ def results_are_close_enough(ground_truth: Data, new_method: Data,
+ thresh=.99):
+ def _sorted_tensors_are_close(tensor1, tensor2):
+ return torch.all(
+ torch.isclose(tensor1.sort()[0],
+ tensor2.sort()[0]) > thresh)
+
+ def _graphs_are_same(tensor1, tensor2):
+ return nx.weisfeiler_lehman_graph_hash(nx.Graph(
+ tensor1.T)) == nx.weisfeiler_lehman_graph_hash(
+ nx.Graph(tensor2.T))
+ return _sorted_tensors_are_close(
+ ground_truth.x, new_method.x) \
+ and _sorted_tensors_are_close(
+ ground_truth.edge_attr, new_method.edge_attr) \
+ and _graphs_are_same(
+ ground_truth.edge_index, new_method.edge_index)
+
+ for dsets in zip(naive_graph_ds, large_graph_ds):
+ assert results_are_close_enough(*dsets)
+
+
+def test_save_load(tmp_path):
+ graph = sample_triplets(1000)
+
+ node_feature_vecs = {s: featurize(s) for s in NODE_POOL}
+ edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}
+
+ indexer = LargeGraphIndexer.from_triplets(graph)
+ indexer_nodes = indexer.get_unique_node_features()
+ indexer_node_vals = torch.Tensor(
+ [node_feature_vecs[n] for n in indexer_nodes])
+ indexer_edges = indexer.get_unique_edge_features(
+ feature_name=EDGE_RELATION)
+ indexer_edge_vals = torch.Tensor(
+ [edge_feature_vecs[e] for e in indexer_edges])
+ indexer.add_node_feature('x', indexer_node_vals)
+ indexer.add_edge_feature('edge_attr', indexer_edge_vals,
+ map_from_feature=EDGE_RELATION)
+
+ indexer.save(str(tmp_path))
+ assert indexer == LargeGraphIndexer.from_disk(str(tmp_path))
diff --git a/test/datasets/test_web_qsp_dataset.py b/test/datasets/test_web_qsp_dataset.py
new file mode 100644
index 000000000000..9dbb8218c65a
--- /dev/null
+++ b/test/datasets/test_web_qsp_dataset.py
@@ -0,0 +1,29 @@
+import pytest
+
+from torch_geometric.datasets import WebQSPDataset
+from torch_geometric.testing import onlyFullTest, onlyOnline
+
+
+@pytest.mark.skip(reason="Times out")
+@onlyOnline
+@onlyFullTest
+def test_web_qsp_dataset():
+ dataset = WebQSPDataset()
+ assert len(dataset) == 4700
+ assert str(dataset) == "WebQSPDataset(4700)"
+
+
+@onlyOnline
+@onlyFullTest
+def test_web_qsp_dataset_limit(tmp_path):
+ dataset = WebQSPDataset(root=tmp_path, limit=100)
+ assert len(dataset) == 100
+ assert str(dataset) == "WebQSPDataset(100)"
+
+
+@onlyOnline
+@onlyFullTest
+def test_web_qsp_dataset_limit_no_pcst(tmp_path):
+ dataset = WebQSPDataset(root=tmp_path, limit=100, include_pcst=False)
+ assert len(dataset) == 100
+ assert str(dataset) == "WebQSPDataset(100)"
diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py
index 899e70730cc9..24a74d1b6f6e 100644
--- a/test/nn/models/test_g_retriever.py
+++ b/test/nn/models/test_g_retriever.py
@@ -51,3 +51,52 @@ def test_g_retriever() -> None:
# Test inference:
pred = model.inference(question, x, edge_index, batch, edge_attr)
assert len(pred) == 1
+
+
+@onlyFullTest
+@withPackage('transformers', 'sentencepiece', 'accelerate')
+def test_g_retriever_many_tokens() -> None:
+ llm = LLM(
+ model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
+ num_params=1,
+ dtype=torch.float16,
+ )
+
+ gnn = GAT(
+ in_channels=1024,
+ out_channels=1024,
+ hidden_channels=1024,
+ num_layers=2,
+ heads=4,
+ norm='batch_norm',
+ )
+
+ model = GRetriever(
+ llm=llm,
+ gnn=gnn,
+ mlp_out_channels=2048,
+ mlp_out_tokens=2,
+ )
+ assert str(model) == ('GRetriever(\n'
+ ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n'
+ ' gnn=GAT(1024, 1024, num_layers=2),\n'
+ ')')
+
+ x = torch.randn(10, 1024)
+ edge_index = torch.tensor([
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
+ ])
+ edge_attr = torch.randn(edge_index.size(1), 1024)
+ batch = torch.zeros(x.size(0), dtype=torch.long)
+
+ question = ["Is PyG the best open-source GNN library?"]
+ label = ["yes!"]
+
+ # Test train:
+ loss = model(question, x, edge_index, batch, label, edge_attr)
+ assert loss >= 0
+
+ # Test inference:
+ pred = model.inference(question, x, edge_index, batch, edge_attr)
+ assert len(pred) == 1
diff --git a/test/profile/test_nvtx.py b/test/profile/test_nvtx.py
new file mode 100644
index 000000000000..56e28a9c2e59
--- /dev/null
+++ b/test/profile/test_nvtx.py
@@ -0,0 +1,136 @@
+from unittest.mock import call, patch
+
+from torch_geometric.profile import nvtxit
+
+
+def _setup_mock(torch_cuda_mock):
+ torch_cuda_mock.is_available.return_value = True
+ torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None
+ torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None
+ return torch_cuda_mock
+
+
+@patch('torch_geometric.profile.nvtx.torch.cuda')
+def test_nvtxit_base(torch_cuda_mock):
+ torch_cuda_mock = _setup_mock(torch_cuda_mock)
+
+ # dummy func calls a calls b
+
+ @nvtxit()
+ def call_b():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return 42
+
+ @nvtxit()
+ def call_a():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return call_b()
+
+ def dummy_func():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return call_a()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ dummy_func()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.nvtx.range_push.call_args_list == [
+ call('call_a_0'), call('call_b_0')
+ ]
+
+
+@patch('torch_geometric.profile.nvtx.torch.cuda')
+def test_nvtxit_rename(torch_cuda_mock):
+ torch_cuda_mock = _setup_mock(torch_cuda_mock)
+
+ # dummy func calls a calls b
+
+ @nvtxit()
+ def call_b():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return 42
+
+ @nvtxit('a_nvtx')
+ def call_a():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return call_b()
+
+ def dummy_func():
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ return call_a()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ dummy_func()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.nvtx.range_push.call_args_list == [
+ call('a_nvtx_0'), call('call_b_0')
+ ]
+
+
+@patch('torch_geometric.profile.nvtx.torch.cuda')
+def test_nvtxit_iters(torch_cuda_mock):
+ torch_cuda_mock = _setup_mock(torch_cuda_mock)
+
+ # dummy func calls a calls b
+
+ @nvtxit(n_iters=1)
+ def call_b():
+ return 42
+
+ @nvtxit()
+ def call_a():
+ return call_b()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+
+ call_b()
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501
+ call_a()
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501
+
+ assert torch_cuda_mock.nvtx.range_push.call_args_list == [
+ call('call_b_0'), call('call_a_0')
+ ]
+
+
+@patch('torch_geometric.profile.nvtx.torch.cuda')
+def test_nvtxit_warmups(torch_cuda_mock):
+ torch_cuda_mock = _setup_mock(torch_cuda_mock)
+
+ # dummy func calls a calls b
+
+ @nvtxit(n_warmups=1)
+ def call_b():
+ return 42
+
+ @nvtxit()
+ def call_a():
+ return call_b()
+
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+
+ call_b()
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501
+ call_a()
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501
+ assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501
+
+ assert torch_cuda_mock.nvtx.range_push.call_args_list == [
+ call('call_a_0'), call('call_b_1')
+ ]
diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py
index 821ef9c5c063..fee215b1a357 100644
--- a/torch_geometric/data/__init__.py
+++ b/torch_geometric/data/__init__.py
@@ -16,6 +16,7 @@
from .makedirs import makedirs
from .download import download_url, download_google_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
+from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
from torch_geometric.lazy_loader import LazyLoader
@@ -27,6 +28,8 @@
'Dataset',
'InMemoryDataset',
'OnDiskDataset',
+ 'LargeGraphIndexer',
+ 'TripletLike',
]
remote_backend_classes = [
@@ -50,6 +53,8 @@
'extract_zip',
'extract_bz2',
'extract_gz',
+ 'get_features_for_triplets',
+ "get_features_for_triplets_groups",
]
__all__ = data_classes + remote_backend_classes + helper_functions
diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py
new file mode 100644
index 000000000000..51b0ce459b33
--- /dev/null
+++ b/torch_geometric/data/large_graph_indexer.py
@@ -0,0 +1,672 @@
+import os
+import pickle as pkl
+import shutil
+from dataclasses import dataclass
+from itertools import chain
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Hashable,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
+
+import torch
+from torch import Tensor
+from tqdm import tqdm
+
+from torch_geometric.data import Data
+
+TripletLike = Tuple[Hashable, Hashable, Hashable]
+
+KnowledgeGraphLike = Iterable[TripletLike]
+
+
+def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
+ return list(dict.fromkeys(values))
+
+
+# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
+
+NODE_PID = "pid"
+
+NODE_KEYS = {NODE_PID}
+
+EDGE_PID = "e_pid"
+EDGE_HEAD = "h"
+EDGE_RELATION = "r"
+EDGE_TAIL = "t"
+EDGE_INDEX = "edge_idx"
+
+EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
+
+FeatureValueType = Union[Sequence[Any], Tensor]
+
+
+@dataclass
+class MappedFeature:
+ name: str
+ values: FeatureValueType
+
+ def __eq__(self, value: "MappedFeature") -> bool:
+ eq = self.name == value.name
+ if isinstance(self.values, torch.Tensor):
+ eq &= torch.equal(self.values, value.values)
+ else:
+ eq &= self.values == value.values
+ return eq
+
+
+class LargeGraphIndexer:
+ """For a dataset that consists of mulitiple subgraphs that are assumed to
+ be part of a much larger graph, collate the values into a large graph store
+ to save resources.
+ """
+ def __init__(
+ self,
+ nodes: Iterable[Hashable],
+ edges: KnowledgeGraphLike,
+ node_attr: Optional[Dict[str, List[Any]]] = None,
+ edge_attr: Optional[Dict[str, List[Any]]] = None,
+ ) -> None:
+ r"""Constructs a new index that uniquely catalogs each node and edge
+ by id. Not meant to be used directly.
+
+ Args:
+ nodes (Iterable[Hashable]): Node ids in the graph.
+ edges (KnowledgeGraphLike): Edge ids in the graph.
+ node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
+ attribute name and list of their values in order of unique node
+ ids. Defaults to None.
+ edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
+ attribute name and list of their values in order of unique edge
+ ids. Defaults to None.
+ """
+ self._nodes: Dict[Hashable, int] = dict()
+ self._edges: Dict[TripletLike, int] = dict()
+
+ self._mapped_node_features: Set[str] = set()
+ self._mapped_edge_features: Set[str] = set()
+
+ if len(nodes) != len(set(nodes)):
+ raise AttributeError("Nodes need to be unique")
+ if len(edges) != len(set(edges)):
+ raise AttributeError("Edges need to be unique")
+
+ if node_attr is not None:
+ # TODO: Validity checks btw nodes and node_attr
+ self.node_attr = node_attr
+ if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
+ raise AttributeError(
+ "Invalid node_attr object. Missing " +
+ f"{NODE_KEYS - set(self.node_attr.keys())}")
+ elif self.node_attr[NODE_PID] != nodes:
+ raise AttributeError(
+ "Nodes provided do not match those in node_attr")
+ else:
+ self.node_attr = dict()
+ self.node_attr[NODE_PID] = nodes
+
+ for i, node in enumerate(self.node_attr[NODE_PID]):
+ self._nodes[node] = i
+
+ if edge_attr is not None:
+ # TODO: Validity checks btw edges and edge_attr
+ self.edge_attr = edge_attr
+
+ if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
+ raise AttributeError(
+ "Invalid edge_attr object. Missing " +
+ f"{EDGE_KEYS - set(self.edge_attr.keys())}")
+ elif self.node_attr[EDGE_PID] != edges:
+ raise AttributeError(
+ "Edges provided do not match those in edge_attr")
+
+ else:
+ self.edge_attr = dict()
+ for default_key in EDGE_KEYS:
+ self.edge_attr[default_key] = list()
+ self.edge_attr[EDGE_PID] = edges
+
+ for i, tup in enumerate(edges):
+ h, r, t = tup
+ self.edge_attr[EDGE_HEAD].append(h)
+ self.edge_attr[EDGE_RELATION].append(r)
+ self.edge_attr[EDGE_TAIL].append(t)
+ self.edge_attr[EDGE_INDEX].append(
+ (self._nodes[h], self._nodes[t]))
+
+ for i, tup in enumerate(edges):
+ self._edges[tup] = i
+
+ @classmethod
+ def from_triplets(
+ cls,
+ triplets: KnowledgeGraphLike,
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
+ ) -> "LargeGraphIndexer":
+ r"""Generate a new index from a series of triplets that represent edge
+ relations between nodes.
+ Formatted like (source_node, edge, dest_node).
+
+ Args:
+ triplets (KnowledgeGraphLike): Series of triplets representing
+ knowledge graph relations.
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
+ Optional preprocessing function to apply to triplets.
+ Defaults to None.
+
+ Returns:
+ LargeGraphIndexer: Index of unique nodes and edges.
+ """
+ # NOTE: Right now assumes that all trips can be loaded into memory
+ nodes = set()
+ edges = set()
+
+ if pre_transform is not None:
+
+ def apply_transform(
+ trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
+ for trip in trips:
+ yield pre_transform(trip)
+
+ triplets = apply_transform(triplets)
+
+ for h, r, t in triplets:
+
+ for node in (h, t):
+ nodes.add(node)
+
+ edge_idx = (h, r, t)
+ edges.add(edge_idx)
+
+ return cls(list(nodes), list(edges))
+
+ @classmethod
+ def collate(cls,
+ graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
+ r"""Combines a series of large graph indexes into a single large graph
+ index.
+
+ Args:
+ graphs (Iterable["LargeGraphIndexer"]): Indices to be
+ combined.
+
+ Returns:
+ LargeGraphIndexer: Singular unique index for all nodes and edges
+ in input indices.
+ """
+ # FIXME Needs to merge node attrs and edge attrs?
+ trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
+ return cls.from_triplets(trips)
+
+ def get_unique_node_features(
+ self, feature_name: str = NODE_PID) -> List[Hashable]:
+ r"""Get all the unique values for a specific node attribute.
+
+ Args:
+ feature_name (str, optional): Name of feature to get.
+ Defaults to NODE_PID.
+
+ Returns:
+ List[Hashable]: List of unique values for the specified feature.
+ """
+ try:
+ if feature_name in self._mapped_node_features:
+ raise IndexError(
+ "Only non-mapped features can be retrieved uniquely.")
+ return ordered_set(self.get_node_features(feature_name))
+
+ except KeyError:
+ raise AttributeError(
+ f"Nodes do not have a feature called {feature_name}")
+
+ def add_node_feature(
+ self,
+ new_feature_name: str,
+ new_feature_vals: FeatureValueType,
+ map_from_feature: str = NODE_PID,
+ ) -> None:
+ r"""Adds a new feature that corresponds to each unique node in
+ the graph.
+
+ Args:
+ new_feature_name (str): Name to call the new feature.
+ new_feature_vals (FeatureValueType): Values to map for that
+ new feature.
+ map_from_feature (str, optional): Key of feature to map from.
+ Size must match the number of feature values.
+ Defaults to NODE_PID.
+ """
+ if new_feature_name in self.node_attr:
+ raise AttributeError("Features cannot be overridden once created")
+ if map_from_feature in self._mapped_node_features:
+ raise AttributeError(
+ f"{map_from_feature} is already a feature mapping.")
+
+ feature_keys = self.get_unique_node_features(map_from_feature)
+ if len(feature_keys) != len(new_feature_vals):
+ raise AttributeError(
+ "Expected encodings for {len(feature_keys)} unique features," +
+ f" but got {len(new_feature_vals)} encodings.")
+
+ if map_from_feature == NODE_PID:
+ self.node_attr[new_feature_name] = new_feature_vals
+ else:
+ self.node_attr[new_feature_name] = MappedFeature(
+ name=map_from_feature, values=new_feature_vals)
+ self._mapped_node_features.add(new_feature_name)
+
+ def get_node_features(
+ self,
+ feature_name: str = NODE_PID,
+ pids: Optional[Iterable[Hashable]] = None,
+ ) -> List[Any]:
+ r"""Get node feature values for a given set of unique node ids.
+ Returned values are not necessarily unique.
+
+ Args:
+ feature_name (str, optional): Name of feature to fetch. Defaults
+ to NODE_PID.
+ pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
+ for. Defaults to None, which fetches all nodes.
+
+ Returns:
+ List[Any]: Node features corresponding to the specified ids.
+ """
+ if feature_name in self._mapped_node_features:
+ values = self.node_attr[feature_name].values
+ else:
+ values = self.node_attr[feature_name]
+
+ # TODO: torch_geometric.utils.select
+ if isinstance(values, torch.Tensor):
+ idxs = list(
+ self.get_node_features_iter(feature_name, pids,
+ index_only=True))
+ return values[idxs]
+ return list(self.get_node_features_iter(feature_name, pids))
+
+ def get_node_features_iter(
+ self,
+ feature_name: str = NODE_PID,
+ pids: Optional[Iterable[Hashable]] = None,
+ index_only: bool = False,
+ ) -> Iterator[Any]:
+ """Iterator version of get_node_features. If index_only is True,
+ yields indices instead of values.
+ """
+ if pids is None:
+ pids = self.node_attr[NODE_PID]
+
+ if feature_name in self._mapped_node_features:
+ feature_map_info = self.node_attr[feature_name]
+ from_feature_name, to_feature_vals = (
+ feature_map_info.name,
+ feature_map_info.values,
+ )
+ from_feature_vals = self.get_unique_node_features(
+ from_feature_name)
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
+
+ for pid in pids:
+ idx = self._nodes[pid]
+ from_feature_val = self.node_attr[from_feature_name][idx]
+ to_feature_idx = feature_mapping[from_feature_val]
+ if index_only:
+ yield to_feature_idx
+ else:
+ yield to_feature_vals[to_feature_idx]
+ else:
+ for pid in pids:
+ idx = self._nodes[pid]
+ if index_only:
+ yield idx
+ else:
+ yield self.node_attr[feature_name][idx]
+
+ def get_unique_edge_features(
+ self, feature_name: str = EDGE_PID) -> List[Hashable]:
+ r"""Get all the unique values for a specific edge attribute.
+
+ Args:
+ feature_name (str, optional): Name of feature to get.
+ Defaults to EDGE_PID.
+
+ Returns:
+ List[Hashable]: List of unique values for the specified feature.
+ """
+ try:
+ if feature_name in self._mapped_edge_features:
+ raise IndexError(
+ "Only non-mapped features can be retrieved uniquely.")
+ return ordered_set(self.get_edge_features(feature_name))
+ except KeyError:
+ raise AttributeError(
+ f"Edges do not have a feature called {feature_name}")
+
+ def add_edge_feature(
+ self,
+ new_feature_name: str,
+ new_feature_vals: FeatureValueType,
+ map_from_feature: str = EDGE_PID,
+ ) -> None:
+ r"""Adds a new feature that corresponds to each unique edge in
+ the graph.
+
+ Args:
+ new_feature_name (str): Name to call the new feature.
+ new_feature_vals (FeatureValueType): Values to map for that new
+ feature.
+ map_from_feature (str, optional): Key of feature to map from.
+ Size must match the number of feature values.
+ Defaults to EDGE_PID.
+ """
+ if new_feature_name in self.edge_attr:
+ raise AttributeError("Features cannot be overridden once created")
+ if map_from_feature in self._mapped_edge_features:
+ raise AttributeError(
+ f"{map_from_feature} is already a feature mapping.")
+
+ feature_keys = self.get_unique_edge_features(map_from_feature)
+ if len(feature_keys) != len(new_feature_vals):
+ raise AttributeError(
+ f"Expected encodings for {len(feature_keys)} unique features, "
+ + f"but got {len(new_feature_vals)} encodings.")
+
+ if map_from_feature == EDGE_PID:
+ self.edge_attr[new_feature_name] = new_feature_vals
+ else:
+ self.edge_attr[new_feature_name] = MappedFeature(
+ name=map_from_feature, values=new_feature_vals)
+ self._mapped_edge_features.add(new_feature_name)
+
+ def get_edge_features(
+ self,
+ feature_name: str = EDGE_PID,
+ pids: Optional[Iterable[Hashable]] = None,
+ ) -> List[Any]:
+ r"""Get edge feature values for a given set of unique edge ids.
+ Returned values are not necessarily unique.
+
+ Args:
+ feature_name (str, optional): Name of feature to fetch.
+ Defaults to EDGE_PID.
+ pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
+ for. Defaults to None, which fetches all edges.
+
+ Returns:
+ List[Any]: Node features corresponding to the specified ids.
+ """
+ if feature_name in self._mapped_edge_features:
+ values = self.edge_attr[feature_name].values
+ else:
+ values = self.edge_attr[feature_name]
+
+ # TODO: torch_geometric.utils.select
+ if isinstance(values, torch.Tensor):
+ idxs = list(
+ self.get_edge_features_iter(feature_name, pids,
+ index_only=True))
+ return values[idxs]
+ return list(self.get_edge_features_iter(feature_name, pids))
+
+ def get_edge_features_iter(
+ self,
+ feature_name: str = EDGE_PID,
+ pids: Optional[KnowledgeGraphLike] = None,
+ index_only: bool = False,
+ ) -> Iterator[Any]:
+ """Iterator version of get_edge_features. If index_only is True,
+ yields indices instead of values.
+ """
+ if pids is None:
+ pids = self.edge_attr[EDGE_PID]
+
+ if feature_name in self._mapped_edge_features:
+ feature_map_info = self.edge_attr[feature_name]
+ from_feature_name, to_feature_vals = (
+ feature_map_info.name,
+ feature_map_info.values,
+ )
+ from_feature_vals = self.get_unique_edge_features(
+ from_feature_name)
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
+
+ for pid in pids:
+ idx = self._edges[pid]
+ from_feature_val = self.edge_attr[from_feature_name][idx]
+ to_feature_idx = feature_mapping[from_feature_val]
+ if index_only:
+ yield to_feature_idx
+ else:
+ yield to_feature_vals[to_feature_idx]
+ else:
+ for pid in pids:
+ idx = self._edges[pid]
+ if index_only:
+ yield idx
+ else:
+ yield self.edge_attr[feature_name][idx]
+
+ def to_triplets(self) -> Iterator[TripletLike]:
+ return iter(self.edge_attr[EDGE_PID])
+
+ def save(self, path: str) -> None:
+ if os.path.exists(path):
+ shutil.rmtree(path)
+ os.makedirs(path, exist_ok=True)
+ with open(path + "/edges", "wb") as f:
+ pkl.dump(self._edges, f)
+ with open(path + "/nodes", "wb") as f:
+ pkl.dump(self._nodes, f)
+
+ with open(path + "/mapped_edges", "wb") as f:
+ pkl.dump(self._mapped_edge_features, f)
+ with open(path + "/mapped_nodes", "wb") as f:
+ pkl.dump(self._mapped_node_features, f)
+
+ node_attr_path = path + "/node_attr"
+ os.makedirs(node_attr_path, exist_ok=True)
+ for attr_name, vals in self.node_attr.items():
+ torch.save(vals, node_attr_path + f"/{attr_name}.pt")
+
+ edge_attr_path = path + "/edge_attr"
+ os.makedirs(edge_attr_path, exist_ok=True)
+ for attr_name, vals in self.edge_attr.items():
+ torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
+
+ @classmethod
+ def from_disk(cls, path: str) -> "LargeGraphIndexer":
+ indexer = cls(list(), list())
+ with open(path + "/edges", "rb") as f:
+ indexer._edges = pkl.load(f)
+ with open(path + "/nodes", "rb") as f:
+ indexer._nodes = pkl.load(f)
+
+ with open(path + "/mapped_edges", "rb") as f:
+ indexer._mapped_edge_features = pkl.load(f)
+ with open(path + "/mapped_nodes", "rb") as f:
+ indexer._mapped_node_features = pkl.load(f)
+
+ node_attr_path = path + "/node_attr"
+ for fname in os.listdir(node_attr_path):
+ full_fname = f"{node_attr_path}/{fname}"
+ key = fname.split(".")[0]
+ indexer.node_attr[key] = torch.load(full_fname)
+
+ edge_attr_path = path + "/edge_attr"
+ for fname in os.listdir(edge_attr_path):
+ full_fname = f"{edge_attr_path}/{fname}"
+ key = fname.split(".")[0]
+ indexer.edge_attr[key] = torch.load(full_fname)
+
+ return indexer
+
+ def to_data(self, node_feature_name: str,
+ edge_feature_name: Optional[str] = None) -> Data:
+ """Return a Data object containing all the specified node and
+ edge features and the graph.
+
+ Args:
+ node_feature_name (str): Feature to use for nodes
+ edge_feature_name (Optional[str], optional): Feature to use for
+ edges. Defaults to None.
+
+ Returns:
+ Data: Data object containing the specified node and
+ edge features and the graph.
+ """
+ x = torch.Tensor(self.get_node_features(node_feature_name))
+ node_id = torch.LongTensor(range(len(x)))
+
+ edge_index = torch.t(
+ torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
+
+ edge_attr = (self.get_edge_features(edge_feature_name)
+ if edge_feature_name is not None else None)
+ edge_id = torch.LongTensor(range(len(edge_attr)))
+
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
+ edge_id=edge_id, node_id=node_id)
+
+ def __eq__(self, value: "LargeGraphIndexer") -> bool:
+ eq = True
+ eq &= self._nodes == value._nodes
+ eq &= self._edges == value._edges
+ eq &= self.node_attr.keys() == value.node_attr.keys()
+ eq &= self.edge_attr.keys() == value.edge_attr.keys()
+ eq &= self._mapped_node_features == value._mapped_node_features
+ eq &= self._mapped_edge_features == value._mapped_edge_features
+
+ for k in self.node_attr:
+ eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
+ if isinstance(self.node_attr[k], torch.Tensor):
+ eq &= torch.equal(self.node_attr[k], value.node_attr[k])
+ else:
+ eq &= self.node_attr[k] == value.node_attr[k]
+ for k in self.edge_attr:
+ eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
+ if isinstance(self.edge_attr[k], torch.Tensor):
+ eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
+ else:
+ eq &= self.edge_attr[k] == value.edge_attr[k]
+ return eq
+
+
+def get_features_for_triplets_groups(
+ indexer: LargeGraphIndexer,
+ triplet_groups: Iterable[KnowledgeGraphLike],
+ node_feature_name: str = "x",
+ edge_feature_name: str = "edge_attr",
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
+ verbose: bool = False,
+) -> Iterator[Data]:
+ """Given an indexer and a series of triplet groups (like a dataset),
+ retrieve the specified node and edge features for each triplet from the
+ index.
+
+ Args:
+ indexer (LargeGraphIndexer): Indexer containing desired features
+ triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
+ triplets to fetch features for
+ node_feature_name (str, optional): Node feature to fetch.
+ Defaults to "x".
+ edge_feature_name (str, optional): edge feature to fetch.
+ Defaults to "edge_attr".
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
+ Optional preprocessing to perform on triplets.
+ Defaults to None.
+ verbose (bool, optional): Whether to print progress. Defaults to False.
+
+ Yields:
+ Iterator[Data]: For each triplet group, yield a data object containing
+ the unique graph and features from the index.
+ """
+ if pre_transform is not None:
+
+ def apply_transform(trips):
+ for trip in trips:
+ yield pre_transform(tuple(trip))
+
+ # TODO: Make this safe for large amounts of triplets?
+ triplet_groups = (list(apply_transform(triplets))
+ for triplets in triplet_groups)
+
+ node_keys = []
+ edge_keys = []
+ edge_index = []
+
+ for triplets in tqdm(triplet_groups, disable=not verbose):
+ small_graph_indexer = LargeGraphIndexer.from_triplets(
+ triplets, pre_transform=pre_transform)
+
+ node_keys.append(small_graph_indexer.get_node_features())
+ edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
+ edge_index.append(
+ small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
+
+ node_feats = indexer.get_node_features(feature_name=node_feature_name,
+ pids=chain.from_iterable(node_keys))
+ edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
+ pids=chain.from_iterable(edge_keys))
+
+ last_node_idx, last_edge_idx = 0, 0
+ for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
+ nlen, elen = len(nkeys), len(ekeys)
+ x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
+ last_node_idx += len(nkeys)
+
+ edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
+ elen])
+ last_edge_idx += len(ekeys)
+
+ edge_idx = torch.LongTensor(eidx).T
+
+ data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
+ data_obj[NODE_PID] = node_keys
+ data_obj[EDGE_PID] = edge_keys
+ data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
+ data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
+
+ yield data_obj
+
+
+def get_features_for_triplets(
+ indexer: LargeGraphIndexer,
+ triplets: KnowledgeGraphLike,
+ node_feature_name: str = "x",
+ edge_feature_name: str = "edge_attr",
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
+ verbose: bool = False,
+) -> Data:
+ """For a given set of triplets retrieve a Data object containing the
+ unique graph and features from the index.
+
+ Args:
+ indexer (LargeGraphIndexer): Indexer containing desired features
+ triplets (KnowledgeGraphLike): Triplets to fetch features for
+ node_feature_name (str, optional): Feature to use for node features.
+ Defaults to "x".
+ edge_feature_name (str, optional): Feature to use for edge features.
+ Defaults to "edge_attr".
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
+ Optional preprocessing function for triplets. Defaults to None.
+ verbose (bool, optional): Whether to print progress. Defaults to False.
+
+ Returns:
+ Data: Data object containing the unique graph and features from the
+ index for the given triplets.
+ """
+ gen = get_features_for_triplets_groups(indexer, [triplets],
+ node_feature_name,
+ edge_feature_name, pre_transform,
+ verbose)
+ return next(gen)
diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py
index 96d51032d818..236753293c2e 100644
--- a/torch_geometric/datasets/__init__.py
+++ b/torch_geometric/datasets/__init__.py
@@ -111,6 +111,7 @@
import torch_geometric.datasets.utils
homo_datasets = [
+ 'WebQSPDataset',
'KarateClub',
'TUDataset',
'GNNBenchmarkDataset',
diff --git a/torch_geometric/datasets/web_qsp_dataset.py b/torch_geometric/datasets/web_qsp_dataset.py
index bcc85e070920..fd3caeb99b70 100644
--- a/torch_geometric/datasets/web_qsp_dataset.py
+++ b/torch_geometric/datasets/web_qsp_dataset.py
@@ -1,12 +1,21 @@
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
-from typing import Any, Dict, List, Tuple, no_type_check
+import os
+from itertools import chain
+from typing import Any, Iterator, List, Tuple, no_type_check
import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm
-from torch_geometric.data import Data, InMemoryDataset
+from torch_geometric.data import (
+ Data,
+ InMemoryDataset,
+ LargeGraphIndexer,
+ TripletLike,
+ get_features_for_triplets_groups,
+)
+from torch_geometric.data.large_graph_indexer import EDGE_RELATION
from torch_geometric.nn.nlp import SentenceTransformer
@@ -19,9 +28,11 @@ def retrieval_via_pcst(
topk: int = 3,
topk_e: int = 3,
cost_e: float = 0.5,
+ save_idx: bool = False,
+ override: bool = False,
) -> Tuple[Data, str]:
c = 0.01
- if len(textual_nodes) == 0 or len(textual_edges) == 0:
+ if len(textual_nodes) == 0 or len(textual_edges) == 0 or override:
desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
index=False,
columns=["src", "edge_attr", "dst"],
@@ -114,15 +125,28 @@ def retrieval_via_pcst(
src = [mapping[i] for i in edge_index[0].tolist()]
dst = [mapping[i] for i in edge_index[1].tolist()]
+ # HACK Added so that the subset of nodes and edges selected can be tracked
+ if save_idx:
+ node_idx = np.array(data.node_idx)[selected_nodes]
+ edge_idx = np.array(data.edge_idx)[selected_edges]
+
data = Data(
x=data.x[selected_nodes],
edge_index=torch.tensor([src, dst]),
edge_attr=data.edge_attr[selected_edges],
)
+ if save_idx:
+ data['node_idx'] = node_idx
+ data['edge_idx'] = edge_idx
return data, desc
+def preprocess_triplet(triplet: TripletLike) -> TripletLike:
+ h, r, t = triplet
+ return str(h).lower(), str(r), str(t).lower()
+
+
class WebQSPDataset(InMemoryDataset):
r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
Labeling for Knowledge Base Question Answering"
@@ -135,107 +159,214 @@ class WebQSPDataset(InMemoryDataset):
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
+ limit (int, optional): Construct only the first n samples.
+ Defaults to -1 to construct all samples.
+ include_pcst (bool, optional): Whether to include PCST step
+ (See GRetriever paper). Defaults to True.
+ verbose (bool, optional): Whether to print output. Defaults to False.
"""
def __init__(
self,
root: str,
split: str = "train",
force_reload: bool = False,
+ limit: int = -1,
+ include_pcst: bool = True,
+ verbose: bool = False,
) -> None:
+ self.limit = limit
+ self.split = split
+ self.include_pcst = include_pcst
+ # TODO Confirm why the dependency checks and device setting were removed here # noqa
+ '''
+ self.device = torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu")
+ self._check_dependencies()
+ '''
+ self.verbose = verbose
+ self.force_reload = force_reload
super().__init__(root, force_reload=force_reload)
- if split not in {'train', 'val', 'test'}:
+ if split not in {'train', 'val', 'test'} and limit < 0:
raise ValueError(f"Invalid 'split' argument (got {split})")
- path = self.processed_paths[['train', 'val', 'test'].index(split)]
- self.load(path)
+ self._load_raw_data()
+ self.load(self.processed_paths[0])
+
+ '''
+ def _check_dependencies(self) -> None:
+ missing_str_list = []
+ if not WITH_PCST:
+ missing_str_list.append('pcst_fast')
+ if not WITH_DATASETS:
+ missing_str_list.append('datasets')
+ if not WITH_PANDAS:
+ missing_str_list.append('pandas')
+ if len(missing_str_list) > 0:
+ missing_str = ' '.join(missing_str_list)
+ error_out = f"`pip install {missing_str}` to use this dataset."
+ raise ImportError(error_out)
+ '''
+
+ @property
+ def raw_file_names(self) -> List[str]:
+ return ["raw_data", "split_idxs"]
@property
def processed_file_names(self) -> List[str]:
- return ['train_data.pt', 'val_data.pt', 'test_data.pt']
+ file_lst = [
+ "train_data.pt",
+ "val_data.pt",
+ "test_data.pt",
+ "pre_filter.pt",
+ "pre_transform.pt",
+ "large_graph_indexer",
+ ]
+ split_file = file_lst.pop(['train', 'val', 'test'].index(self.split))
+ file_lst.insert(0, split_file)
+ return file_lst
+
+ def _save_raw_data(self) -> None:
+ self.raw_dataset.save_to_disk(self.raw_paths[0])
+ torch.save(self.split_idxs, self.raw_paths[1])
+
+ def _load_raw_data(self) -> None:
+ import datasets
+ if not hasattr(self, "raw_dataset"):
+ self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
+ if not hasattr(self, "split_idxs"):
+ self.split_idxs = torch.load(self.raw_paths[1])
- def process(self) -> None:
+ def download(self) -> None:
import datasets
- import pandas as pd
-
- datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model_name = 'sentence-transformers/all-roberta-large-v1'
- model = SentenceTransformer(model_name).to(device)
- model.eval()
-
- for dataset, path in zip(
- [datasets['train'], datasets['validation'], datasets['test']],
- self.processed_paths,
- ):
- questions = [example["question"] for example in dataset]
- question_embs = model.encode(
- questions,
- batch_size=256,
- output_device='cpu',
+
+ dataset = datasets.load_dataset("rmanluo/RoG-webqsp")
+ self.raw_dataset = datasets.concatenate_datasets(
+ [dataset["train"], dataset["validation"], dataset["test"]])
+ self.split_idxs = {
+ "train":
+ torch.arange(len(dataset["train"])),
+ "val":
+ torch.arange(len(dataset["validation"])) + len(dataset["train"]),
+ "test":
+ torch.arange(len(dataset["test"])) + len(dataset["train"]) +
+ len(dataset["validation"]),
+ }
+
+ if self.limit >= 0:
+ self.raw_dataset = self.raw_dataset.select(range(self.limit))
+
+ # HACK
+ self.split_idxs = {
+ "train":
+ torch.arange(self.limit // 2),
+ "val":
+ torch.arange(self.limit // 4) + self.limit // 2,
+ "test":
+ torch.arange(self.limit // 4) + self.limit // 2 +
+ self.limit // 4,
+ }
+ self._save_raw_data()
+
+ def _get_trips(self) -> Iterator[TripletLike]:
+ return chain.from_iterable(
+ iter(ds["graph"]) for ds in self.raw_dataset)
+
+ def _build_graph(self) -> None:
+ trips = self._get_trips()
+ self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(
+ trips, pre_transform=preprocess_triplet)
+
+ # Nodes:
+ nodes = self.indexer.get_unique_node_features()
+ x = self.model.encode(
+ nodes, # type: ignore
+ batch_size=256,
+ output_device='cpu')
+ self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x)
+
+ # Edges:
+ edges = self.indexer.get_unique_edge_features(
+ feature_name=EDGE_RELATION)
+ edge_attr = self.model.encode(
+ edges, # type: ignore
+ batch_size=256,
+ output_device='cpu')
+ self.indexer.add_edge_feature(
+ new_feature_name="edge_attr",
+ new_feature_vals=edge_attr,
+ map_from_feature=EDGE_RELATION,
+ )
+
+ print("Saving graph...")
+ self.indexer.save(self.processed_paths[-1])
+
+ def _retrieve_subgraphs(self) -> None:
+ print("Encoding questions...")
+ self.questions = [str(ds["question"]) for ds in self.raw_dataset]
+ q_embs = self.model.encode(self.questions, batch_size=256,
+ output_device='cpu')
+ list_of_graphs = []
+ print("Retrieving subgraphs...")
+ textual_nodes = self.textual_nodes
+ textual_edges = self.textual_edges
+ graph_gen = get_features_for_triplets_groups(
+ self.indexer, (ds['graph'] for ds in self.raw_dataset),
+ pre_transform=preprocess_triplet, verbose=self.verbose)
+
+ for index in tqdm(range(len(self.raw_dataset)),
+ disable=not self.verbose):
+ data_i = self.raw_dataset[index]
+ graph = next(graph_gen)
+ textual_nodes = self.textual_nodes.iloc[
+ graph["node_idx"]].reset_index()
+ textual_edges = self.textual_edges.iloc[
+ graph["edge_idx"]].reset_index()
+ pcst_subgraph, desc = retrieval_via_pcst(
+ graph,
+ q_embs[index],
+ textual_nodes,
+ textual_edges,
+ topk=3,
+ topk_e=5,
+ cost_e=0.5,
+ override=not self.include_pcst,
)
+ question = f"Question: {data_i['question']}\nAnswer: "
+ label = ("|").join(data_i["answer"]).lower()
- data_list = []
- for i, example in enumerate(tqdm(dataset)):
- raw_nodes: Dict[str, int] = {}
- raw_edges = []
- for tri in example["graph"]:
- h, r, t = tri
- h = h.lower()
- t = t.lower()
- if h not in raw_nodes:
- raw_nodes[h] = len(raw_nodes)
- if t not in raw_nodes:
- raw_nodes[t] = len(raw_nodes)
- raw_edges.append({
- "src": raw_nodes[h],
- "edge_attr": r,
- "dst": raw_nodes[t]
- })
- nodes = pd.DataFrame([{
- "node_id": v,
- "node_attr": k,
- } for k, v in raw_nodes.items()],
- columns=["node_id", "node_attr"])
- edges = pd.DataFrame(raw_edges,
- columns=["src", "edge_attr", "dst"])
-
- nodes.node_attr = nodes.node_attr.fillna("")
- x = model.encode(
- nodes.node_attr.tolist(),
- batch_size=256,
- output_device='cpu',
- )
- edge_attr = model.encode(
- edges.edge_attr.tolist(),
- batch_size=256,
- output_device='cpu',
- )
- edge_index = torch.tensor([
- edges.src.tolist(),
- edges.dst.tolist(),
- ], dtype=torch.long)
-
- question = f"Question: {example['question']}\nAnswer: "
- label = ('|').join(example['answer']).lower()
- data = Data(
- x=x,
- edge_index=edge_index,
- edge_attr=edge_attr,
- )
- data, desc = retrieval_via_pcst(
- data,
- question_embs[i],
- nodes,
- edges,
- topk=3,
- topk_e=5,
- cost_e=0.5,
- )
- data.question = question
- data.label = label
- data.desc = desc
- data_list.append(data)
-
- self.save(data_list, path)
+ pcst_subgraph["question"] = question
+ pcst_subgraph["label"] = label
+ pcst_subgraph["desc"] = desc
+ list_of_graphs.append(pcst_subgraph.to("cpu"))
+ print("Saving subgraphs...")
+ self.save(list_of_graphs, self.processed_paths[0])
+
+ def process(self) -> None:
+ from pandas import DataFrame
+ self._load_raw_data()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.model = SentenceTransformer(
+ 'sentence-transformers/all-roberta-large-v1').to(device)
+ self.model.eval()
+ if self.force_reload or not os.path.exists(self.processed_paths[-1]):
+ print("Encoding graph...")
+ self._build_graph()
+ else:
+ print("Loading graph...")
+ self.indexer = LargeGraphIndexer.from_disk(
+ self.processed_paths[-1])
+ self.textual_nodes = DataFrame.from_dict(
+ {"node_attr": self.indexer.get_node_features()})
+ self.textual_nodes["node_id"] = self.textual_nodes.index
+ self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]]
+ self.textual_edges = DataFrame(self.indexer.get_edge_features(),
+ columns=["src", "edge_attr", "dst"])
+ self.textual_edges["src"] = [
+ self.indexer._nodes[h] for h in self.textual_edges["src"]
+ ]
+ self.textual_edges["dst"] = [
+ self.indexer._nodes[h] for h in self.textual_edges["dst"]
+ ]
+ self._retrieve_subgraphs()
diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py
index 266f498a113b..7e83c35befb6 100644
--- a/torch_geometric/loader/__init__.py
+++ b/torch_geometric/loader/__init__.py
@@ -22,6 +22,7 @@
from .prefetch import PrefetchLoader
from .cache import CachedLoader
from .mixin import AffinityMixin
+from .rag_loader import RAGQueryLoader
__all__ = classes = [
'DataLoader',
@@ -50,6 +51,7 @@
'PrefetchLoader',
'CachedLoader',
'AffinityMixin',
+ 'RAGQueryLoader',
]
RandomNodeSampler = deprecated(
diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py
new file mode 100644
index 000000000000..33d6cf0e868e
--- /dev/null
+++ b/torch_geometric/loader/rag_loader.py
@@ -0,0 +1,106 @@
+from abc import abstractmethod
+from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
+
+from torch_geometric.data import Data, FeatureStore, HeteroData
+from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
+from torch_geometric.typing import InputEdges, InputNodes
+
+
+class RAGFeatureStore(Protocol):
+ """Feature store for remote GNN RAG backend."""
+ @abstractmethod
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
+ """Makes a comparison between the query and all the nodes to get all
+ the closest nodes. Return the indices of the nodes that are to be seeds
+ for the RAG Sampler.
+ """
+ ...
+
+ @abstractmethod
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
+ """Makes a comparison between the query and all the edges to get all
+ the closest nodes. Returns the edge indices that are to be the seeds
+ for the RAG Sampler.
+ """
+ ...
+
+ @abstractmethod
+ def load_subgraph(
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
+ ) -> Union[Data, HeteroData]:
+ """Combines sampled subgraph output with features in a Data object."""
+ ...
+
+
+class RAGGraphStore(Protocol):
+ """Graph store for remote GNN RAG backend."""
+ @abstractmethod
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
+ """Sample a subgraph using the seeded nodes and edges."""
+ ...
+
+ @abstractmethod
+ def register_feature_store(self, feature_store: FeatureStore):
+ """Register a feature store to be used with the sampler. Samplers need
+ info from the feature store in order to work properly on HeteroGraphs.
+ """
+ ...
+
+
+# TODO: Make compatible with Heterographs
+
+
+class RAGQueryLoader:
+ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
+ local_filter: Optional[Callable[[Data, Any], Data]] = None,
+ seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
+ seed_edges_kwargs: Optional[Dict[str, Any]] = None,
+ sampler_kwargs: Optional[Dict[str, Any]] = None,
+ loader_kwargs: Optional[Dict[str, Any]] = None):
+ """Loader meant for making queries from a remote backend.
+
+ Args:
+ data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
+ and GraphStore to load from. Assumed to conform to the
+ protocols listed above.
+ local_filter (Optional[Callable[[Data, Any], Data]], optional):
+ Optional local transform to apply to data after retrieval.
+ Defaults to None.
+ seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
+ to pass into process for fetching seed nodes. Defaults to None.
+ seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
+ to pass into process for fetching seed edges. Defaults to None.
+ sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
+ pass into process for sampling graph. Defaults to None.
+ loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
+ pass into process for loading graph features. Defaults to None.
+ """
+ fstore, gstore = data
+ self.feature_store = fstore
+ self.graph_store = gstore
+ self.graph_store.register_feature_store(self.feature_store)
+ self.local_filter = local_filter
+ self.seed_nodes_kwargs = seed_nodes_kwargs or {}
+ self.seed_edges_kwargs = seed_edges_kwargs or {}
+ self.sampler_kwargs = sampler_kwargs or {}
+ self.loader_kwargs = loader_kwargs or {}
+
+ def query(self, query: Any) -> Data:
+ """Retrieve a subgraph associated with the query with all its feature
+ attributes.
+ """
+ seed_nodes = self.feature_store.retrieve_seed_nodes(
+ query, **self.seed_nodes_kwargs)
+ seed_edges = self.feature_store.retrieve_seed_edges(
+ query, **self.seed_edges_kwargs)
+
+ subgraph_sample = self.graph_store.sample_subgraph(
+ seed_nodes, seed_edges, **self.sampler_kwargs)
+
+ data = self.feature_store.load_subgraph(sample=subgraph_sample,
+ **self.loader_kwargs)
+
+ if self.local_filter:
+ data = self.local_filter(data, query)
+ return data
diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py
index 6f8fbcc644dc..f7529ae721b7 100644
--- a/torch_geometric/nn/models/g_retriever.py
+++ b/torch_geometric/nn/models/g_retriever.py
@@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module):
(default: :obj:`False`)
mlp_out_channels (int, optional): The size of each graph embedding
after projection. (default: :obj:`4096`)
+ mlp_out_tokens (int, optional): Number of LLM prefix tokens to
+ reserve for GNN output. (default: :obj:`1`)
.. warning::
This module has been tested with the following HuggingFace models
@@ -43,6 +45,7 @@ def __init__(
gnn: torch.nn.Module,
use_lora: bool = False,
mlp_out_channels: int = 4096,
+ mlp_out_tokens: int = 1,
) -> None:
super().__init__()
@@ -77,7 +80,9 @@ def __init__(
self.projector = torch.nn.Sequential(
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
torch.nn.Sigmoid(),
- torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
+ torch.nn.Linear(mlp_hidden_channels,
+ mlp_out_channels * mlp_out_tokens),
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
).to(self.llm.device)
def encode(
@@ -126,6 +131,9 @@ def forward(
x = self.projector(x)
xs = x.split(1, dim=0)
+ # Handle case where theres more than one embedding for each sample
+ xs = [x.squeeze(0) for x in xs]
+
# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
@@ -182,6 +190,9 @@ def inference(
x = self.projector(x)
xs = x.split(1, dim=0)
+ # Handle case where theres more than one embedding for each sample
+ xs = [x.squeeze(0) for x in xs]
+
# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py
index 833ee657d0e7..22d3039f4c83 100644
--- a/torch_geometric/profile/__init__.py
+++ b/torch_geometric/profile/__init__.py
@@ -20,6 +20,7 @@
get_gpu_memory_from_nvidia_smi,
get_model_size,
)
+from .nvtx import nvtxit
__all__ = [
'profileit',
@@ -38,6 +39,7 @@
'get_gpu_memory_from_nvidia_smi',
'get_gpu_memory_from_ipex',
'benchmark',
+ 'nvtxit',
]
classes = __all__
diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py
new file mode 100644
index 000000000000..8dbce375ae5a
--- /dev/null
+++ b/torch_geometric/profile/nvtx.py
@@ -0,0 +1,66 @@
+from functools import wraps
+from typing import Optional
+
+import torch
+
+CUDA_PROFILE_STARTED = False
+
+
+def begin_cuda_profile():
+ global CUDA_PROFILE_STARTED
+ prev_state = CUDA_PROFILE_STARTED
+ if prev_state is False:
+ CUDA_PROFILE_STARTED = True
+ torch.cuda.cudart().cudaProfilerStart()
+ return prev_state
+
+
+def end_cuda_profile(prev_state: bool):
+ global CUDA_PROFILE_STARTED
+ CUDA_PROFILE_STARTED = prev_state
+ if prev_state is False:
+ torch.cuda.cudart().cudaProfilerStop()
+
+
+def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
+ n_iters: Optional[int] = None):
+ """Enables NVTX profiling for a function.
+
+ Args:
+ name (Optional[str], optional): Name to give the reference frame for
+ the function being wrapped. Defaults to the name of the
+ function in code.
+ n_warmups (int, optional): Number of iters to call that function
+ before starting. Defaults to 0.
+ n_iters (Optional[int], optional): Number of iters of that function to
+ record. Defaults to all of them.
+ """
+ def nvtx(func):
+
+ nonlocal name
+ iters_so_far = 0
+ if name is None:
+ name = func.__name__
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ nonlocal iters_so_far
+ if not torch.cuda.is_available():
+ return func(*args, **kwargs)
+ elif iters_so_far < n_warmups:
+ iters_so_far += 1
+ return func(*args, **kwargs)
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
+ prev_state = begin_cuda_profile()
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
+ result = func(*args, **kwargs)
+ torch.cuda.nvtx.range_pop()
+ end_cuda_profile(prev_state)
+ iters_so_far += 1
+ return result
+ else:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return nvtx