diff --git a/CHANGELOG.md b/CHANGELOG.md index 271acd853b5d..221f02fcd2e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added various GRetriever Architecture Benchmarking examples ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) +- Added `profiler.nvtxit` with some examples ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) +- Added `loader.RagQueryLoader` with Remote Backend Example ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) +- Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) + ### Changed ### Deprecated diff --git a/docs/source/_figures/flowchart.svg b/docs/source/_figures/flowchart.svg new file mode 100644 index 000000000000..188d37b14f41 --- /dev/null +++ b/docs/source/_figures/flowchart.svg @@ -0,0 +1 @@ + diff --git a/docs/source/_figures/multihop_example.svg b/docs/source/_figures/multihop_example.svg new file mode 100644 index 000000000000..4925dcb9713d --- /dev/null +++ b/docs/source/_figures/multihop_example.svg @@ -0,0 +1 @@ + diff --git a/docs/source/_figures/remote_backend.svg b/docs/source/_figures/remote_backend.svg new file mode 100644 index 000000000000..c5791f0a95de --- /dev/null +++ b/docs/source/_figures/remote_backend.svg @@ -0,0 +1 @@ + diff --git a/docs/source/advanced/rag.rst b/docs/source/advanced/rag.rst new file mode 100644 index 000000000000..2d8513fd56e4 --- /dev/null +++ b/docs/source/advanced/rag.rst @@ -0,0 +1,614 @@ +Working with LLM RAG in Pytorch Geometric +========================================= + +This series aims to provide a starting point and for +multi-step LLM Retrieval Augmented Generation +(RAG) using Graph Neural Networks. + +Motivation +---------- + +As Large Language Models (LLMs) quickly grow to dominate industry, they +are increasingly being deployed at scale in use cases that require very +specific contextual expertise. LLMs often struggle with these cases out +of the box, as they will hallucinate answers that are not included in +their training data. At the same time, many business already have large +graph databases full of important context that can provide important +domain-specific context to reduce hallucination and improve answer +fidelity for LLMs. Graph Neural Networks (GNNs) provide a means for +efficiently encoding this contextual information into the model, which +can help LLMs to better understand and generate answers. Hence, theres +an open research question as to how to effectively use GNN encodings +efficiently for this purpose, that the tooling provided here can help +investigate. + +Architecture +------------ + +To model the use-case of RAG from a large knowledge graph of millions of +nodes, we present the following architecture: + + + + + +.. figure:: ../_figures/flowchart.svg + :align: center + :width: 100% + + + +Graph RAG as shown in the diagram above follows the following order of +operations: + +0. To start, not pictured here, there must exist a large knowledge graph + that exists as a source of truth. The nodes and edges of this + knowledge graph + +During inference time, RAG implementations that follow this architecture +are composed of the following steps: + +1. Tokenize and encode the query using the LLM Encoder +2. Retrieve a subgraph of the larger knowledge graph (KG) relevant to + the query and encode it using a GNN +3. Jointly embed the GNN embedding with the LLM embedding +4. Utilize LLM Decoder to decode joint embedding and generate a response + + + + +Encoding a Large Knowledge Graph +================================ + +To start, a Large Knowledge Graph needs to be created from triplets or +multiple subgraphs in a dataset. + +Example 1: Building from Already Existing Datasets +-------------------------------------------------- + +In most RAG scenarios, the subset of the information corpus that gets +retrieved is crucial for whether the appropriate response to the LLM. +The same is true for GNN based RAG. For example, consider the +WebQSPDataset. + +.. code:: python + + from torch_geometric.datasets import WebQSPDataset + + num_questions = 100 + ds = WebQSPDataset('small_sample', limit=num_questions) + + +WebQSP is a dataset that is based off of a subset of the Freebase +Knowledge Graph, which is an open-source knowledge graph formerly +maintained by Google. For each question-answer pair in the dataset, a +subgraph was chosen based on a Semantic SPARQL search on the larger +knowledge graph, to provide relevent context on finding the answer. So +each entry in the dataset consists of: +- A question to be answered +- The answer +- A knowledge graph subgraph of Freebase that has the context +needed to answer the question. + +.. code:: python + + ds.raw_dataset + + >>> Dataset({ + features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'], + num_rows: 100 + }) + + + +.. code:: python + + ds.raw_dataset[0] + + + >>> {'id': 'WebQTrn-0', + 'question': 'what is the name of justin bieber brother', + 'answer': ['Jaxon Bieber'], + 'q_entity': ['Justin Bieber'], + 'a_entity': ['Jaxon Bieber'], + 'graph': [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'], + ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'], + ...], + 'choices': []} + + + +Although this dataset can be trained on as-is, a couple problems emerge +from doing so: +1. A retrieval algorithm needs to be implemented and +executed during inference time, that might not appropriately correspond +to the algorithm that was used to generate the dataset subgraphs. +1. The dataset as is not stored computationally efficiently, as there will +exist many duplicate nodes and edges that are shared between the +questions. + +As a result, it makes sense in this scenario to be able to encode all +the entries into a large knowledge graph, so that duplicate nodes and +edges can be avoided, and so that alternative retrieval algorithms can +be tried. We can do this with the LargeGraphIndexer class: + +.. code:: python + + from torch_geometric.data import LargeGraphIndexer, Data, get_features_for_triplets_groups + from torch_geometric.nn.nlp import SentenceTransformer + import time + import torch + import tqdm + from itertools import chain + import networkx as nx + +.. code:: python + + raw_dataset_graphs = [[tuple(trip) for trip in graph] for graph in ds.raw_dataset['graph']] + print(raw_dataset_graphs[0][:10]) + + >>> [('P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'), ('1Club.FM: Power', 'broadcast.content.artist', 'P!nk'), ...] + + +To show the benefits of this indexer in action, we will use the +following model to encode this sample of graphs using LargeGraphIndexer, +along with naively. + +.. code:: python + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SentenceTransformer(model_name='sentence-transformers/all-roberta-large-v1').to(device) + + +First, we compare the clock times of encoding using both methods. + +.. code:: python + + # Indexing question-by-question + dataset_graphs_embedded = [] + start = time.time() + for graph in tqdm.tqdm(raw_dataset_graphs): + nodes_map = dict() + edges_map = dict() + edge_idx_base = [] + + for src, edge, dst in graph: + # Collect nodes + if src not in nodes_map: + nodes_map[src] = len(nodes_map) + if dst not in nodes_map: + nodes_map[dst] = len(nodes_map) + + # Collect edge types + if edge not in edges_map: + edges_map[edge] = len(edges_map) + + # Record edge + edge_idx_base.append((nodes_map[src], edges_map[edge], nodes_map[dst])) + + # Encode nodes and edges + sorted_nodes = list(sorted(nodes_map.keys(), key=lambda x: nodes_map[x])) + sorted_edges = list(sorted(edges_map.keys(), key=lambda x: edges_map[x])) + + x = model.encode(sorted_nodes, batch_size=256) + edge_attrs_map = model.encode(sorted_edges, batch_size=256) + + edge_attrs = [] + edge_idx = [] + for trip in edge_idx_base: + edge_attrs.append(edge_attrs_map[trip[1]]) + edge_idx.append([trip[0], trip[2]]) + + dataset_graphs_embedded.append(Data(x=x, edge_index=torch.tensor(edge_idx).T, edge_attr=torch.stack(edge_attrs, dim=0))) + + + print(time.time()-start) + + >>> 121.68579435348511 + + + +.. code:: python + + # Using LargeGraphIndexer to make one large knowledge graph + from torch_geometric.data.large_graph_indexer import EDGE_RELATION + + start = time.time() + all_triplets_together = chain.from_iterable(raw_dataset_graphs) + # Index as one large graph + print('Indexing...') + indexer = LargeGraphIndexer.from_triplets(all_triplets_together) + + # first the nodes + unique_nodes = indexer.get_unique_node_features() + node_encs = model.encode(unique_nodes, batch_size=256) + indexer.add_node_feature(new_feature_name='x', new_feature_vals=node_encs) + + # then the edges + unique_edges = indexer.get_unique_edge_features(feature_name=EDGE_RELATION) + edge_attr = model.encode(unique_edges, batch_size=256) + indexer.add_edge_feature(new_feature_name="edge_attr", new_feature_vals=edge_attr, map_from_feature=EDGE_RELATION) + + ckpt_time = time.time() + whole_knowledge_graph = indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr') + whole_graph_done = time.time() + print(f"Time to create whole knowledge_graph: {whole_graph_done-start}") + + # Compute this to make sure we're comparing like to like on final time printout + whole_graph_diff = whole_graph_done-ckpt_time + + # retrieve subgraphs + print('Retrieving Subgraphs...') + dataset_graphs_embedded_largegraphindexer = [graph for graph in tqdm.tqdm(get_features_for_triplets_groups(indexer=indexer, triplet_groups=raw_dataset_graphs), total=num_questions)] + print(time.time()-start-whole_graph_diff) + + >>> Indexing... + >>> Time to create whole knowledge_graph: 114.01080107688904 + >>> Retrieving Subgraphs... + >>> 114.66037964820862 + + +The large graph indexer allows us to compute the entire knowledge graph +from a series of samples, so that new retrieval methods can also be +tested on the entire graph. We will see this attempted in practice later +on. + +It’s worth noting that, although the times are relatively similar right +now, the speedup with largegraphindexer will be much higher as the size +of the knowledge graph grows. This is due to the speedup being a factor +of the number of unique nodes and edges in the graph. + + +We expect the two results to be functionally identical, with the +differences being due to floating point jitter. + +.. code:: python + + def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.8): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all(torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1) > thresh) + def _graphs_are_same(tensor1, tensor2): + return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T)) + return _sorted_tensors_are_close(ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same(ground_truth.edge_index, new_method.edge_index) + + + all_results_match = True + for old_graph, new_graph in tqdm.tqdm(zip(dataset_graphs_embedded, dataset_graphs_embedded_largegraphindexer), total=num_questions): + all_results_match &= results_are_close_enough(old_graph, new_graph) + all_results_match + + >>> True + + + +When scaled up to the entire dataset, we see a 2x speedup with indexing +this way on the WebQSP Dataset. + +Example 2: Building a new Dataset from Questions and an already-existing Knowledge Graph +---------------------------------------------------------------------------------------- + +Motivation +~~~~~~~~~~ + +One potential application of knowledge graph structural encodings is +capturing the relationships between different entities that are multiple +hops apart. This can be challenging for an LLM to recognize from +prepended graph information. Here’s a motivating example (credit to +@Rishi Puri): + + +.. figure:: ../_figures/multihop_example.svg + :align: center + :width: 100% + + + +In this example, the question can only be answered by reasoning about +the relationships between the entities in the knowledge graph. + +Building a Multi-Hop QA Dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To see an example of encoding a large knowledge graph starting from an +existing set of triplets, check out the multi-hop example in +`examples/llm_plus_gnn/multihop_rag`. + +Question: How do we extract a contextual subgraph for a given query? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The chosen retrieval algorithm is a critical component in the pipeline +for affecting RAG performance. In the next section, we will +demonstrate a naive method of retrieval for a large knowledge graph. + + +Retrieval Algorithms and Scaling Retrieval +========================================== + +Motivation +---------- + +When building a RAG Pipeline for inference, the retrieval component is +important for the following reasons: +1. A given algorithm for retrieving subgraph context can have a +marked effect on the hallucination rate of the responses in the model +2. A given retrieval algorithm needs to be able to scale to larger +graphs of millions of nodes and edges in order to be practical for production. + +In this section, we will explore how to construct a RAG retrieval +algorithm from a given subgraph, and conduct some experiments to +evaluate its runtime performance. + +We want to do so in-line with Pytorch Geometric’s in-house framework for +remote backends: + + +.. figure:: ../_figures/remote_2.png + :align: center + :width: 100% + + + +As seen here, the GraphStore is used to store the neighbor relations +between the nodes of the graph, whereas the FeatureStore is used to +store the node and edge features in the graph. + +Let’s start by loading in a knowledge graph dataset for the sake of our +experiment: + +.. code:: python + + from torch_geometric.data import LargeGraphIndexer + from torch_geometric.datasets import WebQSPDataset + from itertools import chain + + ds = WebQSPDataset(root='demo', limit=10) + +Let’s set up our set of questions and graph triplets: + +.. code:: python + + questions = ds.raw_dataset['question'] + questions + + >>> ['what is the name of justin bieber brother', + 'what character did natalie portman play in star wars', + 'what country is the grand bahama island in', + 'what kind of money to take to bahamas', + 'what character did john noble play in lord of the rings', + 'who does joakim noah play for', + 'where are the nfl redskins from', + 'where did saki live', + 'who did draco malloy end up marrying', + 'which countries border the us'] + + + ds.raw_dataset[:10]['graph'][0][:10] + + + >>> [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'], + ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'], + ['Somebody to Love', 'music.recording.contributions', 'm.0rqp4h0'], + ['Rudolph Valentino', 'freebase.valuenotation.is_reviewed', 'Place of birth'], + ['Ice Cube', 'broadcast.artist.content', '.977 The Hits Channel'], + ['Colbie Caillat', 'broadcast.artist.content', 'Hot Wired Radio'], + ['Stephen Melton', 'people.person.nationality', 'United States of America'], + ['Record producer', + 'music.performance_role.regular_performances', + 'm.012m1vf1'], + ['Justin Bieber', 'award.award_winner.awards_won', 'm.0yrkc0l'], + ['1.FM Top 40', 'broadcast.content.artist', 'Geri Halliwell']] + + + all_triplets = chain.from_iterable((row['graph'] for row in ds.raw_dataset)) + +With these questions and triplets, we want to: +1. Consolidate all the relations in these triplets into a Knowledge Graph +2. Create a FeatureStore that encodes all the nodes and edges in the knowledge graph +3. Create a GraphStore that encodes all the edge indices in the knowledge graph + +In order to create a remote backend, we need to define a FeatureStore +and GraphStore locally, as well as a method for initializing its state +from triplets. The code methods used in this tutorial can be found in +`examples/llm_plus_gnn`. + +.. code:: python + + from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet + from rag_construction_utils import create_remote_backend_from_triplets, RemoteGraphBackendLoader + + # We define this GraphStore to sample the neighbors of a node locally. + # Ideally for a real remote backend, this interface would be replaced with an API to a Graph DB, such as Neo4j. + from rag_graph_store import NeighborSamplingRAGGraphStore + + # We define this FeatureStore to encode the nodes and edges locally, and perform appoximate KNN when indexing. + # Ideally for a real remote backend, this interface would be replaced with an API to a vector DB, such as Pinecone. + from rag_feature_store import SentenceTransformerFeatureStore + +.. code:: python + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SentenceTransformer(model_name="sentence-transformers/all-roberta-large-v1").to(device) + + backend_loader: RemoteGraphBackendLoader = create_remote_backend_from_triplets( + triplets=all_triplets, # All the triplets to insert into the backend + node_embedding_model=model, # Embedding model to process triplets with + node_method_to_call="encode", # This method will encode the nodes/edges with 'model.encode' in this case. + path="backend", # Save path + pre_transform=preprocess_triplet, # Preprocessing function to apply to triplets before invoking embedding model. + node_method_kwargs={"batch_size": 256}, # Keyword arguments to pass to the node_method_to_call. + graph_db=NeighborSamplingRAGGraphStore, # Graph Store to use + feature_db=SentenceTransformerFeatureStore # Feature Store to use + ) + # This loader saves a copy of the processed data locally to be transformed into a graphstore and featurestore when load() is called. + feature_store, graph_store = backend_loader.load() + +Now that we have initialized our remote backends, we can now retrieve +from them using a Loader to query the backends, as shown in this +diagram: + + +.. figure:: ../_figures/remote_3.png + :align: center + :width: 100% + + + +.. code:: python + + from torch_geometric.loader import RAGQueryLoader + + query_loader = RAGQueryLoader( + data=(feature_store, graph_store), # Remote Rag Graph Store and Feature Store + # Arguments to pass into the seed node/edge retrieval methods for the FeatureStore. + # In this case, it's k for the KNN on the nodes and edges. + seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10}, + # Arguments to pass into the GraphStore's Neighbor sampling method. + # In this case, the GraphStore implements a NeighborLoader, so it takes the same arguments. + sampler_kwargs={"num_neighbors": [40]*3}, + # Arguments to pass into the FeatureStore's feature loading method. + loader_kwargs={}, + # An optional local transform that can be applied on the returned subgraph. + local_filter=None, + ) + +To make better sense of this loader’s arguments, let’s take a closer +look at the retrieval process for a remote backend: + + +.. figure:: ../_figures/remote_backend.svg + :align: center + :width: 100% + + + +As we see here, there are 3 important steps to any remote backend +procedure for graphs: +1. Retrieve the seed nodes and edges to begin our retrieval process from. +2. Traverse the graph neighborhood of the seed nodes/edges to gather local context. +3. Fetch the features associated with the subgraphs obtained from the traversal. + +We can see that our Query Loader construction allows us to specify +unique hyperparameters for each unique step in this retrieval. + +Now we can submit our queries to the remote backend to retrieve our +subgraphs: + +.. code:: python + + sub_graphs = [] + for q in tqdm.tqdm(questions): + sub_graphs.append(query_loader.query(q)) + + + sub_graphs[0] + + >>> Data(x=[2251, 1024], edge_index=[2, 7806], edge_attr=[7806, 1024], node_idx=[2251], edge_idx=[7806]) + + + +These subgraphs are now retrieved using a different retrieval method +when compared to the original WebQSP dataset. Can we compare the +properties of this method to the original WebQSPDataset’s retrieval +method? Let’s compare some basics properties of the subgraphs: + +.. code:: python + + def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e-(subg_e | gt_e)) + return (tp+tn)/num_edges + def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + + ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet) + num_edges = len(ds.indexer._edges) + + + for subg, ground_truth in tqdm.tqdm(zip((query_loader.query(q) for q in questions), ground_truth_graphs)): + print(f"Size: {len(subg.x)}, Ground Truth Size: {len(ground_truth.x)}, Accuracy: {check_retrieval_accuracy(subg, ground_truth, num_edges)}, Precision: {check_retrieval_precision(subg, ground_truth)}, Recall: {check_retrieval_recall(subg, ground_truth)}") + + >>> Size: 2193, Ground Truth Size: 1709, Accuracy: 0.6636780705203827, Precision: 0.22923807012918535, Recall: 0.1994037381034285 + >>> Size: 2682, Ground Truth Size: 1251, Accuracy: 0.7158736400576746, Precision: 0.10843513670738801, Recall: 0.22692963233503774 + >>> Size: 2087, Ground Truth Size: 1285, Accuracy: 0.7979813868134749, Precision: 0.0547879177377892, Recall: 0.15757855822550831 + >>> Size: 2975, Ground Truth Size: 1988, Accuracy: 0.6956088609254162, Precision: 0.14820555621795636, Recall: 0.21768826619964973 + >>> Size: 2594, Ground Truth Size: 633, Accuracy: 0.78849128326124, Precision: 0.04202616198163095, Recall: 0.2032301480484522 + >>> Size: 2462, Ground Truth Size: 1044, Accuracy: 0.7703499803381832, Precision: 0.07646643109540636, Recall: 0.19551861221539574 + >>> Size: 2011, Ground Truth Size: 1382, Accuracy: 0.7871804954777821, Precision: 0.10117783355860205, Recall: 0.13142713819914723 + >>> Size: 2011, Ground Truth Size: 1052, Accuracy: 0.802831301612269, Precision: 0.06452691407556001, Recall: 0.16702726092600606 + >>> Size: 2892, Ground Truth Size: 1012, Accuracy: 0.7276182985974571, Precision: 0.10108615156751419, Recall: 0.20860927152317882 + >>> Size: 1817, Ground Truth Size: 1978, Accuracy: 0.7530475815965395, Precision: 0.1677807486631016, Recall: 0.11696178937558248 + + + +Note that, since we’re only comparing the results of 10 graphs here, +this retrieval algorithm is not taking into account the full corpus of +nodes in the dataset. If you want to see a full example, look at +``rag_generate.py``, or ``rag_generate_multihop.py`` These examples +generate datasets for the entirety of the WebQSP dataset, or the +WikiData Multihop datasets that are discussed in Section 0. + +Evaluating Runtime Performance +------------------------------ + +Pytorch Geometric provides multiple methods for evalutaing runtime +performance. In this notebook, we utilize NVTX to profile the different +components of our RAG Query Loader. + +The method ``nvtxit`` allows for profiling the utilization and timings +of any methods that get wrapped by it in a Python script. + +To see an example of this, check out +``nvtx_examples/nvtx_rag_backend_example.py``. + +This script mirrors this notebook’s functionality, but notably, it +includes the following code snippet: + +.. code:: python + + # Patch FeatureStore and GraphStore + + SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes) + SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges) + SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph) + NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph) + rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +Importantly, this snippet wraps the methods of FeatureStore, GraphStore, +and the Query method from QueryLoader so that it will be recognized as a +unique frame in NVTX. + +This can be executed by the included shell script ``nvtx_run.sh``: + +.. code:: bash + + ... + + # Get the base name of the Python file + python_file=$(basename "$1") + + # Run nsys profile on the Python file + nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" + + echo "Profile data saved as profile_${python_file%.py}.nsys-rep" + +The generated resulting ``.nsys-rep`` file can be visualized using tools +like Nsight Systems or Nsight Compute, that can show the relative +timings of the FeatureStore, GraphStore, and QueryLoader methods. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1c67eeddeec6..02315461c6ec 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,6 +44,7 @@ In addition, it consists of easy-to-use mini-batch loaders for operating on many advanced/remote advanced/graphgym advanced/cpu_affinity + advanced/rag .. toctree:: :maxdepth: 1 diff --git a/examples/README.md b/examples/README.md index 64fc2dffdcdd..af324f09f79d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,9 +14,9 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see - [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges. - [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph). -For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile). +For examples on co-training LLM with GNN, see examples and README under [`examples/llm_plus_gnn`](./llm_plus_gnn). -For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu). +For examples on using `torch.compile`, see examples and README under [`examples/compile`](./compile). For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero). diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..5177e501d610 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,8 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| Example | Description | +| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | +| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | +| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 48b012917553..2844eb11ff3b 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -8,25 +8,52 @@ `pip install datasets transformers pcst_fast sentencepiece accelerate` """ import argparse +import gc import math -import os.path as osp +import multiprocessing as mp import re import time +from typing import Any, Callable, Dict, List, Type import pandas as pd import torch +import torch.nn as nn from torch import Tensor from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm from torch_geometric import seed_everything +from torch_geometric.data import Dataset from torch_geometric.datasets import WebQSPDataset from torch_geometric.loader import DataLoader from torch_geometric.nn.models import GAT, GRetriever from torch_geometric.nn.nlp import LLM -def compute_metrics(eval_output): +def _detect_hallucinate(inp): + pred, label = inp + try: + split_pred = pred.split('[/s]')[0].strip().split('|') + correct_hit = len(re.findall(split_pred[0], label)) > 0 + correct_hit = correct_hit or any( + [label_i in pred.lower() for label_i in label.split('|')]) + hallucination = not correct_hit + return hallucination + except: # noqa + return "skip" + + +def detect_hallucinate(pred_batch, label_batch): + r"""An approximation for the unsolved task of detecting hallucinations. + We define a hallucination as an output that contains no instances of + acceptable label. + """ + with mp.Pool(len(pred_batch)) as p: + res = p.map(_detect_hallucinate, zip(pred_batch, label_batch)) + return res + + +def compute_accuracy(eval_output) -> float: df = pd.concat([pd.DataFrame(d) for d in eval_output]) all_hit = [] all_precision = [] @@ -68,6 +95,12 @@ def compute_metrics(eval_output): print(f'Recall: {recall:.4f}') print(f'F1: {f1:.4f}') + return hit + + +def compute_n_parameters(model: torch.nn.Module) -> int: + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + def save_params_dict(model, save_path): state_dict = model.state_dict() @@ -112,6 +145,9 @@ def train( lr, checkpointing=False, tiny_llama=False, + model=None, + dataset=None, + model_save_name=None, ): def adjust_learning_rate(param_group, LR, epoch): # Decay the learning rate with half-cycle cosine after warmup @@ -127,13 +163,19 @@ def adjust_learning_rate(param_group, LR, epoch): return lr start_time = time.time() - path = osp.dirname(osp.realpath(__file__)) - path = osp.join(path, '..', '..', 'data', 'WebQSPDataset') - train_dataset = WebQSPDataset(path, split='train') - val_dataset = WebQSPDataset(path, split='val') - test_dataset = WebQSPDataset(path, split='test') - seed_everything(42) + if dataset is None: + dataset = WebQSPDataset() + gc.collect() + elif not isinstance(dataset, Dataset) and callable(dataset): + dataset = dataset() + gc.collect() + idx_split = dataset.split_idxs + + # Step 1: Build Node Classification Dataset + train_dataset = [dataset[i] for i in idx_split['train']] + val_dataset = [dataset[i] for i in idx_split['val']] + test_dataset = [dataset[i] for i in idx_split['test']] train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, pin_memory=True, shuffle=True) @@ -142,24 +184,28 @@ def adjust_learning_rate(param_group, LR, epoch): test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) - gnn = GAT( - in_channels=1024, - hidden_channels=hidden_channels, - out_channels=1024, - num_layers=num_gnn_layers, - heads=4, - ) - if tiny_llama: - llm = LLM( - model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', - num_params=1, + if model is None: + gc.collect() + gnn = GAT( + in_channels=1024, + hidden_channels=hidden_channels, + out_channels=1024, + num_layers=num_gnn_layers, + heads=4, ) - model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) - else: - llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) - model = GRetriever(llm=llm, gnn=gnn) + if tiny_llama: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + ) + model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) + else: + llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) + model = GRetriever(llm=llm, gnn=gnn) + + if model_save_name is None: + model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm' - model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm' params = [p for _, p in model.named_parameters() if p.requires_grad] optimizer = torch.optim.AdamW([ { @@ -240,10 +286,376 @@ def adjust_learning_rate(param_group, LR, epoch): eval_output.append(eval_data) progress_bar_test.update(1) + # Step 6 Post-processing & compute metrics compute_metrics(eval_output) print(f"Total Training Time: {time.time() - start_time:2f}s") save_params_dict(model, f'{model_save_name}.pt') torch.save(eval_output, f'{model_save_name}_eval_outs.pt') + print("Done!") + return prep_time, dataset, eval_output + + +def _eval_hallucinations_on_loader(outs, loader, eval_batch_size): + model_save_list = [] + model_preds = [] + for out in outs: + model_preds += out['pred'] + for i, batch in enumerate(loader): + correct_answer = batch.label + + model_pred = model_preds[i * eval_batch_size:(i + 1) * eval_batch_size] + model_hallucinates = detect_hallucinate(model_pred, correct_answer) + model_save_list += [tup for tup in zip(model_pred, model_hallucinates)] + return model_save_list + + +def benchmark_models(models: List[Type[nn.Module]], model_names: List[str], + model_kwargs: List[Dict[str, Any]], dataset: Dataset, + lr: float, epochs: int, batch_size: int, + eval_batch_size: int, loss_fn: Callable, + inference_fn: Callable, skip_LLMs: bool = True, + tiny_llama: bool = False, checkpointing: bool = True, + force: bool = False, root_dir='.'): + """Utility function for creating a model benchmark for GRetriever that + grid searches over hyperparameters. Produces a DataFrame containing + metrics for each model. + + Args: + models (List[Type[nn.Module]]): Models to be benchmarked. + model_names (List[str]): Name of save files for model checkpoints + model_kwargs (List[Dict[str, Any]]): Parameters to use for each + particular model. + dataset (Dataset): Input dataset to train on. + lr (float): Learning rate + epochs (int): Number of epochs + batch_size (int): Batch size for training + eval_batch_size (int): Batch size for eval. Also determines + hallucination detection concurrancy. + loss_fn (Callable): Loss function + inference_fn (Callable): Inference function + skip_LLMs (bool, optional): Whether to skip LLM-only runs. + Defaults to True. + tiny_llama (bool, optional): Whether to use tiny llama as LLM. + Defaults to False. + checkpointing (bool, optional): Whether to checkpoint models. + Defaults to True. + force (bool, optional): Whether to rerun already existing results. + Defaults to False. + root_dir (str, optional): Dir to save results and checkpoints in. + Defaults to '.'. + """ + model_log: Dict[str, Dict[str, Any]] = dict() + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if not skip_LLMs: + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + + if force or not path.exists(root_dir + "/pure_llm_model_log.pt"): + model_log["pure_llm"] = dict() + + pure_preds = [] + for batch in tqdm(loader): + pure_llm_preds = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_preds += pure_llm_preds + pure_preds = [{"pred": pred} for pred in pure_preds] + + model_log["pure_llm"]["preds"] = pure_preds + model_log["pure_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(pure_preds, loader, + eval_batch_size) + model_log["pure_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["pure_llm"], + root_dir + "/pure_llm_model_log.pt") + else: + model_log["pure_llm"] = \ + torch.load(root_dir+"/pure_llm_model_log.pt") + + # LORA + if force or not path.exists(root_dir + "/tuned_llm_model_log.pt"): + model_log["tuned_llm"] = dict() + since = time.time() + gc.collect() + prep_time, _, lora_eval_outs = train(since, epochs, None, None, + batch_size, eval_batch_size, + lr, loss_fn, inference_fn, + model=pure_llm, + dataset=dataset) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log["tuned_llm"]["prep_time"] = prep_time + model_log["tuned_llm"]["e2e_time"] = e2e_time + model_log["tuned_llm"]["eval_output"] = lora_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + + model_log["tuned_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(lora_eval_outs, loader, + eval_batch_size) + model_log["tuned_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["tuned_llm"], + root_dir + "/tuned_llm_model_log.pt") + else: + model_log["tuned_llm"] = \ + torch.load(root_dir+"/tuned_llm_model_log.pt") + + del pure_llm + gc.collect() + + # All other models + for name, Model, kwargs in zip(model_names, models, model_kwargs): + model_log[name] = dict() + train_model = True + if path.exists(root_dir + f"/{name}.pt") and not force: + print(f"Model {name} appears to already exist.") + print("Would you like to retrain?") + train_model = str(input("(y/n):")).lower() == "y" + + if train_model: + since = time.time() + gc.collect() + model = Model(**kwargs) + prep_time, _, model_eval_outs = train( + since=since, num_epochs=epochs, hidden_channels=None, + num_gnn_layers=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, lr=lr, loss_fn=loss_fn, + inference_fn=inference_fn, checkpointing=checkpointing, + tiny_llama=tiny_llama, dataset=dataset, + model_save_name=root_dir + '/' + name, model=model) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log[name]["prep_time"] = prep_time + model_log[name]["e2e_time"] = e2e_time + model_log[name]["eval_output"] = model_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + model_log[name]["n_params"] = compute_n_parameters(model) + del model + gc.collect() + else: + model_eval_outs = torch.load(root_dir + f"/{name}_eval_outs.pt") + + # Calculate Hallucinations + skip_hallucination_detection = False + + if path.exists(root_dir + f"/{name}_model_log.pt") and not force: + print(f"Saved outputs for {name} have been found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_hallucination_detection = user_input != "y" + + if not skip_hallucination_detection: + model_save_list = _eval_hallucinations_on_loader( + model_eval_outs, loader, eval_batch_size) + + model_log[name]["hallucinates_list"] = model_save_list + torch.save(model_log[name], root_dir + f"/{name}_model_log.pt") + else: + model_log[name]["hallucinates_list"] = \ + torch.load( + root_dir+f"/{name}_model_log.pt" + )["hallucinates_list"] + + hal_dict = { + k: [tup[1] for tup in v["hallucinates_list"]] + for (k, v) in model_log.items() + } + hallucinates_df = pd.DataFrame(hal_dict).astype(str) + hallucinates_df = hallucinates_df.apply(pd.Series.value_counts).transpose() + hallucinates_df['e2e_time'] = pd.Series( + {k: v.get('e2e_time') + for (k, v) in model_log.items()}) + hallucinates_df['n_params'] = pd.Series( + {k: v.get('n_params') + for (k, v) in model_log.items()}) + print(hallucinates_df) + hallucinates_df.to_csv(root_dir + "/hallucinates_df.csv", index=False) + + +def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size, + eval_batch_size, loss_fn, inference_fn, + skip_pretrained_LLM=False, tiny_llama=False): + if not skip_pretrained_LLM: + print("First comparing against a pretrained LLM...") + # Step 1: Define a single batch size test loader + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + # batch size 1 loader for simplicity + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + if path.exists("demo_save_dict.pt"): + print("Saved outputs for the first step of the demo found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_step_one = user_input == "n" + else: + skip_step_one = False + + if not skip_step_one: + gnn_llm_hallucin_sum = 0 + pure_llm_hallucin_sum = 0 + gnn_save_list = [] + untuned_llm_save_list = [] + gnn_llm_preds = [] + for out in gnn_llm_eval_outs: + gnn_llm_preds += out['pred'] + if skip_pretrained_LLM: + print("Checking GNN+LLM for hallucinations...") + else: + print( + "Checking pretrained LLM vs trained GNN+LLM for hallucinations..." # noqa + ) + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + + gnn_llm_pred = gnn_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred, + correct_answer) + gnn_save_list += [ + tup for tup in zip(gnn_llm_pred, gnn_llm_hallucinates) + ] + + if not skip_pretrained_LLM: + # GNN+LLM only using 32 tokens to answer. + # Allow more output tokens for untrained LLM + pure_llm_pred = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_llm_hallucinates = detect_hallucinate( + pure_llm_pred, correct_answer) + else: + pure_llm_pred = [''] * len(gnn_llm_hallucinates) + pure_llm_hallucinates = [False] * len(gnn_llm_hallucinates) + untuned_llm_save_list += [ + tup for tup in zip(pure_llm_pred, pure_llm_hallucinates) + ] + + for gnn_llm_hal, pure_llm_hal in zip(gnn_llm_hallucinates, + pure_llm_hallucinates): + if gnn_llm_hal == "skip" or pure_llm_hal == "skip": # noqa + # skipping when hallucination is hard to eval + continue + gnn_llm_hallucin_sum += int(gnn_llm_hal) + pure_llm_hallucin_sum += int(pure_llm_hal) + if not skip_pretrained_LLM: + print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2) + print(f"GNN reduces pretrained LLM hallucinations by: ~{percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Now we see how the LLM compares when finetuned...") + print("Saving outputs of GNN+LLM and pretrained LLM...") + save_dict = { + "gnn_save_list": gnn_save_list, + "untuned_llm_save_list": untuned_llm_save_list, + "gnn_llm_hallucin_sum": gnn_llm_hallucin_sum, + "pure_llm_hallucin_sum": pure_llm_hallucin_sum + } + torch.save(save_dict, "demo_save_dict.pt") + print("Done!") + else: + save_dict = torch.load("demo_save_dict.pt") + gnn_save_list = save_dict["gnn_save_list"] + untuned_llm_save_list = save_dict["untuned_llm_save_list"] + gnn_llm_hallucin_sum = save_dict["gnn_llm_hallucin_sum"] + pure_llm_hallucin_sum = save_dict["pure_llm_hallucin_sum"] + + trained_llm_hallucin_sum = 0 + untuned_llm_hallucin_sum = pure_llm_hallucin_sum + final_prnt_str = "" + if path.exists("llm.pt") and path.exists("llm_eval_outs.pt"): + print("Existing finetuned LLM found.") + print("Would you like to retrain?") + user_input = str(input("(y/n):")).lower() + retrain = user_input == "y" + else: + retrain = True + if retrain: + print("Finetuning LLM...") + since = time.time() + _, _, pure_llm_eval_outputs = train(since, epochs, None, None, + batch_size, eval_batch_size, lr, + loss_fn, inference_fn, + model=pure_llm, dataset=dataset) + e2e_time = round(time.time() - since, 2) + print("E2E time (e2e_time) =", e2e_time, "seconds") + else: + pure_llm_eval_outputs = torch.load("llm_eval_outs.pt") + pure_llm_preds = [] + for out in pure_llm_eval_outputs: + pure_llm_preds += out['pred'] + print("Final comparison between all models...") + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + gnn_llm_pred, gnn_llm_hallucinates = list( + zip(*gnn_save_list[i * eval_batch_size:(i + 1) * eval_batch_size])) + untuned_llm_pred, untuned_llm_hallucinates = list( + zip(*untuned_llm_save_list[i * eval_batch_size:(i + 1) * + eval_batch_size])) + pure_llm_pred = pure_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + pure_llm_hallucinates = detect_hallucinate(pure_llm_pred, + correct_answer) + for j in range(len(gnn_llm_pred)): + if skip_pretrained_LLM: + # we did not check the untrained LLM, so do not decide to demo + # based on this. + # HACK + untuned_llm_hallucinates = {j: True} + if gnn_llm_hallucinates[j] == "skip" or untuned_llm_hallucinates[ + j] == "skip" or pure_llm_hallucinates[j] == "skip": + continue + trained_llm_hallucin_sum += int(pure_llm_hallucinates[j]) + if untuned_llm_hallucinates[j] and pure_llm_hallucinates[ + j] and not gnn_llm_hallucinates[j]: # noqa + final_prnt_str += "Prompt: '" + question[j] + "'\n" + final_prnt_str += "Label: '" + correct_answer[j] + "'\n" + if not skip_pretrained_LLM: + final_prnt_str += "Untuned LLM Output: '" \ + + untuned_llm_pred[j] + "'\n" # noqa + final_prnt_str += "Tuned LLM Output: '" + pure_llm_pred[ + j] + "'\n" + final_prnt_str += "GNN+LLM Output: '" + gnn_llm_pred[j] + "'\n" + final_prnt_str += "\n" + "#" * 20 + "\n\n" + if not skip_pretrained_LLM: + print("Total untuned LLM Hallucinations:", untuned_llm_hallucin_sum) + print("Total tuned LLM Hallucinations:", trained_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + if not skip_pretrained_LLM: + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / untuned_llm_hallucin_sum), 2) + print(f"GNN reduces untuned LLM hallucinations by: ~{percent}%") + tuned_percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / trained_llm_hallucin_sum), 2) + print(f"GNN reduces tuned LLM hallucinations by: ~{tuned_percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Potential instances where GNN solves the hallucinations of LLM:") + print(final_prnt_str) if __name__ == '__main__': @@ -256,6 +668,9 @@ def adjust_learning_rate(param_group, LR, epoch): parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument('--checkpointing', action='store_true') parser.add_argument('--tiny_llama', action='store_true') + parser.add_argument( + "--skip_pretrained_llm_eval", action="store_true", + help="This flag will skip the evaluation of the pretrained LLM.") args = parser.parse_args() start_time = time.time() diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md new file mode 100644 index 000000000000..e072e6746b7c --- /dev/null +++ b/examples/llm/g_retriever_utils/README.md @@ -0,0 +1,11 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | +| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | +| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. | + +NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes. diff --git a/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py new file mode 100644 index 000000000000..76148cfc09e5 --- /dev/null +++ b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py @@ -0,0 +1,105 @@ +"""Used to benchmark the performance of an untuned/fine tuned LLM against +GRetriever with various architectures and layer depths. +""" +# %% +import argparse +import sys + +import torch + +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.nn.models import GAT, MLP, GRetriever + +sys.path.append('..') +from g_retriever import ( # noqa: E402 + benchmark_models, + get_loss, + inference_step, +) + +# %% +parser = argparse.ArgumentParser(description="""Benchmarker for GRetriever +NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models.""" + ) +parser.add_argument("--hidden_channels", type=int, default=1024) +parser.add_argument("--learning_rate", type=float, default=1e-5) +parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--eval_batch_size", type=int, default=16) +parser.add_argument("--tiny_llama", action='store_true') + +parser.add_argument("--dataset_path", type=str, required=False) +# Default to WebQSP split +parser.add_argument("--num_train", type=int, default=2826) +parser.add_argument("--num_val", type=int, default=246) +parser.add_argument("--num_test", type=int, default=1628) + +args = parser.parse_args() + +# %% +hidden_channels = args.hidden_channels +lr = args.learning_rate +epochs = args.epochs +batch_size = args.batch_size +eval_batch_size = args.eval_batch_size + +# %% +if not args.dataset_path: + ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True) +else: + # We just assume that the size of the dataset accomodates the + # train/val/test split, because checking may be expensive. + dataset = torch.load(args.dataset_path) + + class MockDataset: + """Utility class to patch the fields in WebQSPDataset used by + GRetriever. + """ + def __init__(self) -> None: + pass + + @property + def split_idxs(self) -> dict: + # Imitates the WebQSP split method + return { + "train": + torch.arange(args.num_train), + "val": + torch.arange(args.num_val) + args.num_train, + "test": + torch.arange(args.num_test) + args.num_train + args.num_val, + } + + def __getitem__(self, idx: int): + return dataset[idx] + + ds = MockDataset() + +# %% +model_names = [] +model_classes = [] +model_kwargs = [] +model_type = ["GAT", "MLP"] +models = {"GAT": GAT, "MLP": MLP} +# Use to vary the depth of the GNN model +num_layers = [4] +# Use to vary the number of LLM tokens reserved for GNN output +num_tokens = [1] +for m_type in model_type: + for n_layer in num_layers: + for n_tokens in num_tokens: + model_names.append(f"{m_type}_{n_layer}_{n_tokens}") + model_classes.append(GRetriever) + kwargs = dict(gnn_hidden_channels=hidden_channels, + num_gnn_layers=n_layer, gnn_to_use=models[m_type], + mlp_out_tokens=n_tokens) + if args.tiny_llama: + kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1' + kwargs['mlp_out_dim'] = 2048 + kwargs['num_llm_params'] = 1 + model_kwargs.append(kwargs) + +# %% +benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs, + batch_size, eval_batch_size, get_loss, inference_step, + skip_LLMs=False, tiny_llama=args.tiny_llama, force=True) diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py new file mode 100644 index 000000000000..0f1c0e1b87ec --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_backend_utils.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Protocol, + Tuple, + Type, + runtime_checkable, +) + +import torch +from torch import Tensor +from torch.nn import Module + +from torch_geometric.data import ( + FeatureStore, + GraphStore, + LargeGraphIndexer, + TripletLike, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION +from torch_geometric.distributed import ( + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.typing import EdgeType, NodeType + +RemoteGraphBackend = Tuple[FeatureStore, GraphStore] + +# TODO: Make everything compatible with Hetero graphs aswell + + +# Adapted from LocalGraphStore +@runtime_checkable +class ConvertableGraphStore(Protocol): + @classmethod + def from_data( + cls, + edge_id: Tensor, + edge_index: Tensor, + num_nodes: int, + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_hetero_data( + cls, + edge_id_dict: Dict[EdgeType, Tensor], + edge_index_dict: Dict[EdgeType, Tensor], + num_nodes_dict: Dict[NodeType, int], + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> GraphStore: + ... + + +# Adapted from LocalFeatureStore +@runtime_checkable +class ConvertableFeatureStore(Protocol): + @classmethod + def from_data( + cls, + node_id: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + edge_id: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_hetero_data( + cls, + node_id_dict: Dict[NodeType, Tensor], + x_dict: Optional[Dict[NodeType, Tensor]] = None, + y_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> FeatureStore: + ... + + +class RemoteDataType(Enum): + DATA = auto() + PARTITION = auto() + + +@dataclass +class RemoteGraphBackendLoader: + """Utility class to load triplets into a RAG Backend.""" + path: str + datatype: RemoteDataType + graph_store_type: Type[ConvertableGraphStore] + feature_store_type: Type[ConvertableFeatureStore] + + def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: + if self.datatype == RemoteDataType.DATA: + data_obj = torch.load(self.path) + graph_store = self.graph_store_type.from_data( + edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, + num_nodes=data_obj.num_nodes) + feature_store = self.feature_store_type.from_data( + node_id=data_obj['node_id'], x=data_obj.x, + edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) + elif self.datatype == RemoteDataType.PARTITION: + if pid is None: + assert pid is not None, \ + "Partition ID must be defined for loading from a " \ + + "partitioned store." + graph_store = self.graph_store_type.from_partition(self.path, pid) + feature_store = self.feature_store_type.from_partition( + self.path, pid) + else: + raise NotImplementedError + return (feature_store, graph_store) + + +# TODO: make profilable +def create_remote_backend_from_triplets( + triplets: Iterable[TripletLike], node_embedding_model: Module, + edge_embedding_model: Module | None = None, + graph_db: Type[ConvertableGraphStore] = LocalGraphStore, + feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, + node_method_to_call: str = "forward", + edge_method_to_call: str | None = None, + pre_transform: Callable[[TripletLike], TripletLike] | None = None, + path: str = '', n_parts: int = 1, + node_method_kwargs: Optional[Dict[str, Any]] = None, + edge_method_kwargs: Optional[Dict[str, Any]] = None +) -> RemoteGraphBackendLoader: + """Utility function that can be used to create a RAG Backend from triplets. + + Args: + triplets (Iterable[TripletLike]): Triplets to load into the RAG + Backend. + node_embedding_model (Module): Model to embed nodes into a feature + space. + edge_embedding_model (Module | None, optional): Model to embed edges + into a feature space. Defaults to the node model. + graph_db (Type[ConvertableGraphStore], optional): GraphStore class to + use. Defaults to LocalGraphStore. + feature_db (Type[ConvertableFeatureStore], optional): FeatureStore + class to use. Defaults to LocalFeatureStore. + node_method_to_call (str, optional): method to call for embeddings on + the node model. Defaults to "forward". + edge_method_to_call (str | None, optional): method to call for + embeddings on the edge model. Defaults to the node method. + pre_transform (Callable[[TripletLike], TripletLike] | None, optional): + optional preprocessing function for triplets. Defaults to None. + path (str, optional): path to save resulting stores. Defaults to ''. + n_parts (int, optional): Number of partitons to store in. + Defaults to 1. + node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into node encoding method. Defaults to None. + edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into edge encoding method. Defaults to None. + + Returns: + RemoteGraphBackendLoader: Loader to load RAG backend from disk or + memory. + """ + # Will return attribute errors for missing attributes + if not issubclass(graph_db, ConvertableGraphStore): + getattr(graph_db, "from_data") + getattr(graph_db, "from_hetero_data") + getattr(graph_db, "from_partition") + elif not issubclass(feature_db, ConvertableFeatureStore): + getattr(feature_db, "from_data") + getattr(feature_db, "from_hetero_data") + getattr(feature_db, "from_partition") + + # Resolve callable methods + node_method_kwargs = node_method_kwargs \ + if node_method_kwargs is not None else dict() + + edge_embedding_model = edge_embedding_model \ + if edge_embedding_model is not None else node_embedding_model + edge_method_to_call = edge_method_to_call \ + if edge_method_to_call is not None else node_method_to_call + edge_method_kwargs = edge_method_kwargs \ + if edge_method_kwargs is not None else node_method_kwargs + + # These will return AttributeErrors if they don't exist + node_model = getattr(node_embedding_model, node_method_to_call) + edge_model = getattr(edge_embedding_model, edge_method_to_call) + + indexer = LargeGraphIndexer.from_triplets(triplets, + pre_transform=pre_transform) + + node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) + indexer.add_node_feature('x', node_feats) + + edge_feats = edge_model( + indexer.get_unique_edge_features(feature_name=EDGE_RELATION), + **edge_method_kwargs) + indexer.add_edge_feature(new_feature_name="edge_attr", + new_feature_vals=edge_feats, + map_from_feature=EDGE_RELATION) + + data = indexer.to_data(node_feature_name='x', + edge_feature_name='edge_attr') + + if n_parts == 1: + torch.save(data, path) + return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, + feature_db) + else: + partitioner = Partitioner(data=data, num_parts=n_parts, root=path) + partitioner.generate_partition() + return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, + graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py new file mode 100644 index 000000000000..e01e9e59bb88 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_feature_store.py @@ -0,0 +1,189 @@ +import gc +from collections.abc import Iterable, Iterator +from typing import Any, Dict, Optional, Type, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torchmetrics.functional import pairwise_cosine_similarity + +from torch_geometric.data import Data, HeteroData +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.nn.pool import ApproxMIPSKNNIndex +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +# NOTE: Only compatible with Homogeneous graphs for now +class KNNRAGFeatureStore(LocalFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.enc_model = enc_model(*args, **kwargs).to(self.device) + self.enc_model.eval() + self.model_kwargs = \ + model_kwargs if model_kwargs is not None else dict() + super().__init__() + + @property + def x(self) -> Tensor: + return self.get_tensor(group_name=None, attr_name='x') + + @property + def edge_attr(self) -> Tensor: + return self.get_tensor(group_name=(None, None), attr_name='edge_attr') + + def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: + result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) + topk = min(k_nodes, len(self.x)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: + result = next(self._retrieve_seed_edges_batch([query], k_edges)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + + prizes = pairwise_cosine_similarity(query_enc, + self.edge_attr.to(self.device)) + topk = min(k_edges, len(self.edge_attr)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + + if isinstance(sample, HeteroSamplerOutput): + raise NotImplementedError + + # NOTE: torch_geometric.loader.utils.filter_custom_store can be used + # here if it supported edge features + node_id = sample.node + edge_id = sample.edge + edge_index = torch.stack((sample.row, sample.col), dim=0) + x = self.x[node_id] + edge_attr = self.edge_attr[edge_id] + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + node_idx=node_id, edge_idx=edge_id) + + +# TODO: Refactor because composition >> inheritance + + +def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, + device: torch.device, batch_size: int = 2**20): + """Add new features to the existing KNN index in batches. + + Args: + knn_index (ApproxMIPSKNNIndex): Index to add features to. + emb (Tensor): Embeddings to add. + device (torch.device): Device to store in + batch_size (int, optional): Batch size to iterate by. + Defaults to 2**20, which equates to 4GB if working with + 1024 dim floats. + """ + for i in range(0, emb.size(0), batch_size): + if emb.size(0) - i >= batch_size: + emb_batch = emb[i:i + batch_size].to(device) + else: + emb_batch = emb[i:].to(device) + knn_index.add(emb_batch) + + +class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + # TODO: Add kwargs for approx KNN to parameters here. + super().__init__(enc_model, model_kwargs, *args, **kwargs) + self.node_knn_index = None + self.edge_knn_index = None + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.node_knn_index is None: + self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.node_knn_index, self.x, + self.device) + + output = self.node_knn_index.search(query_enc, k=k_nodes) + yield from output.index + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.edge_knn_index is None: + self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, + self.device) + + output = self.edge_knn_index.search(query_enc, k=k_edges) + yield from output.index + + +# TODO: These two classes should be refactored +class SentenceTransformerFeatureStore(KNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) + + +class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py new file mode 100644 index 000000000000..c6895b453b0c --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_generate.py @@ -0,0 +1,137 @@ +# %% +import argparse +from itertools import chain +from typing import Tuple + +import pandas as pd +import torch +import tqdm +from rag_backend_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore + +from torch_geometric.data import Data +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +# %% +parser = argparse.ArgumentParser(description="""Generate new WebQSP subgraphs +NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models.""" + ) +# TODO: Add more arguments for configuring rag params +parser.add_argument("--use_pcst", action="store_true") +parser.add_argument("--num_samples", type=int, default=4700) +parser.add_argument("--out_file", default="subg_results.pt") +args = parser.parse_args() + +# %% +ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True, + force_reload=True) + +# %% +triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset) + +# %% +questions = ds.raw_dataset['question'] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]: + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + desc = ( + textual_nodes.to_csv(index=False) + "\n" + + textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"])) + graph["desc"] = desc + return graph + + +transform = apply_retrieval_via_pcst \ + if args.use_pcst else apply_retrieval_with_text + +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5}, + seed_edges_kwargs={"k_edges": 5}, + sampler_kwargs={"num_neighbors": [50] * 2}, + local_filter=transform) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% +retrieval_stats = {"precision": [], "recall": [], "accuracy": []} +subgs = [] +node_len = [] +edge_len = [] +for subg in tqdm.tqdm(query_loader.query(q) for q in questions): + subgs.append(subg) + node_len.append(subg['x'].shape[0]) + edge_len.append(subg['edge_attr'].shape[0]) + +for i, subg in enumerate(subgs): + subg['question'] = questions[i] + subg['label'] = ds[i]['label'] + +pd.DataFrame.from_dict(retrieval_stats).to_csv( + args.out_file.split('.')[0] + '_metadata.csv') +torch.save(subgs, args.out_file) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py new file mode 100644 index 000000000000..48473f287233 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_graph_store.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import FeatureStore +from torch_geometric.distributed import LocalGraphStore +from torch_geometric.sampler import ( + HeteroSamplerOutput, + NeighborSampler, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.sampler.neighbor_sampler import NumNeighborsType +from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes + + +class NeighborSamplingRAGGraphStore(LocalGraphStore): + def __init__(self, feature_store: Optional[FeatureStore] = None, + num_neighbors: NumNeighborsType = [1], **kwargs): + self.feature_store = feature_store + self._num_neighbors = num_neighbors + self.sample_kwargs = kwargs + self._sampler_is_initialized = False + super().__init__() + + def _init_sampler(self): + if self.feature_store is None: + raise AttributeError("Feature store not registered yet.") + self.sampler = NeighborSampler(data=(self.feature_store, self), + num_neighbors=self._num_neighbors, + **self.sample_kwargs) + self._sampler_is_initialized = True + + def register_feature_store(self, feature_store: FeatureStore): + self.feature_store = feature_store + self._sampler_is_initialized = False + + def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: + ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) + self._sampler_is_initialized = False + return ret + + @property + def edge_index(self): + return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) + + def put_edge_index(self, edge_index: EdgeTensorType, *args, + **kwargs) -> bool: + ret = super().put_edge_index(edge_index, *args, **kwargs) + # HACK + self.edge_idx_args = args + self.edge_idx_kwargs = kwargs + self._sampler_is_initialized = False + return ret + + @property + def num_neighbors(self): + return self._num_neighbors + + @num_neighbors.setter + def num_neighbors(self, num_neighbors: NumNeighborsType): + self._num_neighbors = num_neighbors + if hasattr(self, 'sampler'): + self.sampler.num_neighbors = num_neighbors + + def sample_subgraph( + self, seed_nodes: InputNodes, seed_edges: InputEdges, + num_neighbors: Optional[NumNeighborsType] = None + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample the graph starting from the given nodes and edges using the + in-built NeighborSampler. + + Args: + seed_nodes (InputNodes): Seed nodes to start sampling from. + seed_edges (InputEdges): Seed edges to start sampling from. + num_neighbors (Optional[NumNeighborsType], optional): Parameters + to determine how many hops and number of neighbors per hop. + Defaults to None. + + Returns: + Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput + for the input. + """ + if not self._sampler_is_initialized: + self._init_sampler() + if num_neighbors is not None: + self.num_neighbors = num_neighbors + + # FIXME: Right now, only input nodes/edges as tensors are be supported + if not isinstance(seed_nodes, Tensor): + raise NotImplementedError + if not isinstance(seed_edges, Tensor): + raise NotImplementedError + device = seed_nodes.device + + # TODO: Call sample_from_edges for seed_edges + # Turning them into nodes for now. + seed_edges = self.edge_index.to(device).T[seed_edges.to( + device)].reshape(-1) + seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) + + seed_nodes = seed_nodes.unique().contiguous() + node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) + out = self.sampler.sample_from_nodes(node_sample_input) + + return out diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md new file mode 100644 index 000000000000..ff43b16a2c05 --- /dev/null +++ b/examples/llm/multihop_rag/README.md @@ -0,0 +1,9 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. | +| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. | +| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. | + +NOTE: Performance of GRetriever on this dataset has not been evaluated. diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh new file mode 100644 index 000000000000..3c1970d39440 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_download.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Wikidata5m + +wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz" +tar -xvf "wikidata5m_alias.tar.gz" +wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz" +gzip -d "wikidata5m_all_triplet.txt.gz" -f + +# 2Multihopqa +wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip" +unzip -o "data_ids_april7.zip" diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py new file mode 100644 index 000000000000..46052bdf1b15 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_preprocess.py @@ -0,0 +1,276 @@ +"""Example workflow for downloading and assembling a multihop QA dataset.""" + +import argparse +import json +from subprocess import call + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import LargeGraphIndexer + +# %% [markdown] +# # Encoding A Large Knowledge Graph Part 2 + +# %% [markdown] +# In this notebook, we will continue where we left off by building a new +# multi-hop QA dataset based on Wikidata. + +# %% [markdown] +# ## Example 2: Building a new Dataset from Questions and an already-existing +# Knowledge Graph + +# %% [markdown] +# ### Motivation + +# %% [markdown] +# One potential application of knowledge graph structural encodings is +# capturing the relationships between different entities that are multiple +# hops apart. This can be challenging for an LLM to recognize from prepended +# graph information. Here's a motivating example (credit to @Rishi Puri): + +# %% [markdown] +# In this example, the question can only be answered by reasoning about the +# relationships between the entities in the knowledge graph. + +# %% [markdown] +# ### Building a Multi-Hop QA Dataset + +# %% [markdown] +# To start, we need to download the raw data of a knowledge graph. +# In this case, we use WikiData5M +# ([Wang et al] +# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)). +# Here we download the raw triplets and their entity codes. Information about +# this dataset can be found +# [here](https://deepgraphlearning.github.io/project/wikidata5m). + +# %% [markdown] +# The following download contains the ID to plaintext mapping for all the +# entities and relations in the knowledge graph: + +rv = call("./multihop_download.sh") + +# %% [markdown] +# To start, we are going to preprocess the knowledge graph to substitute each +# of the entity/relation codes with their plaintext aliases. This makes it +# easier to use a pre-trained textual encoding model to create triplet +# embeddings, as such a model likely won't understand how to properly embed +# the entity codes. + +# %% + +# %% +parser = argparse.ArgumentParser(description="Preprocess wikidata5m") +parser.add_argument("--n_triplets", type=int, default=-1) +args = parser.parse_args() + +# %% +# Substitute entity codes with their aliases +# Picking the first alias for each entity (rather arbitrarily) +alias_map = {} +rel_alias_map = {} +for line in open('wikidata5m_entity.txt'): + parts = line.strip().split('\t') + entity_id = parts[0] + aliases = parts[1:] + alias_map[entity_id] = aliases[0] +for line in open('wikidata5m_relation.txt'): + parts = line.strip().split('\t') + relation_id = parts[0] + relation_name = parts[1] + rel_alias_map[relation_id] = relation_name + +# %% +full_graph = [] +missing_total = 0 +total = 0 +limit = None if args.n_triplets == -1 else args.n_triplets +i = 0 + +for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')): + if limit is not None and i >= limit: + break + src, rel, dst = line.strip().split('\t') + if src not in alias_map: + missing_total += 1 + if dst not in alias_map: + missing_total += 1 + if rel not in rel_alias_map: + missing_total += 1 + total += 3 + full_graph.append([ + alias_map.get(src, src), + rel_alias_map.get(rel, rel), + alias_map.get(dst, dst) + ]) + i += 1 +print(f"Missing aliases: {missing_total}/{total}") + +# %% [markdown] +# Now `full_graph` represents the knowledge graph triplets in +# understandable plaintext. + +# %% [markdown] +# Next, we need a set of multi-hop questions that the Knowledge Graph will +# provide us with context for. We utilize a subset of +# [HotPotQA](https://hotpotqa.github.io/) +# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called +# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop) +# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)), +# which includes a subgraph of entities that serve as the ground truth +# justification for answering each multi-hop question: + +# %% +with open('train.json') as f: + train_data = json.load(f) +train_df = pd.DataFrame(train_data) +train_df['split_type'] = 'train' + +with open('dev.json') as f: + dev_data = json.load(f) +dev_df = pd.DataFrame(dev_data) +dev_df['split_type'] = 'dev' + +with open('test.json') as f: + test_data = json.load(f) +test_df = pd.DataFrame(test_data) +test_df['split_type'] = 'test' + +df = pd.concat([train_df, dev_df, test_df]) + +# %% [markdown] +# Now we need to extract the subgraphs + +# %% +df['graph_size'] = df['evidences_id'].apply(lambda row: len(row)) + +# %% [markdown] +# (Optional) We take only questions where the evidence graph is greater than +# 0. (Note: this gets rid of the test set): + +# %% +# df = df[df['graph_size'] > 0] + +# %% +refined_df = df[[ + '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type', + 'graph_size' +]] + +# %% [markdown] +# Checkpoint: + +# %% +refined_df.to_csv('wikimultihopqa_refined.csv', index=False) + +# %% [markdown] +# Now we need to check that all the entities mentioned in the question/answer +# set are also present in the Wikidata graph: + +# %% +relation_map = {} +with open('wikidata5m_relation.txt') as f: + for line in tqdm.tqdm(f): + parts = line.strip().split('\t') + for i in range(1, len(parts)): + if parts[i] not in relation_map: + relation_map[parts[i]] = [] + relation_map[parts[i]].append(parts[0]) + +# %% +entity_set = set() +with open('wikidata5m_entity.txt') as f: + for line in tqdm.tqdm(f): + entity_set.add(line.strip().split('\t')[0]) + +# %% +missing_entities = set() +missing_entity_idx = set() +for i, row in enumerate(refined_df.itertuples()): + for trip in row.evidences_id: + entities = trip[0], trip[2] + for entity in entities: + if entity not in entity_set: + # print( + # f'The following entity was not found in the KG: {entity}' + # ) + missing_entities.add(entity) + missing_entity_idx.add(i) + +# %% [markdown] +# Right now, we drop the missing entity entries. Additional preprocessing can +# be done here to resolve the entity/relation collisions, but that is out of +# the scope for this notebook. + +# %% +# missing relations are ok, but missing entities cannot be mapped to +# plaintext, so they should be dropped. +refined_df.reset_index(inplace=True, drop=True) + +# %% +cleaned_df = refined_df.drop(missing_entity_idx) + +# %% [markdown] +# Now we save the resulting graph and questions/answers dataset: + +# %% +cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False) + +# %% + +# %% +torch.save(full_graph, 'wikimultihopqa_full_graph.pt') + +# %% [markdown] +# ### Question: How do we extract a contextual subgraph for a given query? + +# %% [markdown] +# The chosen retrieval algorithm is a critical component in the pipeline for +# affecting RAG performance. In the next section (1), we will demonstrate a +# naive method of retrieval for a large knowledge graph, and how to apply it +# to this dataset along with WebQSP. + +# %% [markdown] +# ### Preparing a Textualized Graph for LLM + +# %% [markdown] +# For now however, we need to prepare the graph data to be used as a plaintext +# prefix to the LLM. In order to do this, we want to prompt the LLM to use the +# unique nodes, and unique edge triplets of a given subgraph. In order to do +# this, we prepare a unique indexed node df and edge df for the knowledge +# graph now. This process occurs trivially with the LargeGraphIndexer: + +# %% + +# %% +indexer = LargeGraphIndexer.from_triplets(full_graph) + +# %% +# Node DF +textual_nodes = pd.DataFrame.from_dict( + {"node_attr": indexer.get_node_features()}) +textual_nodes["node_id"] = textual_nodes.index +textual_nodes = textual_nodes[["node_id", "node_attr"]] + +# %% [markdown] +# Notice how LargeGraphIndexer ensures that there are no duplicate indices: + +# %% +# Edge DF +textual_edges = pd.DataFrame(indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) +textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]] +textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]] + +# %% [markdown] +# Note: The edge table refers to each node by its index in the node table. +# We will see how this gets utilized later when indexing a subgraph. + +# %% [markdown] +# Now we can save the result + +# %% +textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False) +textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False) diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py new file mode 100644 index 000000000000..de93a9e75dd1 --- /dev/null +++ b/examples/llm/multihop_rag/rag_generate_multihop.py @@ -0,0 +1,88 @@ +# %% +import argparse +import sys +from typing import Tuple + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import Data +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +sys.path.append('..') + +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerApproxFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +parser = argparse.ArgumentParser( + description="Generate new multihop dataset for rag") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--num_samples", type=int) +args = parser.parse_args() + +# %% +triplets = torch.load('wikimultihopqa_full_graph.pt') + +# %% +df = pd.read_csv('wikimultihopqa_cleaned.csv') +questions = df['question'][:args.num_samples] +labels = df['answer'][:args.num_samples] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerApproxFeatureStore).load() + +# %% + +all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv') +all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv') + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +# %% +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, + seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": [40] * 3}, + local_filter=apply_retrieval_via_pcst) + +# %% +subgs = [] +for q, l in tqdm.tqdm(zip(questions, labels)): + subg = query_loader.query(q) + subg['question'] = q + subg['label'] = l + subgs.append(subg) + +torch.save(subgs, 'subg_results.pt') diff --git a/examples/llm/nvtx_examples/README.md b/examples/llm/nvtx_examples/README.md new file mode 100644 index 000000000000..aa4f070d9824 --- /dev/null +++ b/examples/llm/nvtx_examples/README.md @@ -0,0 +1,7 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| [`nvtx_run.sh`](./nvtx_run.sh) | Runs nsys profiler on a given Python file that contains NVTX calls. | +| [`nvtx_rag_backend_example.py`](./nvtx_rag_backend_example.py) | Example script for nsys profiling a RAG Backend such as that used in [`rag_generate.py`](../g_retriever_utils/rag_generate.py). | +| [`nvtx_webqsp_example.py`](./nvtx_webqsp_example.py) | Example script for nsys profiling the WebQSP dataset. | diff --git a/examples/llm/nvtx_examples/nvtx_rag_backend_example.py b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py new file mode 100644 index 000000000000..b30e34b8c7b1 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py @@ -0,0 +1,144 @@ +# %% +import argparse +import sys +from itertools import chain +from typing import Tuple + +import torch + +from torch_geometric.data import Data, get_features_for_triplets_groups +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import rag_loader +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.profile.nvtx import nvtxit + +sys.path.append('..') +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +# Patch FeatureStore and GraphStore + +SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_nodes) +SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_edges) +SentenceTransformerFeatureStore.load_subgraph = nvtxit()( + SentenceTransformerFeatureStore.load_subgraph) +NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()( + NeighborSamplingRAGGraphStore.sample_subgraph) +rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +# %% +ds = WebQSPDataset("small_ds_1", force_reload=True, limit=10) + +# %% +triplets = list(chain.from_iterable(d['graph'] for d in ds.raw_dataset)) + +# %% +questions = ds.raw_dataset['question'] + +# %% +ground_truth_graphs = get_features_for_triplets_groups( + ds.indexer, (d['graph'] for d in ds.raw_dataset), + pre_transform=preprocess_triplet) +num_edges = len(ds.indexer._edges) + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer('sentence-transformers/all-roberta-large-v1').to( + device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +@nvtxit() +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return graph + + +# %% +query_loader = rag_loader.RAGQueryLoader( + data=(fs, gs), seed_nodes_kwargs={"k_nodes": + 10}, seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": + [40] * 10}, local_filter=apply_retrieval_via_pcst) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% + + +@nvtxit() +def _run_eval(): + for subg, gt in zip((query_loader.query(q) for q in questions), + ground_truth_graphs): + print(check_retrieval_accuracy(subg, gt, num_edges), + check_retrieval_precision(subg, gt), + check_retrieval_recall(subg, gt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + _run_eval() + else: + _run_eval() diff --git a/examples/llm/nvtx_examples/nvtx_run.sh b/examples/llm/nvtx_examples/nvtx_run.sh new file mode 100644 index 000000000000..4c6fce7c8224 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_run.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# Check if the user provided a Python file +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Check if the provided file exists +if [[ ! -f "$1" ]]; then + echo "Error: File '$1' does not exist." + exit 1 +fi + +# Check if the provided file is a Python file +if [[ ! "$1" == *.py ]]; then + echo "Error: '$1' is not a Python file." + exit 1 +fi + +# Get the base name of the Python file +python_file=$(basename "$1") + +# Run nsys profile on the Python file +nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" + +echo "Profile data saved as profile_${python_file%.py}.nsys-rep" diff --git a/examples/llm/nvtx_examples/nvtx_webqsp_example.py b/examples/llm/nvtx_examples/nvtx_webqsp_example.py new file mode 100644 index 000000000000..a1e9611ee5cc --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_webqsp_example.py @@ -0,0 +1,22 @@ +import argparse + +import torch + +from torch_geometric.datasets import web_qsp_dataset +from torch_geometric.profile import nvtxit + +# Apply Patches +web_qsp_dataset.retrieval_via_pcst = nvtxit()( + web_qsp_dataset.retrieval_via_pcst) +web_qsp_dataset.WebQSPDataset.process = nvtxit()( + web_qsp_dataset.WebQSPDataset.process) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = web_qsp_dataset.WebQSPDataset('baseline', limit=10) + else: + ds = web_qsp_dataset.WebQSPDataset('baseline', limit=10) diff --git a/test/data/test_large_graph_indexer.py b/test/data/test_large_graph_indexer.py new file mode 100644 index 000000000000..840299ba4934 --- /dev/null +++ b/test/data/test_large_graph_indexer.py @@ -0,0 +1,170 @@ +import random +import string +from typing import List + +import networkx as nx +import torch + +from torch_geometric.data import ( + Data, + LargeGraphIndexer, + TripletLike, + get_features_for_triplets, +) +from torch_geometric.data.large_graph_indexer import ( + EDGE_PID, + EDGE_RELATION, + NODE_PID, +) + +# create possible nodes and edges for graph +strkeys = string.ascii_letters + string.digits +NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)}) +EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)}) + + +def featurize(s: str) -> int: + return int.from_bytes(s.encode(), 'little') + + +def sample_triplets(amount: int = 1) -> List[TripletLike]: + trips = [] + for i in range(amount): + h, t = random.sample(NODE_POOL, k=2) + r = random.sample(EDGE_POOL, k=1)[0] + trips.append(tuple([h, r, t])) + return trips + + +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return h.lower(), r, t.lower() + + +def test_basic_collate(): + graphs = [sample_triplets(1000) for i in range(2)] + + indexer_0 = LargeGraphIndexer.from_triplets( + graphs[0], pre_transform=preprocess_triplet) + indexer_1 = LargeGraphIndexer.from_triplets( + graphs[1], pre_transform=preprocess_triplet) + + big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1]) + + assert len(indexer_0._nodes) + len( + indexer_1._nodes) - len(indexer_0._nodes.keys() + & indexer_1._nodes.keys()) == len( + big_indexer._nodes) + assert len(indexer_0._edges) + len( + indexer_1._edges) - len(indexer_0._edges.keys() + & indexer_1._edges.keys()) == len( + big_indexer._edges) + + assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes) + assert len(set(big_indexer._edges.values())) == len(big_indexer._edges) + + for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()): + assert big_indexer.node_attr[NODE_PID][ + big_indexer._nodes[node]] == node + + +def test_large_graph_index(): + graphs = [sample_triplets(1000) for i in range(100)] + + # Preprocessing of trips lowercases nodes but not edges + node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + def encode_graph_from_trips(triplets: List[TripletLike]) -> Data: + seen_nodes = dict() + edge_attrs = list() + edge_idx = [] + for trip in triplets: + trip = preprocess_triplet(trip) + h, r, t = trip + seen_nodes[h] = len( + seen_nodes) if h not in seen_nodes else seen_nodes[h] + seen_nodes[t] = len( + seen_nodes) if t not in seen_nodes else seen_nodes[t] + edge_attrs.append(edge_feature_vecs[r]) + edge_idx.append((seen_nodes[h], seen_nodes[t])) + + x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()]) + edge_idx = torch.LongTensor(edge_idx).T + edge_attrs = torch.Tensor(edge_attrs) + return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs) + + naive_graph_ds = [ + encode_graph_from_trips(triplets=trips) for trips in graphs + ] + + indexer = LargeGraphIndexer.collate([ + LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet) + for g in graphs + ]) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + large_graph_ds = [ + get_features_for_triplets(indexer=indexer, triplets=g, + node_feature_name='x', + edge_feature_name='edge_attr', + pre_transform=preprocess_triplet) + for g in graphs + ] + + for ds in large_graph_ds: + assert NODE_PID in ds + assert EDGE_PID in ds + assert "node_idx" in ds + assert "edge_idx" in ds + + def results_are_close_enough(ground_truth: Data, new_method: Data, + thresh=.99): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all( + torch.isclose(tensor1.sort()[0], + tensor2.sort()[0]) > thresh) + + def _graphs_are_same(tensor1, tensor2): + return nx.weisfeiler_lehman_graph_hash(nx.Graph( + tensor1.T)) == nx.weisfeiler_lehman_graph_hash( + nx.Graph(tensor2.T)) + return _sorted_tensors_are_close( + ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close( + ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same( + ground_truth.edge_index, new_method.edge_index) + + for dsets in zip(naive_graph_ds, large_graph_ds): + assert results_are_close_enough(*dsets) + + +def test_save_load(tmp_path): + graph = sample_triplets(1000) + + node_feature_vecs = {s: featurize(s) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + indexer = LargeGraphIndexer.from_triplets(graph) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + + indexer.save(str(tmp_path)) + assert indexer == LargeGraphIndexer.from_disk(str(tmp_path)) diff --git a/test/datasets/test_web_qsp_dataset.py b/test/datasets/test_web_qsp_dataset.py new file mode 100644 index 000000000000..9dbb8218c65a --- /dev/null +++ b/test/datasets/test_web_qsp_dataset.py @@ -0,0 +1,29 @@ +import pytest + +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.testing import onlyFullTest, onlyOnline + + +@pytest.mark.skip(reason="Times out") +@onlyOnline +@onlyFullTest +def test_web_qsp_dataset(): + dataset = WebQSPDataset() + assert len(dataset) == 4700 + assert str(dataset) == "WebQSPDataset(4700)" + + +@onlyOnline +@onlyFullTest +def test_web_qsp_dataset_limit(tmp_path): + dataset = WebQSPDataset(root=tmp_path, limit=100) + assert len(dataset) == 100 + assert str(dataset) == "WebQSPDataset(100)" + + +@onlyOnline +@onlyFullTest +def test_web_qsp_dataset_limit_no_pcst(tmp_path): + dataset = WebQSPDataset(root=tmp_path, limit=100, include_pcst=False) + assert len(dataset) == 100 + assert str(dataset) == "WebQSPDataset(100)" diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py index 899e70730cc9..24a74d1b6f6e 100644 --- a/test/nn/models/test_g_retriever.py +++ b/test/nn/models/test_g_retriever.py @@ -51,3 +51,52 @@ def test_g_retriever() -> None: # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_g_retriever_many_tokens() -> None: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.float16, + ) + + gnn = GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=1024, + num_layers=2, + heads=4, + norm='batch_norm', + ) + + model = GRetriever( + llm=llm, + gnn=gnn, + mlp_out_channels=2048, + mlp_out_tokens=2, + ) + assert str(model) == ('GRetriever(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' gnn=GAT(1024, 1024, num_layers=2),\n' + ')') + + x = torch.randn(10, 1024) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 1024) + batch = torch.zeros(x.size(0), dtype=torch.long) + + question = ["Is PyG the best open-source GNN library?"] + label = ["yes!"] + + # Test train: + loss = model(question, x, edge_index, batch, label, edge_attr) + assert loss >= 0 + + # Test inference: + pred = model.inference(question, x, edge_index, batch, edge_attr) + assert len(pred) == 1 diff --git a/test/profile/test_nvtx.py b/test/profile/test_nvtx.py new file mode 100644 index 000000000000..56e28a9c2e59 --- /dev/null +++ b/test/profile/test_nvtx.py @@ -0,0 +1,136 @@ +from unittest.mock import call, patch + +from torch_geometric.profile import nvtxit + + +def _setup_mock(torch_cuda_mock): + torch_cuda_mock.is_available.return_value = True + torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None + torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None + return torch_cuda_mock + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_base(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit() + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_rename(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit('a_nvtx') + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('a_nvtx_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_iters(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_iters=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_b_0'), call('call_a_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_warmups(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_warmups=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_1') + ] diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index 821ef9c5c063..fee215b1a357 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -16,6 +16,7 @@ from .makedirs import makedirs from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz +from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups from torch_geometric.lazy_loader import LazyLoader @@ -27,6 +28,8 @@ 'Dataset', 'InMemoryDataset', 'OnDiskDataset', + 'LargeGraphIndexer', + 'TripletLike', ] remote_backend_classes = [ @@ -50,6 +53,8 @@ 'extract_zip', 'extract_bz2', 'extract_gz', + 'get_features_for_triplets', + "get_features_for_triplets_groups", ] __all__ = data_classes + remote_backend_classes + helper_functions diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py new file mode 100644 index 000000000000..51b0ce459b33 --- /dev/null +++ b/torch_geometric/data/large_graph_indexer.py @@ -0,0 +1,672 @@ +import os +import pickle as pkl +import shutil +from dataclasses import dataclass +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import Data + +TripletLike = Tuple[Hashable, Hashable, Hashable] + +KnowledgeGraphLike = Iterable[TripletLike] + + +def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: + return list(dict.fromkeys(values)) + + +# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum? + +NODE_PID = "pid" + +NODE_KEYS = {NODE_PID} + +EDGE_PID = "e_pid" +EDGE_HEAD = "h" +EDGE_RELATION = "r" +EDGE_TAIL = "t" +EDGE_INDEX = "edge_idx" + +EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX} + +FeatureValueType = Union[Sequence[Any], Tensor] + + +@dataclass +class MappedFeature: + name: str + values: FeatureValueType + + def __eq__(self, value: "MappedFeature") -> bool: + eq = self.name == value.name + if isinstance(self.values, torch.Tensor): + eq &= torch.equal(self.values, value.values) + else: + eq &= self.values == value.values + return eq + + +class LargeGraphIndexer: + """For a dataset that consists of mulitiple subgraphs that are assumed to + be part of a much larger graph, collate the values into a large graph store + to save resources. + """ + def __init__( + self, + nodes: Iterable[Hashable], + edges: KnowledgeGraphLike, + node_attr: Optional[Dict[str, List[Any]]] = None, + edge_attr: Optional[Dict[str, List[Any]]] = None, + ) -> None: + r"""Constructs a new index that uniquely catalogs each node and edge + by id. Not meant to be used directly. + + Args: + nodes (Iterable[Hashable]): Node ids in the graph. + edges (KnowledgeGraphLike): Edge ids in the graph. + node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node + attribute name and list of their values in order of unique node + ids. Defaults to None. + edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge + attribute name and list of their values in order of unique edge + ids. Defaults to None. + """ + self._nodes: Dict[Hashable, int] = dict() + self._edges: Dict[TripletLike, int] = dict() + + self._mapped_node_features: Set[str] = set() + self._mapped_edge_features: Set[str] = set() + + if len(nodes) != len(set(nodes)): + raise AttributeError("Nodes need to be unique") + if len(edges) != len(set(edges)): + raise AttributeError("Edges need to be unique") + + if node_attr is not None: + # TODO: Validity checks btw nodes and node_attr + self.node_attr = node_attr + if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS: + raise AttributeError( + "Invalid node_attr object. Missing " + + f"{NODE_KEYS - set(self.node_attr.keys())}") + elif self.node_attr[NODE_PID] != nodes: + raise AttributeError( + "Nodes provided do not match those in node_attr") + else: + self.node_attr = dict() + self.node_attr[NODE_PID] = nodes + + for i, node in enumerate(self.node_attr[NODE_PID]): + self._nodes[node] = i + + if edge_attr is not None: + # TODO: Validity checks btw edges and edge_attr + self.edge_attr = edge_attr + + if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS: + raise AttributeError( + "Invalid edge_attr object. Missing " + + f"{EDGE_KEYS - set(self.edge_attr.keys())}") + elif self.node_attr[EDGE_PID] != edges: + raise AttributeError( + "Edges provided do not match those in edge_attr") + + else: + self.edge_attr = dict() + for default_key in EDGE_KEYS: + self.edge_attr[default_key] = list() + self.edge_attr[EDGE_PID] = edges + + for i, tup in enumerate(edges): + h, r, t = tup + self.edge_attr[EDGE_HEAD].append(h) + self.edge_attr[EDGE_RELATION].append(r) + self.edge_attr[EDGE_TAIL].append(t) + self.edge_attr[EDGE_INDEX].append( + (self._nodes[h], self._nodes[t])) + + for i, tup in enumerate(edges): + self._edges[tup] = i + + @classmethod + def from_triplets( + cls, + triplets: KnowledgeGraphLike, + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + ) -> "LargeGraphIndexer": + r"""Generate a new index from a series of triplets that represent edge + relations between nodes. + Formatted like (source_node, edge, dest_node). + + Args: + triplets (KnowledgeGraphLike): Series of triplets representing + knowledge graph relations. + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function to apply to triplets. + Defaults to None. + + Returns: + LargeGraphIndexer: Index of unique nodes and edges. + """ + # NOTE: Right now assumes that all trips can be loaded into memory + nodes = set() + edges = set() + + if pre_transform is not None: + + def apply_transform( + trips: KnowledgeGraphLike) -> Iterator[TripletLike]: + for trip in trips: + yield pre_transform(trip) + + triplets = apply_transform(triplets) + + for h, r, t in triplets: + + for node in (h, t): + nodes.add(node) + + edge_idx = (h, r, t) + edges.add(edge_idx) + + return cls(list(nodes), list(edges)) + + @classmethod + def collate(cls, + graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer": + r"""Combines a series of large graph indexes into a single large graph + index. + + Args: + graphs (Iterable["LargeGraphIndexer"]): Indices to be + combined. + + Returns: + LargeGraphIndexer: Singular unique index for all nodes and edges + in input indices. + """ + # FIXME Needs to merge node attrs and edge attrs? + trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) + return cls.from_triplets(trips) + + def get_unique_node_features( + self, feature_name: str = NODE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific node attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to NODE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_node_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_node_features(feature_name)) + + except KeyError: + raise AttributeError( + f"Nodes do not have a feature called {feature_name}") + + def add_node_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = NODE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique node in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that + new feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to NODE_PID. + """ + if new_feature_name in self.node_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_node_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_node_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + "Expected encodings for {len(feature_keys)} unique features," + + f" but got {len(new_feature_vals)} encodings.") + + if map_from_feature == NODE_PID: + self.node_attr[new_feature_name] = new_feature_vals + else: + self.node_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_node_features.add(new_feature_name) + + def get_node_features( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get node feature values for a given set of unique node ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. Defaults + to NODE_PID. + pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + for. Defaults to None, which fetches all nodes. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_node_features: + values = self.node_attr[feature_name].values + else: + values = self.node_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_node_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_node_features_iter(feature_name, pids)) + + def get_node_features_iter( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_node_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.node_attr[NODE_PID] + + if feature_name in self._mapped_node_features: + feature_map_info = self.node_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_node_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._nodes[pid] + from_feature_val = self.node_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._nodes[pid] + if index_only: + yield idx + else: + yield self.node_attr[feature_name][idx] + + def get_unique_edge_features( + self, feature_name: str = EDGE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific edge attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to EDGE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_edge_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_edge_features(feature_name)) + except KeyError: + raise AttributeError( + f"Edges do not have a feature called {feature_name}") + + def add_edge_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = EDGE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique edge in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that new + feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to EDGE_PID. + """ + if new_feature_name in self.edge_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_edge_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_edge_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + f"Expected encodings for {len(feature_keys)} unique features, " + + f"but got {len(new_feature_vals)} encodings.") + + if map_from_feature == EDGE_PID: + self.edge_attr[new_feature_name] = new_feature_vals + else: + self.edge_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_edge_features.add(new_feature_name) + + def get_edge_features( + self, + feature_name: str = EDGE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get edge feature values for a given set of unique edge ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. + Defaults to EDGE_PID. + pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + for. Defaults to None, which fetches all edges. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_edge_features: + values = self.edge_attr[feature_name].values + else: + values = self.edge_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_edge_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_edge_features_iter(feature_name, pids)) + + def get_edge_features_iter( + self, + feature_name: str = EDGE_PID, + pids: Optional[KnowledgeGraphLike] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_edge_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.edge_attr[EDGE_PID] + + if feature_name in self._mapped_edge_features: + feature_map_info = self.edge_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_edge_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._edges[pid] + from_feature_val = self.edge_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._edges[pid] + if index_only: + yield idx + else: + yield self.edge_attr[feature_name][idx] + + def to_triplets(self) -> Iterator[TripletLike]: + return iter(self.edge_attr[EDGE_PID]) + + def save(self, path: str) -> None: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + with open(path + "/edges", "wb") as f: + pkl.dump(self._edges, f) + with open(path + "/nodes", "wb") as f: + pkl.dump(self._nodes, f) + + with open(path + "/mapped_edges", "wb") as f: + pkl.dump(self._mapped_edge_features, f) + with open(path + "/mapped_nodes", "wb") as f: + pkl.dump(self._mapped_node_features, f) + + node_attr_path = path + "/node_attr" + os.makedirs(node_attr_path, exist_ok=True) + for attr_name, vals in self.node_attr.items(): + torch.save(vals, node_attr_path + f"/{attr_name}.pt") + + edge_attr_path = path + "/edge_attr" + os.makedirs(edge_attr_path, exist_ok=True) + for attr_name, vals in self.edge_attr.items(): + torch.save(vals, edge_attr_path + f"/{attr_name}.pt") + + @classmethod + def from_disk(cls, path: str) -> "LargeGraphIndexer": + indexer = cls(list(), list()) + with open(path + "/edges", "rb") as f: + indexer._edges = pkl.load(f) + with open(path + "/nodes", "rb") as f: + indexer._nodes = pkl.load(f) + + with open(path + "/mapped_edges", "rb") as f: + indexer._mapped_edge_features = pkl.load(f) + with open(path + "/mapped_nodes", "rb") as f: + indexer._mapped_node_features = pkl.load(f) + + node_attr_path = path + "/node_attr" + for fname in os.listdir(node_attr_path): + full_fname = f"{node_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.node_attr[key] = torch.load(full_fname) + + edge_attr_path = path + "/edge_attr" + for fname in os.listdir(edge_attr_path): + full_fname = f"{edge_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.edge_attr[key] = torch.load(full_fname) + + return indexer + + def to_data(self, node_feature_name: str, + edge_feature_name: Optional[str] = None) -> Data: + """Return a Data object containing all the specified node and + edge features and the graph. + + Args: + node_feature_name (str): Feature to use for nodes + edge_feature_name (Optional[str], optional): Feature to use for + edges. Defaults to None. + + Returns: + Data: Data object containing the specified node and + edge features and the graph. + """ + x = torch.Tensor(self.get_node_features(node_feature_name)) + node_id = torch.LongTensor(range(len(x))) + + edge_index = torch.t( + torch.LongTensor(self.get_edge_features(EDGE_INDEX))) + + edge_attr = (self.get_edge_features(edge_feature_name) + if edge_feature_name is not None else None) + edge_id = torch.LongTensor(range(len(edge_attr))) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + edge_id=edge_id, node_id=node_id) + + def __eq__(self, value: "LargeGraphIndexer") -> bool: + eq = True + eq &= self._nodes == value._nodes + eq &= self._edges == value._edges + eq &= self.node_attr.keys() == value.node_attr.keys() + eq &= self.edge_attr.keys() == value.edge_attr.keys() + eq &= self._mapped_node_features == value._mapped_node_features + eq &= self._mapped_edge_features == value._mapped_edge_features + + for k in self.node_attr: + eq &= isinstance(self.node_attr[k], type(value.node_attr[k])) + if isinstance(self.node_attr[k], torch.Tensor): + eq &= torch.equal(self.node_attr[k], value.node_attr[k]) + else: + eq &= self.node_attr[k] == value.node_attr[k] + for k in self.edge_attr: + eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k])) + if isinstance(self.edge_attr[k], torch.Tensor): + eq &= torch.equal(self.edge_attr[k], value.edge_attr[k]) + else: + eq &= self.edge_attr[k] == value.edge_attr[k] + return eq + + +def get_features_for_triplets_groups( + indexer: LargeGraphIndexer, + triplet_groups: Iterable[KnowledgeGraphLike], + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Iterator[Data]: + """Given an indexer and a series of triplet groups (like a dataset), + retrieve the specified node and edge features for each triplet from the + index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of + triplets to fetch features for + node_feature_name (str, optional): Node feature to fetch. + Defaults to "x". + edge_feature_name (str, optional): edge feature to fetch. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing to perform on triplets. + Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Yields: + Iterator[Data]: For each triplet group, yield a data object containing + the unique graph and features from the index. + """ + if pre_transform is not None: + + def apply_transform(trips): + for trip in trips: + yield pre_transform(tuple(trip)) + + # TODO: Make this safe for large amounts of triplets? + triplet_groups = (list(apply_transform(triplets)) + for triplets in triplet_groups) + + node_keys = [] + edge_keys = [] + edge_index = [] + + for triplets in tqdm(triplet_groups, disable=not verbose): + small_graph_indexer = LargeGraphIndexer.from_triplets( + triplets, pre_transform=pre_transform) + + node_keys.append(small_graph_indexer.get_node_features()) + edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets)) + edge_index.append( + small_graph_indexer.get_edge_features(EDGE_INDEX, triplets)) + + node_feats = indexer.get_node_features(feature_name=node_feature_name, + pids=chain.from_iterable(node_keys)) + edge_feats = indexer.get_edge_features(feature_name=edge_feature_name, + pids=chain.from_iterable(edge_keys)) + + last_node_idx, last_edge_idx = 0, 0 + for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index): + nlen, elen = len(nkeys), len(ekeys) + x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen]) + last_node_idx += len(nkeys) + + edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx + + elen]) + last_edge_idx += len(ekeys) + + edge_idx = torch.LongTensor(eidx).T + + data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) + data_obj[NODE_PID] = node_keys + data_obj[EDGE_PID] = edge_keys + data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys] + data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys] + + yield data_obj + + +def get_features_for_triplets( + indexer: LargeGraphIndexer, + triplets: KnowledgeGraphLike, + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Data: + """For a given set of triplets retrieve a Data object containing the + unique graph and features from the index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplets (KnowledgeGraphLike): Triplets to fetch features for + node_feature_name (str, optional): Feature to use for node features. + Defaults to "x". + edge_feature_name (str, optional): Feature to use for edge features. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function for triplets. Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Returns: + Data: Data object containing the unique graph and features from the + index for the given triplets. + """ + gen = get_features_for_triplets_groups(indexer, [triplets], + node_feature_name, + edge_feature_name, pre_transform, + verbose) + return next(gen) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 96d51032d818..236753293c2e 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -111,6 +111,7 @@ import torch_geometric.datasets.utils homo_datasets = [ + 'WebQSPDataset', 'KarateClub', 'TUDataset', 'GNNBenchmarkDataset', diff --git a/torch_geometric/datasets/web_qsp_dataset.py b/torch_geometric/datasets/web_qsp_dataset.py index bcc85e070920..fd3caeb99b70 100644 --- a/torch_geometric/datasets/web_qsp_dataset.py +++ b/torch_geometric/datasets/web_qsp_dataset.py @@ -1,12 +1,21 @@ # Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630 -from typing import Any, Dict, List, Tuple, no_type_check +import os +from itertools import chain +from typing import Any, Iterator, List, Tuple, no_type_check import numpy as np import torch from torch import Tensor from tqdm import tqdm -from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.data import ( + Data, + InMemoryDataset, + LargeGraphIndexer, + TripletLike, + get_features_for_triplets_groups, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION from torch_geometric.nn.nlp import SentenceTransformer @@ -19,9 +28,11 @@ def retrieval_via_pcst( topk: int = 3, topk_e: int = 3, cost_e: float = 0.5, + save_idx: bool = False, + override: bool = False, ) -> Tuple[Data, str]: c = 0.01 - if len(textual_nodes) == 0 or len(textual_edges) == 0: + if len(textual_nodes) == 0 or len(textual_edges) == 0 or override: desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv( index=False, columns=["src", "edge_attr", "dst"], @@ -114,15 +125,28 @@ def retrieval_via_pcst( src = [mapping[i] for i in edge_index[0].tolist()] dst = [mapping[i] for i in edge_index[1].tolist()] + # HACK Added so that the subset of nodes and edges selected can be tracked + if save_idx: + node_idx = np.array(data.node_idx)[selected_nodes] + edge_idx = np.array(data.edge_idx)[selected_edges] + data = Data( x=data.x[selected_nodes], edge_index=torch.tensor([src, dst]), edge_attr=data.edge_attr[selected_edges], ) + if save_idx: + data['node_idx'] = node_idx + data['edge_idx'] = edge_idx return data, desc +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return str(h).lower(), str(r), str(t).lower() + + class WebQSPDataset(InMemoryDataset): r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse Labeling for Knowledge Base Question Answering" @@ -135,107 +159,214 @@ class WebQSPDataset(InMemoryDataset): If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) + limit (int, optional): Construct only the first n samples. + Defaults to -1 to construct all samples. + include_pcst (bool, optional): Whether to include PCST step + (See GRetriever paper). Defaults to True. + verbose (bool, optional): Whether to print output. Defaults to False. """ def __init__( self, root: str, split: str = "train", force_reload: bool = False, + limit: int = -1, + include_pcst: bool = True, + verbose: bool = False, ) -> None: + self.limit = limit + self.split = split + self.include_pcst = include_pcst + # TODO Confirm why the dependency checks and device setting were removed here # noqa + ''' + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self._check_dependencies() + ''' + self.verbose = verbose + self.force_reload = force_reload super().__init__(root, force_reload=force_reload) - if split not in {'train', 'val', 'test'}: + if split not in {'train', 'val', 'test'} and limit < 0: raise ValueError(f"Invalid 'split' argument (got {split})") - path = self.processed_paths[['train', 'val', 'test'].index(split)] - self.load(path) + self._load_raw_data() + self.load(self.processed_paths[0]) + + ''' + def _check_dependencies(self) -> None: + missing_str_list = [] + if not WITH_PCST: + missing_str_list.append('pcst_fast') + if not WITH_DATASETS: + missing_str_list.append('datasets') + if not WITH_PANDAS: + missing_str_list.append('pandas') + if len(missing_str_list) > 0: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + ''' + + @property + def raw_file_names(self) -> List[str]: + return ["raw_data", "split_idxs"] @property def processed_file_names(self) -> List[str]: - return ['train_data.pt', 'val_data.pt', 'test_data.pt'] + file_lst = [ + "train_data.pt", + "val_data.pt", + "test_data.pt", + "pre_filter.pt", + "pre_transform.pt", + "large_graph_indexer", + ] + split_file = file_lst.pop(['train', 'val', 'test'].index(self.split)) + file_lst.insert(0, split_file) + return file_lst + + def _save_raw_data(self) -> None: + self.raw_dataset.save_to_disk(self.raw_paths[0]) + torch.save(self.split_idxs, self.raw_paths[1]) + + def _load_raw_data(self) -> None: + import datasets + if not hasattr(self, "raw_dataset"): + self.raw_dataset = datasets.load_from_disk(self.raw_paths[0]) + if not hasattr(self, "split_idxs"): + self.split_idxs = torch.load(self.raw_paths[1]) - def process(self) -> None: + def download(self) -> None: import datasets - import pandas as pd - - datasets = datasets.load_dataset('rmanluo/RoG-webqsp') - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model_name = 'sentence-transformers/all-roberta-large-v1' - model = SentenceTransformer(model_name).to(device) - model.eval() - - for dataset, path in zip( - [datasets['train'], datasets['validation'], datasets['test']], - self.processed_paths, - ): - questions = [example["question"] for example in dataset] - question_embs = model.encode( - questions, - batch_size=256, - output_device='cpu', + + dataset = datasets.load_dataset("rmanluo/RoG-webqsp") + self.raw_dataset = datasets.concatenate_datasets( + [dataset["train"], dataset["validation"], dataset["test"]]) + self.split_idxs = { + "train": + torch.arange(len(dataset["train"])), + "val": + torch.arange(len(dataset["validation"])) + len(dataset["train"]), + "test": + torch.arange(len(dataset["test"])) + len(dataset["train"]) + + len(dataset["validation"]), + } + + if self.limit >= 0: + self.raw_dataset = self.raw_dataset.select(range(self.limit)) + + # HACK + self.split_idxs = { + "train": + torch.arange(self.limit // 2), + "val": + torch.arange(self.limit // 4) + self.limit // 2, + "test": + torch.arange(self.limit // 4) + self.limit // 2 + + self.limit // 4, + } + self._save_raw_data() + + def _get_trips(self) -> Iterator[TripletLike]: + return chain.from_iterable( + iter(ds["graph"]) for ds in self.raw_dataset) + + def _build_graph(self) -> None: + trips = self._get_trips() + self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets( + trips, pre_transform=preprocess_triplet) + + # Nodes: + nodes = self.indexer.get_unique_node_features() + x = self.model.encode( + nodes, # type: ignore + batch_size=256, + output_device='cpu') + self.indexer.add_node_feature(new_feature_name="x", new_feature_vals=x) + + # Edges: + edges = self.indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + edge_attr = self.model.encode( + edges, # type: ignore + batch_size=256, + output_device='cpu') + self.indexer.add_edge_feature( + new_feature_name="edge_attr", + new_feature_vals=edge_attr, + map_from_feature=EDGE_RELATION, + ) + + print("Saving graph...") + self.indexer.save(self.processed_paths[-1]) + + def _retrieve_subgraphs(self) -> None: + print("Encoding questions...") + self.questions = [str(ds["question"]) for ds in self.raw_dataset] + q_embs = self.model.encode(self.questions, batch_size=256, + output_device='cpu') + list_of_graphs = [] + print("Retrieving subgraphs...") + textual_nodes = self.textual_nodes + textual_edges = self.textual_edges + graph_gen = get_features_for_triplets_groups( + self.indexer, (ds['graph'] for ds in self.raw_dataset), + pre_transform=preprocess_triplet, verbose=self.verbose) + + for index in tqdm(range(len(self.raw_dataset)), + disable=not self.verbose): + data_i = self.raw_dataset[index] + graph = next(graph_gen) + textual_nodes = self.textual_nodes.iloc[ + graph["node_idx"]].reset_index() + textual_edges = self.textual_edges.iloc[ + graph["edge_idx"]].reset_index() + pcst_subgraph, desc = retrieval_via_pcst( + graph, + q_embs[index], + textual_nodes, + textual_edges, + topk=3, + topk_e=5, + cost_e=0.5, + override=not self.include_pcst, ) + question = f"Question: {data_i['question']}\nAnswer: " + label = ("|").join(data_i["answer"]).lower() - data_list = [] - for i, example in enumerate(tqdm(dataset)): - raw_nodes: Dict[str, int] = {} - raw_edges = [] - for tri in example["graph"]: - h, r, t = tri - h = h.lower() - t = t.lower() - if h not in raw_nodes: - raw_nodes[h] = len(raw_nodes) - if t not in raw_nodes: - raw_nodes[t] = len(raw_nodes) - raw_edges.append({ - "src": raw_nodes[h], - "edge_attr": r, - "dst": raw_nodes[t] - }) - nodes = pd.DataFrame([{ - "node_id": v, - "node_attr": k, - } for k, v in raw_nodes.items()], - columns=["node_id", "node_attr"]) - edges = pd.DataFrame(raw_edges, - columns=["src", "edge_attr", "dst"]) - - nodes.node_attr = nodes.node_attr.fillna("") - x = model.encode( - nodes.node_attr.tolist(), - batch_size=256, - output_device='cpu', - ) - edge_attr = model.encode( - edges.edge_attr.tolist(), - batch_size=256, - output_device='cpu', - ) - edge_index = torch.tensor([ - edges.src.tolist(), - edges.dst.tolist(), - ], dtype=torch.long) - - question = f"Question: {example['question']}\nAnswer: " - label = ('|').join(example['answer']).lower() - data = Data( - x=x, - edge_index=edge_index, - edge_attr=edge_attr, - ) - data, desc = retrieval_via_pcst( - data, - question_embs[i], - nodes, - edges, - topk=3, - topk_e=5, - cost_e=0.5, - ) - data.question = question - data.label = label - data.desc = desc - data_list.append(data) - - self.save(data_list, path) + pcst_subgraph["question"] = question + pcst_subgraph["label"] = label + pcst_subgraph["desc"] = desc + list_of_graphs.append(pcst_subgraph.to("cpu")) + print("Saving subgraphs...") + self.save(list_of_graphs, self.processed_paths[0]) + + def process(self) -> None: + from pandas import DataFrame + self._load_raw_data() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = SentenceTransformer( + 'sentence-transformers/all-roberta-large-v1').to(device) + self.model.eval() + if self.force_reload or not os.path.exists(self.processed_paths[-1]): + print("Encoding graph...") + self._build_graph() + else: + print("Loading graph...") + self.indexer = LargeGraphIndexer.from_disk( + self.processed_paths[-1]) + self.textual_nodes = DataFrame.from_dict( + {"node_attr": self.indexer.get_node_features()}) + self.textual_nodes["node_id"] = self.textual_nodes.index + self.textual_nodes = self.textual_nodes[["node_id", "node_attr"]] + self.textual_edges = DataFrame(self.indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) + self.textual_edges["src"] = [ + self.indexer._nodes[h] for h in self.textual_edges["src"] + ] + self.textual_edges["dst"] = [ + self.indexer._nodes[h] for h in self.textual_edges["dst"] + ] + self._retrieve_subgraphs() diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 266f498a113b..7e83c35befb6 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,6 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin +from .rag_loader import RAGQueryLoader __all__ = classes = [ 'DataLoader', @@ -50,6 +51,7 @@ 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', + 'RAGQueryLoader', ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py new file mode 100644 index 000000000000..33d6cf0e868e --- /dev/null +++ b/torch_geometric/loader/rag_loader.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from torch_geometric.data import Data, FeatureStore, HeteroData +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +class RAGFeatureStore(Protocol): + """Feature store for remote GNN RAG backend.""" + @abstractmethod + def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: + """Makes a comparison between the query and all the nodes to get all + the closest nodes. Return the indices of the nodes that are to be seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: + """Makes a comparison between the query and all the edges to get all + the closest nodes. Returns the edge indices that are to be the seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + """Combines sampled subgraph output with features in a Data object.""" + ... + + +class RAGGraphStore(Protocol): + """Graph store for remote GNN RAG backend.""" + @abstractmethod + def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, + **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample a subgraph using the seeded nodes and edges.""" + ... + + @abstractmethod + def register_feature_store(self, feature_store: FeatureStore): + """Register a feature store to be used with the sampler. Samplers need + info from the feature store in order to work properly on HeteroGraphs. + """ + ... + + +# TODO: Make compatible with Heterographs + + +class RAGQueryLoader: + def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None): + """Loader meant for making queries from a remote backend. + + Args: + data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore + and GraphStore to load from. Assumed to conform to the + protocols listed above. + local_filter (Optional[Callable[[Data, Any], Data]], optional): + Optional local transform to apply to data after retrieval. + Defaults to None. + seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters + to pass into process for fetching seed nodes. Defaults to None. + seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters + to pass into process for fetching seed edges. Defaults to None. + sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for sampling graph. Defaults to None. + loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for loading graph features. Defaults to None. + """ + fstore, gstore = data + self.feature_store = fstore + self.graph_store = gstore + self.graph_store.register_feature_store(self.feature_store) + self.local_filter = local_filter + self.seed_nodes_kwargs = seed_nodes_kwargs or {} + self.seed_edges_kwargs = seed_edges_kwargs or {} + self.sampler_kwargs = sampler_kwargs or {} + self.loader_kwargs = loader_kwargs or {} + + def query(self, query: Any) -> Data: + """Retrieve a subgraph associated with the query with all its feature + attributes. + """ + seed_nodes = self.feature_store.retrieve_seed_nodes( + query, **self.seed_nodes_kwargs) + seed_edges = self.feature_store.retrieve_seed_edges( + query, **self.seed_edges_kwargs) + + subgraph_sample = self.graph_store.sample_subgraph( + seed_nodes, seed_edges, **self.sampler_kwargs) + + data = self.feature_store.load_subgraph(sample=subgraph_sample, + **self.loader_kwargs) + + if self.local_filter: + data = self.local_filter(data, query) + return data diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py index 6f8fbcc644dc..f7529ae721b7 100644 --- a/torch_geometric/nn/models/g_retriever.py +++ b/torch_geometric/nn/models/g_retriever.py @@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module): (default: :obj:`False`) mlp_out_channels (int, optional): The size of each graph embedding after projection. (default: :obj:`4096`) + mlp_out_tokens (int, optional): Number of LLM prefix tokens to + reserve for GNN output. (default: :obj:`1`) .. warning:: This module has been tested with the following HuggingFace models @@ -43,6 +45,7 @@ def __init__( gnn: torch.nn.Module, use_lora: bool = False, mlp_out_channels: int = 4096, + mlp_out_tokens: int = 1, ) -> None: super().__init__() @@ -77,7 +80,9 @@ def __init__( self.projector = torch.nn.Sequential( torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), torch.nn.Sigmoid(), - torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), + torch.nn.Linear(mlp_hidden_channels, + mlp_out_channels * mlp_out_tokens), + torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)), ).to(self.llm.device) def encode( @@ -126,6 +131,9 @@ def forward( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) @@ -182,6 +190,9 @@ def inference( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py index 833ee657d0e7..22d3039f4c83 100644 --- a/torch_geometric/profile/__init__.py +++ b/torch_geometric/profile/__init__.py @@ -20,6 +20,7 @@ get_gpu_memory_from_nvidia_smi, get_model_size, ) +from .nvtx import nvtxit __all__ = [ 'profileit', @@ -38,6 +39,7 @@ 'get_gpu_memory_from_nvidia_smi', 'get_gpu_memory_from_ipex', 'benchmark', + 'nvtxit', ] classes = __all__ diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py new file mode 100644 index 000000000000..8dbce375ae5a --- /dev/null +++ b/torch_geometric/profile/nvtx.py @@ -0,0 +1,66 @@ +from functools import wraps +from typing import Optional + +import torch + +CUDA_PROFILE_STARTED = False + + +def begin_cuda_profile(): + global CUDA_PROFILE_STARTED + prev_state = CUDA_PROFILE_STARTED + if prev_state is False: + CUDA_PROFILE_STARTED = True + torch.cuda.cudart().cudaProfilerStart() + return prev_state + + +def end_cuda_profile(prev_state: bool): + global CUDA_PROFILE_STARTED + CUDA_PROFILE_STARTED = prev_state + if prev_state is False: + torch.cuda.cudart().cudaProfilerStop() + + +def nvtxit(name: Optional[str] = None, n_warmups: int = 0, + n_iters: Optional[int] = None): + """Enables NVTX profiling for a function. + + Args: + name (Optional[str], optional): Name to give the reference frame for + the function being wrapped. Defaults to the name of the + function in code. + n_warmups (int, optional): Number of iters to call that function + before starting. Defaults to 0. + n_iters (Optional[int], optional): Number of iters of that function to + record. Defaults to all of them. + """ + def nvtx(func): + + nonlocal name + iters_so_far = 0 + if name is None: + name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal iters_so_far + if not torch.cuda.is_available(): + return func(*args, **kwargs) + elif iters_so_far < n_warmups: + iters_so_far += 1 + return func(*args, **kwargs) + elif n_iters is None or iters_so_far < n_iters + n_warmups: + prev_state = begin_cuda_profile() + torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") + result = func(*args, **kwargs) + torch.cuda.nvtx.range_pop() + end_cuda_profile(prev_state) + iters_so_far += 1 + return result + else: + return func(*args, **kwargs) + + return wrapper + + return nvtx