From 5570242624e163e541716a993918b7d534de6369 Mon Sep 17 00:00:00 2001 From: "Valentin G." Date: Mon, 18 Nov 2024 19:51:32 +0100 Subject: [PATCH 01/45] solve issue 9755 (fix typo) (#9790) Fixed the typo in the description of NeighborLoader. --- torch_geometric/loader/neighbor_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 341f2f5a23b6..5814724f0c48 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader): This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. - More specifically, :obj:`num_neighbors` denotes how much neighbors are + More specifically, :obj:`num_neighbors` denotes how many neighbors are sampled for each node in each iteration. :class:`~torch_geometric.loader.NeighborLoader` takes in this list of :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for From e1a925b792631d4362e4dab1cc75d64f41848250 Mon Sep 17 00:00:00 2001 From: Junhao Shen Date: Tue, 19 Nov 2024 14:18:30 -0600 Subject: [PATCH 02/45] add GLEM model, TAGDataset and example of GLEM (#9662) reopened #9591 Feature summary: - Add GLEM as GNN & LLM Co-training model to PyG - adapt GLEM's LM to AutoModelForSequenceClassification from transformers - Lora support - LM/LLM support - ogbn-products/ogbn-arxiv testing finished - TAGDataset can be used as a wrapper class for any node classification dataset in PyG with LM tokenizer and associate raw text - external prediction as pseudo labels supported --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishi Puri Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 2 + examples/llm/README.md | 7 +- examples/llm/glem.py | 443 ++++++++++++++++++++++++ torch_geometric/datasets/__init__.py | 2 + torch_geometric/datasets/tag_dataset.py | 350 +++++++++++++++++++ torch_geometric/nn/models/__init__.py | 3 +- torch_geometric/nn/models/glem.py | 384 ++++++++++++++++++++ 7 files changed, 1187 insertions(+), 4 deletions(-) create mode 100644 examples/llm/glem.py create mode 100644 torch_geometric/datasets/tag_dataset.py create mode 100644 torch_geometric/nn/models/glem.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a782026189..91da66973cef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) +- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) - Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780)) - Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776)) diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..e0ac02d87f2e 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,6 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| Example | Description | +| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | diff --git a/examples/llm/glem.py b/examples/llm/glem.py new file mode 100644 index 000000000000..ec76cef4c010 --- /dev/null +++ b/examples/llm/glem.py @@ -0,0 +1,443 @@ +"""This example run GLEM model using PyG. +Original Paper: https://arxiv.org/abs/2210.14709 +“Learning on Large-scale Text-attributed Graphs via Variational Inference“. +Requirements on top of basic PyG: +`pip install ogb transformers peft tqdm`. +GLEM is a data augmentation co-training strategy for LM and GNN, our +implementation extended original implementation from LM to LLM and opt for LoRA +from peft. + +``note:: + use addtional trick, please add your external prediction by assigning + `ext_pred_path` and combine it into pretraining phase and node features +""" + +import argparse +import os +import os.path as osp +import time + +import torch +from ogb.nodeproppred import Evaluator, PygNodePropPredDataset + +from torch_geometric import seed_everything +from torch_geometric.data import download_google_url +from torch_geometric.datasets import TAGDataset +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE + + +def get_n_params(model): + pp = 0 + for p in list(model.parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +def main(args): + gpu = args.gpu + dataset_name = args.dataset + root = osp.join('data', 'ogb') + hf_model = args.hf_model + pl_ratio = args.pl_ratio + gnn_lr = args.gnn_lr + lm_lr = args.lm_lr + em_order = args.em_order + gnn_epochs = args.gnn_epochs + lm_epochs = args.lm_epochs + patience = args.patience + verbose = args.verbose + out_dir = args.out_dir + lm_batch_size = args.lm_batch_size + gnn_batch_size = args.gnn_batch_size + lm_use_lora = args.lm_use_lora + token_on_disk = args.token_on_disk + num_em_iters = args.num_em_iters + start_time = time.time() + train_without_ext_pred = args.train_without_ext_pred + ext_pred = None + pretrain_augmented = False + ext_pseudo_labels = None + device = torch.device( + f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') + print(f'Running on: {torch.cuda.get_device_name({gpu})}') + torch.cuda.empty_cache() + + if not train_without_ext_pred: + ext_pred_path = download_google_url( + id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', + folder='data/ogb/ogbn_products/ext_preds', + filename='giant_sagn_scr.pt', log=True) + ext_pred = torch.load(ext_pred_path, map_location=device) + ext_pseudo_labels = ext_pred.argmax(dim=-1) + pretrain_augmented = True + + seed_everything(42) + + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + split_idx = dataset.get_idx_split() + data = dataset.data + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + text_dataset = tag_dataset.to_text_dataset() + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + + num_classes = tag_dataset.num_classes + num_features = data.num_features + # =========================== LM Data split =============================== + split_idx = tag_dataset.get_idx_split() + + # GLEM train with augmented data, mark original train data as gold data, + gold_idx = split_idx['train'] + split_idx['valid'] + test_idx = split_idx['test'] + + # randome sample pseudo labels nodes, generate their index + num_pseudo_labels = int(gold_idx.numel() * pl_ratio) + idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] + pseudo_labels_idx = test_idx[idx_to_select] + train_idx = torch.cat( + (gold_idx, pseudo_labels_idx)) # augmented train_indx + + print(f'train_idx: {train_idx.size(0)}, ' + f'gold_idx: {gold_idx.size(0)}, ' + f'pseudo labels ratio: {pl_ratio}, ' + f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') + gold_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=gold_idx) + train_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=train_idx) + # ========================== LM Data Loader =============================== + + print('Building language model dataloader...', end='-->') + + # if set train_without_ext_pred == True, use this for pretrain + text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + # training with augmented data, + text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + text_test_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, + shuffle=False) + print('done') + + # =========================== GNN Data Loader ============================= + initial_memory = torch.cuda.memory_allocated() + data = data.to(device) + if ext_pred is not None: + data.x = torch.cat((data.x, ext_pred), dim=1) + num_features += ext_pred.size(1) + current_memory_1 = torch.cuda.max_memory_allocated() + # 1 GB = 1073741824 Byte + gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 + # Print the maximum memory usage after running the model + print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') + + print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') + + # train on gold data w/o pseudo labels + graph_pretrain_loader = NeighborLoader( + data, + input_nodes=gold_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # graph data loader w/ pseudo labels in M-step + graph_train_loader = NeighborLoader( + data, + input_nodes=train_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # for gnn inference + subgraph_loader = NeighborLoader( + data, + input_nodes=None, + num_neighbors=[-1], + batch_size=gnn_batch_size * 4, + num_workers=12, + persistent_workers=True, + ) + # =========================== internal function =========================== + + evaluator = Evaluator(name=f'ogbn-{dataset_name}') + + def evaluate(out, split): + y_true = data.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + train_acc, val_acc, test_acc = None, None, None + if 'train' in split: + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + if 'valid' in split: + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + if 'test' in split: + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + + # =========================== Build GNN Model ============================= + gnn = None + if args.gnn_model == 'SAGE': + gnn = GraphSAGE( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + elif args.gnn_model == 'GAT': + gnn = GAT(in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, heads=args.gat_heads) + else: + gnn = GCN( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + + print("# GNN Params:", get_n_params(gnn)) + # =========================== Build LM Model ============================== + + model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, + lm_use_lora=lm_use_lora, device=device) + lm = model.lm + print("# LM Params:", get_n_params(lm)) + gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) + lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) + + def load_model(em_phase): + print(f'Move {em_phase} model from cpu memory') + if em_phase == 'lm': + model.lm = model.lm.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) + if em_phase == 'gnn': + model.gnn = model.gnn.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) + return optimizer + + # ================================= Run GLEM ============================== + preds_filename = 'lm_pretrain' + preds_dir = f'{out_dir}preds/{dataset_name}/' + gnn_test_acc = 0.0 + lm_test_acc = 0.0 + # =============================== GLEM pretraining ======================== + pretrain_phase = 'lm' + if em_order == 'lm': + pretrain_phase = 'gnn' + pretrain_start_time = time.time() + # pretraining + pretrain_loader = graph_pretrain_loader + test_loader = subgraph_loader + pretrain_num_epochs = gnn_epochs + pretrain_opt = gnn_opt + if pretrain_phase == 'gnn': + model.gnn = model.gnn.to(device) + print('pretraining gnn to generate pseudo labels') + if not train_without_ext_pred: + pretrain_loader = graph_train_loader + preds_filename = 'gnn_pretrain' + elif pretrain_phase == 'lm': + model.lm = model.lm.to(device) + print('pretraining lm to generate pseudo labels') + pretrain_num_epochs = lm_epochs + pretrain_loader = text_pretrain_loader + test_loader = text_test_loader + pretrain_opt = lm_opt + if not train_without_ext_pred: + pretrain_loader = text_train_loader + preds_filename = 'lm_pretrain' + + early_stopping = 0 + best_val_acc = 0.0 + for epoch in range(1, pretrain_num_epochs + 1): + acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, + ext_pseudo_labels, epoch, pretrain_augmented, + verbose) + if epoch >= 5 or epoch == pretrain_num_epochs: + pretrain_preds = model.inference(pretrain_phase, test_loader, + verbose=verbose) + train_acc, val_acc, _ = evaluate(pretrain_preds, + ['train', 'valid']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Pretrain Early stopped by Epoch: {epoch}') + break + else: + best_val_acc = val_acc + preds = model.inference(pretrain_phase, test_loader, verbose=verbose) + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + if pretrain_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + else: + lm_test_acc = max(lm_test_acc, test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + torch.cuda.empty_cache() + + pretrain_phase_time = time.time() - pretrain_start_time + print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') + os.makedirs(osp.dirname(preds_dir), exist_ok=True) + torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) + print( + f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + # EM iterations + + em_phase = em_order + """ + We run E-step(LM training) and M-Step(GNN training) alternatively in each + em iterations, so the total number of iterations is num_em_iter * 2 and + we switch the em_phase at end of each iteration in following loop + """ + gnn_val_acc = lm_val_acc = 0.0 + for em_it in range(1, num_em_iters * 2 + 1): + pseudo_labels = preds.argmax(dim=-1) + best_val_acc = 0.0 + print(f'EM iteration: {em_it}, EM phase: {em_phase}') + optimizer = load_model(em_phase) + num_epochs = lm_epochs + train_loader = text_train_loader + test_loader = text_test_loader + early_stopping = 0 + if em_phase == 'gnn': + train_loader = graph_train_loader + num_epochs = gnn_epochs + test_loader = subgraph_loader + for epoch in range(1, num_epochs + 1): + acc, loss = model.train(em_phase, train_loader, optimizer, + pseudo_labels, epoch, True, verbose) + if epoch >= 5 or epoch == num_epochs: + cur_preds = model.inference(em_phase, test_loader, + verbose=verbose) + train_acc, val_acc, _ = evaluate(cur_preds, ['train', 'valid']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f},') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'''Early stopped by Epoch: {epoch}, \ + Best acc: {best_val_acc}''') + break + else: + best_val_acc = val_acc + + preds = model.inference(em_phase, test_loader, verbose=verbose) + if em_phase == 'gnn': + gnn_val_acc = max(gnn_val_acc, best_val_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + em_phase = 'lm' + else: + lm_val_acc = max(lm_val_acc, best_val_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + em_phase = 'gnn' + torch.cuda.empty_cache() + print(f'Best GNN validation acc: {gnn_val_acc},' + f'LM validation acc: {lm_val_acc}') + print('============================') + if gnn_val_acc > lm_val_acc: + em_phase = 'gnn' + model.gnn = model.gnn.to(device, non_blocking=True) + else: + em_phase = 'lm' + model.lm = model.lm.to(device, non_blocking=True) + test_preds = model.inference(em_phase, test_loader, verbose=verbose) + train_acc, val_acc, test_acc = evaluate(test_preds, + ['train', 'valid', 'test']) + final_test_acc = max(gnn_test_acc, max(lm_test_acc, test_acc)) + print(f'Best test acc: {final_test_acc}, model: {em_phase}') + end_time = time.time() + running_time = (end_time - start_time) / 3600 + print(f'Total running time: {running_time:.2f} hours') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GLEM Example:') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_runs', type=int, default=10, + help='number of runs') + parser.add_argument('--num_em_iters', type=int, default=1, + help='number of iterations') + parser.add_argument("--dataset", type=str, default='products', + help='arxiv or products') + parser.add_argument("--pl_ratio", type=float, default=0.5, + help="pseudo labels ratio") + parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', + help='huggingface model repo id') + parser.add_argument( + '--gnn_model', type=str, default='SAGE', + help='gnn model for node classification,' + 'options: SAGE, GAT, GCN') + parser.add_argument('--gnn_hidden_channels', type=int, default=256) + parser.add_argument('--gnn_num_layers', type=int, default=3) + parser.add_argument('--gat_heads', type=int, default=4, + help='Number of multi-head-attentions for GAT ') + parser.add_argument('--lm_batch_size', type=int, default=256) + parser.add_argument('--gnn_batch_size', type=int, default=1024) + parser.add_argument( + '--external_pred_path', type=str, default=None, + help="Other model's output logits during the " + "pretraining phase or simply concatenate it with" + "node features as augmented data for gnn") + parser.add_argument('--alpha', type=float, default=0.5, + help='pseudo label weight in E-step') + parser.add_argument('--beta', type=float, default=0.5, + help='pseudo label weight in M-step') + parser.add_argument('--lm_epochs', type=int, default=10) + parser.add_argument('--gnn_epochs', type=int, default=50) + parser.add_argument('--gnn_lr', type=float, default=0.002) + parser.add_argument('--lm_lr', type=float, default=0.001) + parser.add_argument('--patience', type=int, default=3, + help='Patience for early stopping') + parser.add_argument('--verbose', action='store_true', + help='show progress bar during training or not') + parser.add_argument('--em_order', type=str, default='lm', + help='decide train LM first or GNN first') + parser.add_argument('--lm_use_lora', action='store_true', + help='use Lora to fine-tune model or not') + parser.add_argument( + '--token_on_disk', action='store_true', + help='save token on disk and load token from disk' + 'for reducing duplicated tokenizing') + parser.add_argument('--out_dir', type=str, default='output/', + help='output directory') + parser.add_argument( + '--train_without_ext_pred', action='store_true', + help='train glem without using additional pseudo labels ' + 'for augmenting data only available for ogbn-products') + args = parser.parse_args() + print(args) + main(args) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 96d51032d818..0b6569d3f92b 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .tag_dataset import TAGDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -190,6 +191,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'TAGDataset', ] hetero_datasets = [ diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py new file mode 100644 index 000000000000..f25992ced989 --- /dev/null +++ b/torch_geometric/datasets/tag_dataset.py @@ -0,0 +1,350 @@ +import os +import os.path as osp +from collections.abc import Sequence +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import InMemoryDataset, download_google_url +from torch_geometric.data.data import BaseData + +try: + from pandas import DataFrame, read_csv + WITH_PANDAS = True +except ImportError: + WITH_PANDAS = False + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class TAGDataset(InMemoryDataset): + r"""The Text Attributed Graph datasets from the + `"Learning on Large-scale Text-attributed Graphs via Variational Inference + " `_ paper. + This dataset is aiming on transform `ogbn products`, `ogbn arxiv` + into Text Attributed Graph that each node in graph is associate with a + raw text, that dataset can be adapt to DataLoader (for LM training) and + NeighborLoader(for GNN training). In addition, this class can be use as a + wrapper class by convert a InMemoryDataset with Tokenizer and text into + Text Attributed Graph. + + Args: + root (str): Root directory where the dataset should be saved. + dataset (InMemoryDataset): The name of the dataset + (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`). + tokenizer_name (str): The tokenizer name for language model, + Be sure to use same tokenizer name as your `model id` of model repo + on huggingface.co. + text (List[str]): list of raw text associate with node, the order of + list should be align with node list + split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, + for saving split index, it is required that if your dataset doesn't + have get_split_idx function + tokenize_batch_size (int): batch size of tokenizing text, the + tokenizing process will run on cpu, default: 256 + token_on_disk (bool): save token as .pt file on disk or not, + default: False + text_on_disk (bool): save given text(list of str) as dataframe on disk + or not, default: False + force_reload (bool): default: False + .. note:: + See `example/llm_plus_gnn/glem.py` for example usage + """ + raw_text_id = { + 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', + 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' + } + + def __init__(self, root: str, dataset: InMemoryDataset, + tokenizer_name: str, text: Optional[List[str]] = None, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False) -> None: + # list the vars you want to pass in before run download & process + self.name = dataset.name + self.text = text + self.tokenizer_name = tokenizer_name + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.dir_name = '_'.join(dataset.name.split('-')) + self.root = osp.join(root, self.dir_name) + missing_str_list = [] + if not WITH_PANDAS: + missing_str_list.append('pandas') + if len(missing_str_list) > 0: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + if hasattr(dataset, 'get_idx_split'): + self.split_idx = dataset.get_idx_split() + elif split_idx is not None: + self.split_idx = split_idx + else: + raise ValueError("TAGDataset need split idx for generating " + "is_gold mask, please pass splited index " + "in format of dictionaty with 'train', 'valid' " + "'test' index tensor to 'split_idx'") + if text is not None and text_on_disk: + self.save_node_text(text) + self.text_on_disk = text_on_disk + # init will call download and process + super().__init__(self.root, transform=None, pre_transform=None, + pre_filter=None, force_reload=force_reload) + # after processing and download + # Dataset has to have BaseData as _data + assert dataset._data is not None + self._data = dataset._data # reassign reference + assert self._data is not None + assert dataset._data.y is not None + assert isinstance(self._data, BaseData) + assert self._data.num_nodes is not None + assert isinstance(dataset._data.num_nodes, int) + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + is_good_tensor = self.load_gold_mask() + self._is_gold = is_good_tensor.squeeze() + self._data['is_gold'] = is_good_tensor + if self.text is not None and len(self.text) != self._data.num_nodes: + raise ValueError("The number of text sequence in 'text' should be " + "equal to number of nodes!") + self.token_on_disk = token_on_disk + self.tokenize_batch_size = tokenize_batch_size + self._token = self.tokenize_graph(self.tokenize_batch_size) + self.__num_classes__ = dataset.num_classes + + @property + def num_classes(self) -> int: + return self.__num_classes__ + + @property + def raw_file_names(self) -> List[str]: + file_names = [] + for root, _, files in os.walk(osp.join(self.root, 'raw')): + for file in files: + file_names.append(file) + return file_names + + @property + def processed_file_names(self) -> List[str]: + return [ + 'geometric_data_processed.pt', 'pre_filter.pt', + 'pre_transformed.pt' + ] + + @property + def token(self) -> Dict[str, Tensor]: + if self._token is None: # lazy load + self._token = self.tokenize_graph() + return self._token + + # load is_gold after init + @property + def is_gold(self) -> Tensor: + if self._is_gold is None: + print('lazy load is_gold!!') + self._is_gold = self.load_gold_mask() + return self._is_gold + + def get_n_id(self, node_idx: IndexType) -> Tensor: + if self._n_id is None: + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + return self._n_id[node_idx] + + def load_gold_mask(self) -> Tensor: + r"""Use original train split as gold split, generating is_gold mask + for picking ground truth labels and pseudo labels. + """ + train_split_idx = self.get_idx_split()['train'] + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + is_good_tensor = torch.zeros(self._data.num_nodes, + dtype=torch.bool).view(-1, 1) + is_good_tensor[train_split_idx] = True + return is_good_tensor + + def get_gold(self, node_idx: IndexType) -> Tensor: + r"""Get gold mask for given node_idx. + + Args: + node_idx (torch.tensor): a tensor contain node idx + """ + if self._is_gold is None: + self._is_gold = self.is_gold + return self._is_gold[node_idx] + + def get_idx_split(self) -> Dict[str, Tensor]: + return self.split_idx + + def download(self) -> None: + print('downloading raw text') + raw_text_path = download_google_url(id=self.raw_text_id[self.name], + folder=f'{self.root}/raw', + filename='node-text.csv.gz', + log=True) + text_df = read_csv(raw_text_path) + self.text = list(text_df['text']) + + def process(self) -> None: + if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): + text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) + self.text = list(text_df['text']) + elif self.name in self.raw_text_id: + self.download() + else: + print('The dataset is not ogbn-products nor ogbn-arxiv,' + 'please pass in your raw text string list to `text`') + if self.text is None: + raise ValueError("The TAGDataset only have ogbn-products and " + "ogbn-arxiv raw text in default " + "The raw text of each node is not specified" + "Please pass in 'text' when convert your dataset " + "to Text Attribute Graph Dataset") + + def save_node_text(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') + if osp.exists(node_text_path): + print(f'The raw text is existed at {node_text_path}') + else: + print(f'Saving raw text file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + + def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + r"""Tokenizing the text associate with each node, running in cpu. + + Args: + batch_size (Optional[int]): batch size of list of text for + generating emebdding + Returns: + Dict[str, torch.Tensor]: tokenized graph + """ + data_len = 0 + if self.text is not None: + data_len = len(self.text) + else: + raise ValueError("The TAGDataset need text for tokenization") + token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] + path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + # Check if the .pt files already exist + token_files_exist = any( + os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) + + if token_files_exist and self.token_on_disk: + print('Found tokenized file, loading may take several minutes...') + all_encoded_token = { + k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True) + for k in token_keys + if os.path.exists(os.path.join(path, f'{k}.pt')) + } + return all_encoded_token + + all_encoded_token = {k: [] for k in token_keys} + pbar = tqdm(total=data_len) + + pbar.set_description('Tokenizing Text Attributed Graph') + for i in range(0, data_len, batch_size): + end_index = min(data_len, i + batch_size) + token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], + padding='max_length', truncation=True, + max_length=512, return_tensors="pt") + for k in token.keys(): + all_encoded_token[k].append(token[k]) + pbar.update(end_index - i) + pbar.close() + + all_encoded_token = { + k: torch.cat(v) + for k, v in all_encoded_token.items() if len(v) > 0 + } + if self.token_on_disk: + os.makedirs(path, exist_ok=True) + print('Saving tokens on Disk') + for k, tensor in all_encoded_token.items(): + torch.save(tensor, os.path.join(path, f'{k}.pt')) + print('Token saved:', os.path.join(path, f'{k}.pt')) + os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning + return all_encoded_token + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + class TextDataset(torch.utils.data.Dataset): + r"""This nested dataset provides textual data for each node in + the graph. Factory method to create TextDataset from TAGDataset. + + Args: + tag_dataset (TAGDataset): the parent dataset + """ + def __init__(self, tag_dataset: 'TAGDataset') -> None: + self.tag_dataset = tag_dataset + self.token = tag_dataset.token + assert tag_dataset._data is not None + self._data = tag_dataset._data + + assert tag_dataset._data.y is not None + self.labels = tag_dataset._data.y + + def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: + r"""This function will be called in __getitem__(). + + Args: + node_idx (IndexType): selected node idx in each batch + Returns: + items (Dict[str, Tensor]): input for LM + """ + items = {k: v[node_idx] for k, v in self.token.items()} + return items + + # for LM training + def __getitem__( + self, node_id: IndexType + ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + r"""This function will override the function in + torch.utils.data.Dataset, and will be called when you + iterate batch in the dataloader, make sure all following + key value pairs are present in the return dict. + + Args: + node_id (List[int]): list of node idx for selecting tokens, + labels etc. when iterating data loader for LM + Returns: + items (dict): input k,v pairs for Language model training and + inference + """ + item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} + item['input'] = self.get_token(node_id) + item['labels'] = self.labels[node_id] + item['is_gold'] = self.tag_dataset.get_gold(node_id) + item['n_id'] = self.tag_dataset.get_n_id(node_id) + return item + + def __len__(self) -> int: + assert self._data.num_nodes is not None + return self._data.num_nodes + + def get(self, idx: int) -> BaseData: + return self._data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + def to_text_dataset(self) -> TextDataset: + r"""Factory Build text dataset from Text Attributed Graph Dataset + each data point is node's associated text token. + """ + return TAGDataset.TextDataset(self) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 7cfadf0143b2..5860db311ac3 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,7 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever - +from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -77,4 +77,5 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'GLEM', ] diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py new file mode 100644 index 000000000000..afc8b09d77c7 --- /dev/null +++ b/torch_geometric/nn/models/glem.py @@ -0,0 +1,384 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm + +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GraphSAGE, basic_gnn + + +class GLEM(torch.nn.Module): + r"""This GNN+LM co-training model is based on GLEM from the `"Learning on + Large-scale Text-attributed Graphs via Variational Inference" + `_ paper. + + Args: + lm_to_use (str): A TextEncoder from huggingface model repo + with a classifier(default: TinyBERT) + gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) + out_channels (int): output channels for LM and GNN, should be same + num_gnn_heads Optional[int]: Number of heads for attention, if needed + num_gnn_layers (int): number of gnn layers + gnn_loss: loss function for gnn, (default: CrossEntropyLoss) + lm_loss: loss function for Language Model, (default: CrossEntropyLoss) + alpha (float): pseudo label weight of E-step, LM optimization, + (default: 0.5) + beta (float): pseudo label weight of M-step, GNN optimization, + (default: 0.5) + lm_dtype (torch.dtype): the data type once you load LM into memory, + (default: torch.bfloat16) + lm_use_lora (bool): choose if LM use Lora peft for fine tune, + (default: True) + lora_target_modules: The names of the target modules to apply the lora + adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) + + .. note:: + See `examples/llm_plus_gnn/glem.py` for example usage. + """ + def __init__( + self, + lm_to_use: str = 'prajjwal1/bert-tiny', + gnn_to_use: basic_gnn = GraphSAGE, + out_channels: int = 47, + gnn_loss=nn.CrossEntropyLoss(reduction='mean'), + lm_loss=nn.CrossEntropyLoss(reduction='mean'), + alpha: float = 0.5, + beta: float = 0.5, + lm_dtype: torch.dtype = torch.bfloat16, + lm_use_lora: bool = True, + lora_target_modules: Optional[Union[List[str], str]] = None, + device: Union[str, torch.device] = torch.device('cpu'), + ): + super().__init__() + self.device = device + self.lm_loss = lm_loss + self.gnn = gnn_to_use + self.gnn_loss = gnn_loss + self.alpha = alpha + self.beta = beta + self.gnn_loss = gnn_loss + self.lm = lm_to_use + from transformers import AutoModelForSequenceClassification + self.lm = AutoModelForSequenceClassification.from_pretrained( + lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype, + offload_folder="offload", trust_remote_code=True) + if lm_use_lora: + from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + ) + print("Training LM with LORA!") + self.lm = prepare_model_for_kbit_training(self.lm) + config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, + lora_alpha=16, lora_dropout=0.05, bias="none", + target_modules=lora_target_modules) + self.lm = get_peft_model(self.lm, config) + self.lm.print_trainable_parameters() + self.lm.config.pad_token_id = self.lm.config.eos_token_id + self.lm_device = self.lm.device + + if self.lm.num_labels != self.gnn.out_channels: + raise ValueError('''The output channel of language model \ + and gnn should be the same''') + + def pre_train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain GNN, optional steps if you do not have pseudo labels. + best_acc = 0 + early_stopping = 0 + # training only based on gold data + for epoch in range(0, num_epochs): + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, + verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def pre_train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain language model + best_acc = 0 + early_stopping = 0 + for epoch in range(1, num_epochs + 1): + acc, loss = self.train_lm(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def train(self, em_phase: str, train_loader: Union[DataLoader, + NeighborLoader], + optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, + epoch: int, is_augmented: bool = False, verbose: bool = False): + r"""GLEM training step, EM steps. + + Args: + em_phase(str): 'gnn' or 'lm' choose which phase you are training on + train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for + lm training, include tokenized data, labels is_gold mask. + use NeighborLoader for gnn training, include x, edge_index. + optimizer (torch.optim.Optimizer): optimizer for training + pseudo_labels(torch.Tensor): the predicted labels used as pseudo + labels + epoch (int): current epoch + is_augmented (bool): will use pseudo_labels or not + verbose (bool): print training progress bar or not + + Returns: + acc (float): training accuracy + loss (float): loss value + """ + pseudo_labels = pseudo_labels.to(self.device) + if em_phase == 'gnn': + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + if em_phase == 'lm': + acc, loss = self.train_lm(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + return acc, loss + + def train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""Language model Training in every epoch. + + Args: + train_loader (loader.dataloader.DataLoader): text token dataloader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn + is_augmented (bool): train with pseudo labels or not + verbose (bool): print training progress bar or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + + """ + all_out = [] + total_loss = total_correct = 0 + num_nodes = train_loader.dataset.indices.size(0) + self.lm.train() + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + for batch in train_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + out = self.lm(**inputs).logits + labels = batch['labels'].to(self.device).squeeze() + # training with pseudo labels or not + if is_augmented: + pl_batch = pseudo_labels[batch['n_id']].to(self.device) + else: + pl_batch = None + loss = self.loss(out, labels, self.lm_loss, + batch['is_gold'].to(self.device), pl_batch, + self.alpha, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + all_out.append(out) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + total_loss += float(loss) + if verbose: + pbar.update(batch['n_id'].size(0)) + + all_out = torch.cat(all_out, dim=0) + approx_acc = total_correct / num_nodes + loss = total_loss / len(train_loader) + if verbose: + pbar.close() + print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + def train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""GNN training step in every epoch. + + Args: + train_loader (loader.NeighborLoader): gnn Neighbor node loader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels(torch.tensor): 1-D tensor, predictions from lm + is_augmented(bool): use pseudo labeled node or not + verbose (bool): print training progress or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + """ + self.gnn.train() + num_nodes = train_loader.input_nodes.size(0) + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + total_loss = total_correct = 0 + all_out = [] + for batch in train_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + all_out.append(out) + labels = batch.y[:batch.batch_size].squeeze() + is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() + # training with pseudo labels or not + if is_augmented and pseudo_labels is not None: + pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] + else: + pl_batch = None + loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, + pl_batch, self.beta, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + total_loss += float(loss) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + if verbose: + pbar.update(batch.batch_size) + + all_out = torch.cat(all_out, dim=0) + loss = total_loss / len(train_loader) + approx_acc = total_correct / num_nodes + if verbose: + pbar.close() + print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + @torch.no_grad() + def inference(self, em_phase: str, data_loader: Union[NeighborLoader, + DataLoader], + verbose: bool = False): + r"""GLEM inference step. + + Args: + em_phase(str): 'gnn' or 'lm' + data_loader(dataloader or Neighborloader): + dataloader: for lm training, include tokenized data + nodeloader: for gnn training, include x, edge_index + verbose(bool): print inference progress or not + + Returns: + out (torch.Tensor): n * m tensor, m is number of classes, + n is number of nodes + """ + out = None + if em_phase == 'gnn': + self.gnn.eval() + out = self.inference_gnn(data_loader, verbose) + elif em_phase == 'lm': + self.lm.eval() + out = self.inference_lm(data_loader, verbose) + return out + + @torch.no_grad() + def inference_lm(self, data_loader: DataLoader, verbose: bool = True): + r"""LM inference step. + + Args: + data_loader (Dataloader): include token, labels, and gold mask + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, convert to pseudo labels + by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.dataset._data.num_nodes) + pbar.set_description('LM inference stage') + self.lm.eval() + preds = [] + for batch in data_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + logits = self.lm(**inputs).logits + preds.append(logits) + if verbose: + pbar.update(batch['n_id'].size(0)) + if verbose: + pbar.close() + preds = torch.cat(preds) + return preds + + @torch.no_grad() + def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): + r"""GNN inference step. + + Args: + data_loader(NeighborLoader): include x, edge_index, + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, + convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.data.num_nodes) + pbar.set_description('GNN inference stage') + preds = [] + self.gnn.eval() + for batch in data_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + preds.append(out) + if verbose: + pbar.update(batch.batch_size) + if verbose: + pbar.close() + preds = torch.cat(preds, dim=0) + return preds + + def loss(self, logits: torch.Tensor, labels: torch.Tensor, + loss_func: torch.nn.functional, is_gold: torch.Tensor, + pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, + is_augmented: bool = True): + r"""Core function of variational EM inference, this function is aming + on combining loss value on gold(original train) and loss value on + pseudo labels. + + Reference: + # noqa + + Args: + logits(torch.tensor): predict results from LM or GNN + labels(torch.tensor): combined node labels from ground truth and + pseudo labels(if provided) + loss_func(torch.nn.modules.loss): loss function for classification + is_gold(tensor): a tensor with bool value that mask ground truth + label and during training, thus ~is_gold mask pseudo labels + pseudo_labels(torch.tensor): predictions from other model + pl_weight: the pseudo labels used in E-step and M-step optimization + alpha in E-step, beta in M-step respectively + is_augmented: use EM or just train GNN and LM with gold data + + """ + def deal_nan(x): + return 0 if torch.isnan(x) else x + + if is_augmented and (sum(~is_gold) > 0): + mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) + # all other labels beside from ground truth(gold labels) + pseudo_label_loss = deal_nan( + loss_func(logits[~is_gold], pseudo_labels[~is_gold])) + loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss + else: + loss = loss_func(logits, labels) + return loss From 529237c15198d483f8a0b417cc03bc6391854cc2 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Wed, 20 Nov 2024 09:44:31 +0800 Subject: [PATCH 03/45] Add MoleculeGPT (#9710) ### Issue - #9694 - https://github.com/pyg-team/pytorch_geometric/issues/9698 ### Feature Summary - Add `MoleculeGPTDataset` - Add `MoleculeGPT` as GNN & LLM Co-training model to PyG - Add an example for training and testing - Split the PR into 3 sub-PRs (#9723, #9724, #9725) - Limited hardware resources, can't load `lmsys/vicuna-7b-v1.5`, use `TinyLlama/TinyLlama-1.1B-Chat-v0.1` instead, and the full training pipeline was not tested --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Giovanni Gatti <100698520+giovanni-gatti@users.noreply.github.com> Co-authored-by: Rishi Puri --- CHANGELOG.md | 1 + examples/llm/README.md | 9 +- examples/llm/molecule_gpt.py | 193 +++++++ test/datasets/test_molecule_gpt_dataset.py | 10 + test/nn/attention/test_qformer.py | 13 + test/nn/models/test_molecule_gpt.py | 60 +++ torch_geometric/datasets/__init__.py | 2 + .../datasets/molecule_gpt_dataset.py | 480 ++++++++++++++++++ torch_geometric/nn/attention/__init__.py | 6 +- torch_geometric/nn/attention/qformer.py | 71 +++ torch_geometric/nn/models/__init__.py | 2 + torch_geometric/nn/models/molecule_gpt.py | 222 ++++++++ torch_geometric/nn/nlp/llm.py | 2 +- .../nn/nlp/sentence_transformer.py | 3 + 14 files changed, 1068 insertions(+), 6 deletions(-) create mode 100644 examples/llm/molecule_gpt.py create mode 100644 test/datasets/test_molecule_gpt_dataset.py create mode 100644 test/nn/attention/test_qformer.py create mode 100644 test/nn/models/test_molecule_gpt.py create mode 100644 torch_geometric/datasets/molecule_gpt_dataset.py create mode 100644 torch_geometric/nn/attention/qformer.py create mode 100644 torch_geometric/nn/models/molecule_gpt.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 91da66973cef..9240420208bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) - Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) diff --git a/examples/llm/README.md b/examples/llm/README.md index e0ac02d87f2e..d860232aa56b 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,6 +1,7 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | -| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| Example | Description | +| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py new file mode 100644 index 000000000000..8f6c6024014d --- /dev/null +++ b/examples/llm/molecule_gpt.py @@ -0,0 +1,193 @@ +"""This example implements the MoleculeGPT model +(https://ai4d3.github.io/papers/34.pdf) using PyG. +""" +import argparse +import math +import os.path as osp +import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINEConv +from torch_geometric.nn.models import MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer + + +def save_params_dict(model, save_path): + state_dict = model.state_dict() + param_grad_dict = { + k: v.requires_grad + for (k, v) in model.named_parameters() + } + for k in list(state_dict.keys()): + if k in param_grad_dict.keys() and not param_grad_dict[k]: + del state_dict[k] # Delete parameters that do not require gradient + torch.save(state_dict, save_path) + + +@torch.no_grad() +def eval(model, data_loader): + model.eval() + loss = 0 + + for batch in data_loader: + batch_loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss += batch_loss.item() / len(data_loader) + return loss + + +def train( + num_epochs: int, + lr: float, + batch_size: int, + checkpointing: bool, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + # Load dataset ================================================ + path = osp.dirname(osp.realpath(__file__)) + path = osp.join(path, '..', '..', 'data', 'MoleculeGPT') + dataset = MoleculeGPTDataset(path) + train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) + train_dataset = dataset[:train_size] + val_dataset = dataset[train_size:train_size + val_size] + test_dataset = dataset[train_size + val_size:] + + seed_everything(42) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Create model =============================================== + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv( + nn=torch.nn.Sequential( + torch.nn.Linear(6, 768), + torch.nn.ReLU(), + torch.nn.Linear(768, 768), + ), + train_eps=True, + edge_dim=4, + ) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + # Train and eval ============================================ + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05, + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + # Train + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss.backward() + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss += loss.item() + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + + # Eval + val_loss = eval(model, val_loader) + print( + f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 + ) + + if checkpointing and val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + save_params_dict( + model, + f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + print(f"Total Training Time: {time.time() - start_time:2f}s") + # Test + test_loss = eval(model, test_loader) + print(f'Test loss: {test_loss:4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--checkpointing', type=bool, default=True) + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.lr, + args.batch_size, + args.checkpointing, + ) + print(f'Total Time: {time.time() - start_time:2f}s') diff --git a/test/datasets/test_molecule_gpt_dataset.py b/test/datasets/test_molecule_gpt_dataset.py new file mode 100644 index 000000000000..7c00c5efc1b6 --- /dev/null +++ b/test/datasets/test_molecule_gpt_dataset.py @@ -0,0 +1,10 @@ +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.testing import withPackage + + +@withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit') +def test_molecule_gpt_dataset(): + dataset = MoleculeGPTDataset(root='./data/MoleculeGPT') + assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})' + assert dataset.num_edge_features == 4 + assert dataset.num_node_features == 6 diff --git a/test/nn/attention/test_qformer.py b/test/nn/attention/test_qformer.py new file mode 100644 index 000000000000..0de023708fd8 --- /dev/null +++ b/test/nn/attention/test_qformer.py @@ -0,0 +1,13 @@ +import torch + +from torch_geometric.nn.attention import QFormer + + +def test_qformer(): + x = torch.randn(1, 4, 16) + attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4, + num_layers=2) + out = attn(x) + + assert out.shape == (1, 4, 32) + assert str(attn) == ('QFormer(num_heads=4, num_layers=2)') diff --git a/test/nn/models/test_molecule_gpt.py b/test/nn/models/test_molecule_gpt.py new file mode 100644 index 000000000000..c9f0a53403ee --- /dev/null +++ b/test/nn/models/test_molecule_gpt.py @@ -0,0 +1,60 @@ +import torch +from torch.nn import Linear as Lin +from torch.nn import ReLU +from torch.nn import Sequential as Seq + +from torch_geometric.nn import GINEConv, MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_molecule_gpt() -> None: + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)), + train_eps=True, edge_dim=16) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + assert str(model) == ( + 'MoleculeGPT(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' graph=GINEConv,\n' + ' smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\n' # noqa: E501 + ')') + + x = torch.randn(10, 16) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 16) + batch = torch.zeros(x.size(0), dtype=torch.long) + smiles = ['CCCCCCCCCC'] + instructions = ['What is ∼ functional related to?'] + label = ['I do not know!'] + + # Test train: + loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label) + assert loss >= 0 + + # Test inference: + pred = model.inference(x, edge_index, batch, edge_attr, smiles, + instructions) + assert len(pred) == 1 diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 0b6569d3f92b..c086a85df779 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .molecule_gpt_dataset import MoleculeGPTDataset from .tag_dataset import TAGDataset from .dbp15k import DBP15K @@ -191,6 +192,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'MoleculeGPTDataset', 'TAGDataset', ] diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py new file mode 100644 index 000000000000..b1da09f38570 --- /dev/null +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -0,0 +1,480 @@ +import gzip +import json +import multiprocessing +import os +import sys +from collections import defaultdict +from multiprocessing import Pool +from typing import Callable, List, Optional, Tuple + +import numpy as np +import requests +import torch +from tqdm import tqdm + +from torch_geometric.data import Data, InMemoryDataset, download_url +from torch_geometric.io import fs +from torch_geometric.nn.nlp import LLM +from torch_geometric.utils import one_hot + + +def clean_up_description(description: str) -> str: + description = description + " " + + # extra adj Pure + if description.startswith("Pure "): + description = description.replace("Pure ", "") + # fix typo + if description.startswith("Mercurycombines"): + description = description.replace("Mercurycombines", + "Mercury combines") + + # a special case + description = description.replace( + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ", + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ") + + # a special case + description = description.replace("5-Thymidylic acid. ", + "5-Thymidylic acid. is ") + + # a special case + description = description.replace( + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ") + + # a special case + description = description.replace( + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid. "), + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid is ")) + + # a special case + description = description.replace("5'-Uridylic acid. ", + "5'-Uridylic acid is ") + + # a special case + description = description.replace("5'-Adenylic acid, ", + "5'-Adenylic acid is ") + + # a special case + description = description.replace( + "Uridine 5'-(tetrahydrogen triphosphate). ", + "Uridine 5'-(tetrahydrogen triphosphate). is ") + + # a special case + description = description.replace("Inosine 5'-Monophosphate. ", + "Inosine 5'-Monophosphate. is ") + + # a special case + description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ", + "Pivaloyloxymethyl butyrate (AN-9) is ") + + # a special case + description = description.replace( + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ") + + # a special case + description = description.replace( + "Cardamonin (also known as Dihydroxymethoxychalcone), ", + "Cardamonin (also known as Dihydroxymethoxychalcone) is ") + + # a special case + description = description.replace("Lithium has been used to treat ", + "Lithium is ") + + # a special case + description = description.replace("4,4'-Methylenebis ", + "4,4'-Methylenebis is ") + + # a special case + description = description.replace( + "2,3,7,8-Tetrachlorodibenzo-p-dioxin", + "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ") + + # a special case + description = description.replace("Exposure to 2,4,5-trichlorophenol ", + "2,4,5-Trichlorophenol exposure ") + + index = 0 + L = len(description) + if description.startswith('C.I. '): + start_index = len('C.I. ') + elif description.startswith('Nectriapyrone. D '): + start_index = len('Nectriapyrone. D ') + elif description.startswith( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'): + start_index = len( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide') + else: + start_index = 0 + for index in range(start_index, L - 1): + if index < L - 2: + if description[index] == '.' and description[ + index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z': + break + elif index == L - 2: + break + + first_sentence = description[:index + 1] + return first_sentence + + +def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]: + first_sentence = clean_up_description(description) + + splitter = ' -- -- ' + if ' are ' in first_sentence or ' were ' in first_sentence: + replaced_words = 'These molecules' + else: + replaced_words = 'This molecule' + + first_sentence = first_sentence.replace(' is ', splitter) + first_sentence = first_sentence.replace(' are ', splitter) + first_sentence = first_sentence.replace(' was ', splitter) + first_sentence = first_sentence.replace(' were ', splitter) + first_sentence = first_sentence.replace(' appears ', splitter) + first_sentence = first_sentence.replace(' occurs ', splitter) + first_sentence = first_sentence.replace(' stands for ', splitter) + first_sentence = first_sentence.replace(' belongs to ', splitter) + first_sentence = first_sentence.replace(' exists ', + splitter) # only for CID=11443 + first_sentence = first_sentence.replace(' has been used in trials ', + splitter) + first_sentence = first_sentence.replace(' has been investigated ', + splitter) + first_sentence = first_sentence.replace(' has many uses ', splitter) + + if splitter in first_sentence: + extracted_name = first_sentence.split(splitter, 1)[0] + elif first_sentence.startswith(name_raw): + extracted_name = name_raw + elif name_raw in first_sentence: + extracted_name = name_raw + extracted_name = None + print("=====", name_raw) + print("first sentence: ", first_sentence) + else: + extracted_name = None + + if extracted_name is not None: + extracted_description = description.replace(extracted_name, + replaced_words) + else: + extracted_description = description + + return extracted_name, extracted_description, first_sentence + + +class MoleculeGPTDataset(InMemoryDataset): + r"""The dataset from the `"MoleculeGPT: Instruction Following Large + Language Models for Molecular Property Prediction" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + total_page_num (int, optional): The number of pages from PubChem. + (default: :obj:`10`) + total_block_num (int, optional): The blocks of SDF files from PubChem. + (default: :obj:`1`) + """ + description_url = ( + 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/' + 'heading/json?heading_type=Compound&heading=Record+Description&page={}' + ) + compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/' + 'CURRENT-Full/SDF') + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + total_page_num: int = 10, + total_block_num: int = 1, + ): + self.total_page_num = total_page_num + self.total_block_num = total_block_num + + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['pubchem.csv'] + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + # Step 01. Extract description + step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description" + if not os.path.exists(step1_folder): + os.makedirs(step1_folder) + valid_CID_set = set() + CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict( + list) + CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict( + list) + + for page_index in tqdm(range(self.total_page_num)): + page_num = page_index + 1 + f_out = open( + f"{step1_folder}/Compound_description_{page_num}.txt", "w") + + description_data = requests.get( + self.description_url.format(page_num)).json() + + description_data = description_data["Annotations"] + assert description_data["Page"] == page_num + + record_list = description_data["Annotation"] + + for record in record_list: + try: + CID = record["LinkedRecords"]["CID"][0] + if "Name" in record: + name_raw = record["Name"] + CID2name_raw[CID].append(name_raw) + else: + name_raw = None + + data_list = record["Data"] + for data in data_list: + description = data["Value"]["StringWithMarkup"][0][ + "String"].strip() + + extracted_name, extracted_description, _ = extract_name( # noqa: E501 + name_raw, description) + if extracted_name is not None: + CID2name_extracted[CID].append(extracted_name) + + CID2text_raw[CID].append(description) + CID2text_extracted[CID].append( + extracted_description) + + valid_CID_set.add(CID) + f_out.write(f"{CID}\n") + f_out.write(f"{extracted_description}\n\n") + except Exception: + continue + + valid_CID_list = sorted(list(valid_CID_set)) + print(f"Total CID (with raw name) {len(CID2name_raw)}") + print(f"Total CID (with extracted name) {len(CID2name_extracted)}") + print(f"Total CID {len(valid_CID_list)}") + + with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f: + json.dump(CID2name_raw, f) + + with open(f"{self.raw_dir}/CID2name.json", "w") as f: + json.dump(CID2name_extracted, f) + + with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f: + json.dump(CID2text_raw, f) + + with open(f"{self.raw_dir}/CID2text.json", "w") as f: + json.dump(CID2text_extracted, f) + + # Step 02. Download SDF Files + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + if not os.path.exists(step2_folder): + for block_id in tqdm(range(self.total_block_num)): + block_size = 500000 + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + download_url(f"{self.compound_url}/{compound_file_name}", + step2_folder) + + def process(self, use_mp: bool = False) -> None: + try: + from rdkit import Chem + from rdkit.Chem.rdchem import BondType as BT + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + # Step 03. Filter out SDF + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered" + if not os.path.exists(step3_folder): + os.makedirs(step3_folder) + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + + block_size = 500000 + + def extract_one_SDF_file(block_id: int) -> None: + valid_mol_count = 0 + + writer = Chem.SDWriter( + f'{step3_folder}/filtered_{block_id}.sdf') + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}") + suppl = Chem.ForwardSDMolSupplier(gzip_loader) + + for mol in tqdm(suppl): + if mol is None: + continue + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + + if cid not in target_CID_list: + continue + + writer.write(mol) + valid_mol_count += 1 + + print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") + sys.stdout.flush() + return + + if use_mp: + num_process = multiprocessing.cpu_count() + print(f"{num_process} CPUs") + num_process = 8 + p = Pool(num_process) + + block_id_list = np.arange(self.total_block_num) + with p: + p.map(extract_one_SDF_file, block_id_list) + else: + for block_id in range(self.total_block_num): + extract_one_SDF_file(block_id) + + # Step 04. Merge SDF + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + print(f'The length of target_CID_list: {len(target_CID_list)}') + + writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf') + + found_CID_set = set() + for block_id in range(self.total_block_num + 1): + compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf" + try: + suppl = Chem.SDMolSupplier(compound_file_path) + + for mol in tqdm(suppl): + writer.write(mol) + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + found_CID_set.add(cid) + except Exception: + print(f"block id: {block_id} with 0 valid SDF file") + continue + + print(f"In total: {len(found_CID_set)} molecules") + + # Step 05. Convert to PyG data format + types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + data_list = [] + # Real data + CID2text_file = f'{self.raw_dir}/CID2text.json' + + with open(CID2text_file) as f: + CID2text_data = json.load(f) + + suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf') + + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + prompt = ("Propose a question regarding the molecule '∼' " + "whose answer is: {}:") + for mol in tqdm(suppl): + if mol.HasProp('PUBCHEM_COMPOUND_CID'): + CID = mol.GetProp("PUBCHEM_COMPOUND_CID") + CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES") + + m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES) + if m is None: + continue + RDKit_CAN_SMILES = Chem.MolToSmiles(m) + + ground_truth = CID2text_data[CID][0] + + instruction = llm.inference([prompt.format(ground_truth)])[0] + + x: torch.Tensor = torch.tensor([ + types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 + for atom in m.GetAtoms() # type: ignore + ]) + x = one_hot(x, num_classes=len(types), dtype=torch.float) + + rows, cols, edge_types = [], [], [] + for bond in m.GetBonds(): # type: ignore + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_types += [bonds[bond.GetBondType()]] * 2 + rows += [i, j] + cols += [j, i] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_type = torch.tensor(edge_types, dtype=torch.long) + edge_attr = one_hot(edge_type, num_classes=len(bonds)) + + data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + smiles=RDKit_CAN_SMILES, + instruction=instruction, + y=ground_truth, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/torch_geometric/nn/attention/__init__.py b/torch_geometric/nn/attention/__init__.py index 947d5850173b..6b4064cd34b9 100644 --- a/torch_geometric/nn/attention/__init__.py +++ b/torch_geometric/nn/attention/__init__.py @@ -1,3 +1,7 @@ from .performer import PerformerAttention +from .qformer import QFormer -__all__ = ['PerformerAttention'] +__all__ = [ + 'PerformerAttention', + 'QFormer', +] diff --git a/torch_geometric/nn/attention/qformer.py b/torch_geometric/nn/attention/qformer.py new file mode 100644 index 000000000000..3a8f512d3f83 --- /dev/null +++ b/torch_geometric/nn/attention/qformer.py @@ -0,0 +1,71 @@ +from typing import Callable + +import torch + + +class QFormer(torch.nn.Module): + r"""The Querying Transformer (Q-Former) from + `"BLIP-2: Bootstrapping Language-Image Pre-training + with Frozen Image Encoders and Large Language Models" + `_ paper. + + Args: + input_dim (int): The number of features in the input. + hidden_dim (int): The dimension of the fnn in the encoder layer. + output_dim (int): The final output dimension. + num_heads (int): The number of multi-attention-heads. + num_layers (int): The number of sub-encoder-layers in the encoder. + dropout (int): The dropout value in each encoder layer. + + + .. note:: + This is a simplified version of the original Q-Former implementation. + """ + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_heads: int, + num_layers: int, + dropout: float = 0.0, + activation: Callable = torch.nn.ReLU(), + ) -> None: + + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + + self.layer_norm = torch.nn.LayerNorm(input_dim) + self.encoder_layer = torch.nn.TransformerEncoderLayer( + d_model=input_dim, + nhead=num_heads, + dim_feedforward=hidden_dim, + dropout=dropout, + activation=activation, + batch_first=True, + ) + self.encoder = torch.nn.TransformerEncoder( + self.encoder_layer, + num_layers=num_layers, + ) + self.project = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""Forward pass. + + Args: + x (torch.Tensor): Input sequence to the encoder layer. + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with + batch-size :math:`B`, sequence length :math:`N`, + and feature dimension :math:`F`. + """ + x = self.layer_norm(x) + x = self.encoder(x) + out = self.project(x) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'num_heads={self.num_heads}, ' + f'num_layers={self.num_layers})') diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 5860db311ac3..9aeac020264a 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,6 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever +from .molecule_gpt import MoleculeGPT from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, @@ -77,5 +78,6 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'MoleculeGPT', 'GLEM', ] diff --git a/torch_geometric/nn/models/molecule_gpt.py b/torch_geometric/nn/models/molecule_gpt.py new file mode 100644 index 000000000000..a0ac73ad9abb --- /dev/null +++ b/torch_geometric/nn/models/molecule_gpt.py @@ -0,0 +1,222 @@ +from typing import List, Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.attention import QFormer +from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS +from torch_geometric.utils import to_dense_batch + + +def pad_or_truncate(embeddings: Tensor, max_seq_len: int, + padding_value: int = 0) -> Tensor: + batch_size, current_seq_len, d = embeddings.size() + + if current_seq_len > max_seq_len: + return embeddings[:, :max_seq_len, :] + elif current_seq_len < max_seq_len: + pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d), + padding_value, dtype=embeddings.dtype, + device=embeddings.device) + return torch.cat([embeddings, pad_tensor], dim=1) + else: + return embeddings + + +class MoleculeGPT(torch.nn.Module): + r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction + Following Large Language Models for Molecular Property Prediction" + `_ paper. + + Args: + llm (LLM): The LLM to use. + graph_encoder (torch.nn.Module): Encode 2D molecule graph. + smiles_encoder (torch.nn.Module): Encode 1D SMILES. + mlp_out_channels (int, optional): The size of each embedding + after qformer encoding. (default: :obj:`32`) + max_tokens (int, optional): Max output tokens of 1D/2D encoder. + (default: :obj:`20`) + + .. warning:: + This module has been tested with the following HuggingFace models + + * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"` + + and may not work with other models. See other models at `HuggingFace + Models `_ and let us know if you + encounter any issues. + + .. note:: + For an example of using :class:`MoleculeGPT`, see + `examples/llm/molecule_gpt.py `_. + """ + def __init__( + self, + llm: LLM, + graph_encoder: torch.nn.Module, + smiles_encoder: torch.nn.Module, + mlp_out_channels: int = 32, + max_tokens: Optional[int] = 20, + ) -> None: + super().__init__() + self.llm = llm + self.graph_encoder = graph_encoder.to(self.llm.device) + self.smiles_encoder = smiles_encoder.to(self.llm.device) + + self.graph_qformer = QFormer( + input_dim=self.graph_encoder.nn[-1].out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.smiles_qformer = QFormer( + input_dim=self.smiles_encoder.model.pooler.dense.out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.max_tokens = max_tokens + + self.word_embedding = self.llm.word_embedding + self.llm_generator = self.llm.llm + + # LLMs + in_dim = 2 * mlp_out_channels * max_tokens + out_dim = self.llm.llm.model.embed_tokens.embedding_dim + self.projector = torch.nn.Sequential( + torch.nn.Linear(in_dim, in_dim), + torch.nn.Sigmoid(), + torch.nn.Linear(in_dim, out_dim), + ).to(self.llm.device) + + def encode( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + ) -> Tensor: + batch_size = len(smiles) + # 2D Graph Branch: [bs, node_len, d] + x = x.to(self.llm.device) + edge_index = edge_index.to(self.llm.device) + if edge_attr is not None: + edge_attr = edge_attr.to(self.llm.device) + batch = batch.to(self.llm.device) + + x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr) + x_graph = to_dense_batch(x_graph, batch)[0] + out_graph = self.graph_qformer(x_graph) + out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens, + padding_value=0) + out_graph = out_graph.view(batch_size, -1) + + # 1D SMILES Branch: [bs, seq_len, d] + x_smiles = self.smiles_encoder.encode(smiles, + output_device=self.llm.device) + out_smiles = self.smiles_qformer(x_smiles) + out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens, + padding_value=0) + out_smiles = out_smiles.view(batch_size, -1) + + # Merge into LLMs + x_cat = torch.cat([out_graph, out_smiles], dim=1) + return x_cat + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + label: List[str], + additional_text_context: Optional[List[str]] = None, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + ( + inputs_embeds, + attention_mask, + label_input_ids, + ) = self.llm._get_embeds(instructions, additional_text_context, xs, + label) + + with self.llm.autocast_context: + outputs = self.llm_generator( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=label_input_ids, + ) + + return outputs.loss + + @torch.no_grad() + def inference( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + additional_text_context: Optional[List[str]] = None, + max_out_tokens: Optional[int] = MAX_NEW_TOKENS, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + # Handle questions without node features: + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + inputs_embeds, attention_mask, _ = self.llm._get_embeds( + instructions, additional_text_context, xs) + + bos_token = self.llm.tokenizer( + BOS, + add_special_tokens=False, + ).input_ids[0] + + with self.llm.autocast_context: + outputs = self.llm_generator.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_out_tokens, + attention_mask=attention_mask, + bos_token_id=bos_token, + use_cache=True # Important to set! + ) + + return self.llm.tokenizer.batch_decode( + outputs, + skip_special_tokens=True, + ) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' llm={self.llm},\n' + f' graph={self.graph_encoder.__class__.__name__},\n' + f' smiles={self.smiles_encoder},\n' + f')') diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index b58059f8e098..d18aa42382f7 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -56,7 +56,7 @@ class LLM(torch.nn.Module): allocate the correct number of GPUs needed, given the available GPU memory of your GPUs. dtype (torch.dtype, optional): The data type to use for the LLM. - (default :obj: `torch.bloat16`) + (default :obj: `torch.bfloat16`) """ def __init__( self, diff --git a/torch_geometric/nn/nlp/sentence_transformer.py b/torch_geometric/nn/nlp/sentence_transformer.py index c66677e8fa24..715f343bfc19 100644 --- a/torch_geometric/nn/nlp/sentence_transformer.py +++ b/torch_geometric/nn/nlp/sentence_transformer.py @@ -10,6 +10,7 @@ class PoolingStrategy(Enum): MEAN = 'mean' LAST = 'last' CLS = 'cls' + LAST_HIDDEN_STATE = 'last_hidden_state' class SentenceTransformer(torch.nn.Module): @@ -38,6 +39,8 @@ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: emb = mean_pooling(emb, attention_mask) elif self.pooling_strategy == PoolingStrategy.LAST: emb = last_pooling(emb, attention_mask) + elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE: + emb = out.last_hidden_state else: assert self.pooling_strategy == PoolingStrategy.CLS emb = emb[:, 0, :] From 83485e3cad23c8e93e399314a305463d2bc94c52 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Wed, 20 Nov 2024 09:36:53 -0800 Subject: [PATCH 04/45] Add comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo (#9797) --- CHANGELOG.md | 1 + examples/llm/g_retriever.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9240420208bb..d7f1db3c8748 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) - Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) - Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 1735d17f5249..984ce3f010e7 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -6,6 +6,9 @@ Requirements: `pip install datasets transformers pcst_fast sentencepiece accelerate` + +Example repo for integration with Neo4j Graph DB: +https://github.com/neo4j-product-examples/neo4j-gnn-llm-example """ import argparse import math From f73242708262640c644c539e4eae6ea30ad3f477 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 25 Nov 2024 13:39:08 +0800 Subject: [PATCH 05/45] Add GIT-Mol (#9730) ### Issue - #9694 - https://github.com/pyg-team/pytorch_geometric/issues/9700 ### Feature Summary - Add `GitMolDataset` - Add `GITMol` as GNN & LLM Co-training model to PyG - Add an example for pre-training - Limited hardware resources, so the full training pipeline was not tested - Multi modal cross attention shares the same weight, not aligned with the original paper --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishi Puri --- CHANGELOG.md | 1 + examples/llm/README.md | 1 + examples/llm/git_mol.py | 133 +++++++ test/datasets/test_git_mol_dataset.py | 19 + test/nn/models/test_git_mol.py | 24 ++ test/nn/nlp/test_vision_transformer.py | 26 ++ torch_geometric/datasets/__init__.py | 2 + torch_geometric/datasets/git_mol_dataset.py | 263 ++++++++++++++ torch_geometric/nn/models/__init__.py | 2 + torch_geometric/nn/models/git_mol.py | 336 ++++++++++++++++++ torch_geometric/nn/nlp/__init__.py | 2 + .../nn/nlp/sentence_transformer.py | 30 ++ torch_geometric/nn/nlp/vision_transformer.py | 33 ++ 13 files changed, 872 insertions(+) create mode 100644 examples/llm/git_mol.py create mode 100644 test/datasets/test_git_mol_dataset.py create mode 100644 test/nn/models/test_git_mol.py create mode 100644 test/nn/nlp/test_vision_transformer.py create mode 100644 torch_geometric/datasets/git_mol_dataset.py create mode 100644 torch_geometric/nn/models/git_mol.py create mode 100644 torch_geometric/nn/nlp/vision_transformer.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d7f1db3c8748..86e06d8835e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730)) - Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) - Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) - Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) diff --git a/examples/llm/README.md b/examples/llm/README.md index d860232aa56b..eb471563de8e 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -3,5 +3,6 @@ | Example | Description | | -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | | [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | | [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | diff --git a/examples/llm/git_mol.py b/examples/llm/git_mol.py new file mode 100644 index 000000000000..d05104db050c --- /dev/null +++ b/examples/llm/git_mol.py @@ -0,0 +1,133 @@ +"""This example implements the GIT-Mol model +(https://arxiv.org/abs/2308.06911) using PyG. +""" +import argparse +import os.path as osp + +import torch +from accelerate import Accelerator +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.datasets import GitMolDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn.models import GITMol + + +@torch.no_grad() +def eval(model, data_loader): + model.eval() + loss = 0 + + for batch in data_loader: + batch_loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.image, + batch.caption) + loss += batch_loss.item() / len(data_loader) + return loss + + +def train( + num_epochs: int, + lr: float, + weight_decay: float, + batch_size: int, + checkpointing: bool, +): + # Load dataset ================================================ + path = osp.dirname(osp.realpath(__file__)) + path = osp.join(path, '..', '..', 'data', 'GITMol') + train_dataset = GitMolDataset(path, split=0) + val_dataset = GitMolDataset(path, split=1) + test_dataset = GitMolDataset(path, split=2) + + seed_everything(42) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Create model =============================================== + accelerator = Accelerator() + device = accelerator.device + model = GITMol().to(device) + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=lr, + weight_decay=weight_decay) + scheduler = StepLR(optimizer, step_size=1, gamma=0.1) + model, optimizer, train_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, scheduler) + val_loader = accelerator.prepare_data_loader(val_loader, + device_placement=True) + test_loader = accelerator.prepare_data_loader(test_loader, + device_placement=True) + + # Train and eval ============================================ + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + # Train + model.train() + epoch_loss = 0 + if epoch == 0: + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + + for batch in tqdm(train_loader, desc=epoch_str): + optimizer.zero_grad() + loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.image, + batch.caption) + accelerator.backward(loss) + + optimizer.step() + epoch_loss += loss.item() + + train_loss = epoch_loss / len(train_loader) + + # Eval + val_loss = eval(model, val_loader) + print( + f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 + ) + + if checkpointing and val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + torch.save( + { + 'model_state_dict': + accelerator.unwrap_model(model).state_dict(), + 'best_loss': + best_val_loss + }, + f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + # Test + test_loss = eval(model, test_loader) + print(f'Test loss: {test_loss:4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument('--checkpointing', type=bool, default=True) + args = parser.parse_args() + + train( + args.epochs, + args.lr, + args.weight_decay, + args.batch_size, + args.checkpointing, + ) diff --git a/test/datasets/test_git_mol_dataset.py b/test/datasets/test_git_mol_dataset.py new file mode 100644 index 000000000000..3da72f2f2182 --- /dev/null +++ b/test/datasets/test_git_mol_dataset.py @@ -0,0 +1,19 @@ +import pytest + +from torch_geometric.datasets import GitMolDataset +from torch_geometric.testing import withPackage + + +@withPackage('torchvision', 'rdkit', 'PIL') +@pytest.mark.parametrize('split', [ + (0, 3610), + (1, 451), + (2, 451), +]) +def test_git_mol_dataset(split): + dataset = GitMolDataset(root='./data/GITMol', split=split[0]) + + assert len(dataset) == split[1] + assert dataset[0].image.size() == (1, 3, 224, 224) + assert dataset[0].num_node_features == 9 + assert dataset[0].num_edge_features == 3 diff --git a/test/nn/models/test_git_mol.py b/test/nn/models/test_git_mol.py new file mode 100644 index 000000000000..ee557bfaa9fc --- /dev/null +++ b/test/nn/models/test_git_mol.py @@ -0,0 +1,24 @@ +import torch + +from torch_geometric.nn.models import GITMol +from torch_geometric.testing import withPackage + + +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_git_mol(): + model = GITMol() + + x = torch.ones(10, 16, dtype=torch.long) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], + ]) + edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long) + batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] + captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] + images = torch.randn(1, 3, 224, 224) + + # Test train: + loss = model(x, edge_index, batch, edge_attr, smiles, images, captions) + assert loss >= 0 diff --git a/test/nn/nlp/test_vision_transformer.py b/test/nn/nlp/test_vision_transformer.py new file mode 100644 index 000000000000..7500ebc7fd0e --- /dev/null +++ b/test/nn/nlp/test_vision_transformer.py @@ -0,0 +1,26 @@ +import torch + +from torch_geometric.nn.nlp import VisionTransformer +from torch_geometric.testing import onlyFullTest, withCUDA, withPackage + + +@withCUDA +@onlyFullTest +@withPackage('transformers') +def test_vision_transformer(device): + model = VisionTransformer( + model_name='microsoft/swin-base-patch4-window7-224', ).to(device) + assert model.device == device + assert str( + model + ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)' + + images = torch.randn(2, 3, 224, 224).to(device) + + out = model(images) + assert out.device == device + assert out.size() == (2, 49, 1024) + + out = model(images, output_device='cpu') + assert out.is_cpu + assert out.size() == (2, 49, 1024) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index c086a85df779..12895ad1dbac 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset from .tag_dataset import TAGDataset @@ -192,6 +193,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'GitMolDataset', 'MoleculeGPTDataset', 'TAGDataset', ] diff --git a/torch_geometric/datasets/git_mol_dataset.py b/torch_geometric/datasets/git_mol_dataset.py new file mode 100644 index 000000000000..4b7cfa78117c --- /dev/null +++ b/torch_geometric/datasets/git_mol_dataset.py @@ -0,0 +1,263 @@ +import sys +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_google_url, + extract_zip, +) +from torch_geometric.io import fs + + +def safe_index(lst: List[Any], e: int) -> int: + return lst.index(e) if e in lst else len(lst) - 1 + + +class GitMolDataset(InMemoryDataset): + r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model + for Molecular Science with Graph, Image, and Text" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + split (int, optional): Datasets split, train/valid/test=0/1/2. + (default: :obj:`0`) + """ + + raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg' + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + split: int = 0, + ): + from torchvision import transforms + + self.split = split + + if self.split == 0: + self.img_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomRotation(15), + transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + else: + self.img_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl'] + + @property + def processed_file_names(self) -> str: + return ['train.pt', 'valid.pt', 'test.pt'][self.split] + + def download(self) -> None: + file_path = download_google_url( + self.raw_url_id, + self.raw_dir, + 'gitmol.zip', + ) + extract_zip(file_path, self.raw_dir) + + def process(self) -> None: + import pandas as pd + from PIL import Image + + try: + from rdkit import Chem, RDLogger + RDLogger.DisableLog('rdApp.*') # type: ignore + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + allowable_features: Dict[str, List[Any]] = { + 'possible_atomic_num_list': + list(range(1, 119)) + ['misc'], + 'possible_formal_charge_list': + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], + 'possible_chirality_list': [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER + ], + 'possible_hybridization_list': [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc' + ], + 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], + 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], + 'possible_is_aromatic_list': [False, True], + 'possible_is_in_ring_list': [False, True], + 'possible_bond_type_list': [ + Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, + Chem.rdchem.BondType.ZERO + ], + 'possible_bond_dirs': [ # only for double bond stereo information + Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT + ], + 'possible_bond_stereo_list': [ + Chem.rdchem.BondStereo.STEREONONE, + Chem.rdchem.BondStereo.STEREOZ, + Chem.rdchem.BondStereo.STEREOE, + Chem.rdchem.BondStereo.STEREOCIS, + Chem.rdchem.BondStereo.STEREOTRANS, + Chem.rdchem.BondStereo.STEREOANY, + ], + 'possible_is_conjugated_list': [False, True] + } + + data = pd.read_pickle( + f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}') + + data_list = [] + for _, r in tqdm(data.iterrows(), total=data.shape[0]): + smiles = r['isosmiles'] + mol = Chem.MolFromSmiles(smiles.strip('\n')) + if mol is not None: + # text + summary = r['summary'] + # image + cid = r['cid'] + img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png' + img = Image.open(img_file).convert('RGB') + img = self.img_transform(img).unsqueeze(0) + # graph + atom_features_list = [] + for atom in mol.GetAtoms(): # type: ignore + atom_feature = [ + safe_index( + allowable_features['possible_atomic_num_list'], + atom.GetAtomicNum()), + allowable_features['possible_chirality_list'].index( + atom.GetChiralTag()), + safe_index(allowable_features['possible_degree_list'], + atom.GetTotalDegree()), + safe_index( + allowable_features['possible_formal_charge_list'], + atom.GetFormalCharge()), + safe_index(allowable_features['possible_numH_list'], + atom.GetTotalNumHs()), + safe_index( + allowable_features[ + 'possible_number_radical_e_list'], + atom.GetNumRadicalElectrons()), + safe_index( + allowable_features['possible_hybridization_list'], + atom.GetHybridization()), + allowable_features['possible_is_aromatic_list'].index( + atom.GetIsAromatic()), + allowable_features['possible_is_in_ring_list'].index( + atom.IsInRing()), + ] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), + dtype=torch.long) + + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): # type: ignore + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_feature = [ + safe_index( + allowable_features['possible_bond_type_list'], + bond.GetBondType()), + allowable_features['possible_bond_stereo_list'].index( + bond.GetStereo()), + allowable_features['possible_is_conjugated_list']. + index(bond.GetIsConjugated()), + ] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + edge_index = torch.tensor( + np.array(edges_list).T, + dtype=torch.long, + ) + edge_attr = torch.tensor( + np.array(edge_features_list), + dtype=torch.long, + ) + + data = Data( + x=x, + edge_index=edge_index, + smiles=smiles, + edge_attr=edge_attr, + image=img, + caption=summary, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9aeac020264a..9ade58cebc05 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,6 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever +from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM # Deprecated: @@ -78,6 +79,7 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'GITMol', 'MoleculeGPT', 'GLEM', ] diff --git a/torch_geometric/nn/models/git_mol.py b/torch_geometric/nn/models/git_mol.py new file mode 100644 index 000000000000..c06b44671931 --- /dev/null +++ b/torch_geometric/nn/models/git_mol.py @@ -0,0 +1,336 @@ +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential + +from torch_geometric.nn import GINEConv +from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer +from torch_geometric.utils import add_self_loops, to_dense_batch + + +class GraphEncoder(torch.nn.Module): + def __init__( + self, + num_layers: int, + in_channels: int, + dropout: float = 0., + num_atom_type: int = 120, + num_chirality_tag: int = 3, + num_bond_type: int = 6, + num_bond_direction: int = 3, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.dropout = dropout + + self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels) + self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels) + self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels) + self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels) + + self.gnns = torch.nn.ModuleList() + self.batch_norms = torch.nn.ModuleList() + for _ in range(num_layers): + self.gnns.append( + GINEConv( + nn=Sequential( + Linear(in_channels, in_channels * 2), + ReLU(), + Linear(in_channels * 2, in_channels), + ), + train_eps=True, + edge_dim=in_channels, + )) + self.batch_norms.append(BatchNorm1d(in_channels)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.x_embed1.weight.data) + torch.nn.init.xavier_uniform_(self.x_embed2.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data) + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Tensor, + ) -> Tensor: + x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long()) + edge_index, edge_attr = add_self_loops( + edge_index, + edge_attr, + fill_value=0, + num_nodes=x.size(0), + ) + edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2( + edge_attr[:, 1]) + for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)): + x = gnn(x, edge_index, edge_attr) + x = bn(x) + if i < self.num_layers - 1: + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + + x, mask = to_dense_batch(x, batch) + return x, mask + + +class GITFormer(torch.nn.Module): + def __init__( + self, + num_query_token: int, + vision_graph_width: int, + cross_attention_freq: int = 2, + ): + super().__init__() + from transformers import AutoConfig, AutoModel + + config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased") + config.encoder_width = vision_graph_width + # insert cross-attention layer every other block + config.add_cross_attention = True + config.is_decoder = True + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + self.Qformer = AutoModel.from_pretrained( + "allenai/scibert_scivocab_uncased", config=config) + self.query_tokens = torch.nn.Parameter( + torch.zeros(1, num_query_token, config.hidden_size)) + self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range) + + +class GITMol(torch.nn.Module): + r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language + Model for Molecular Science with Graph, Image, and Text" + `_ paper. + + .. note:: + For an example of using :class:`GITMol`, see + `examples/llm/git_mol.py `_. + """ + def __init__(self) -> None: + super().__init__() + # graph + self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16) + self.graph_proj = Linear(16, 768) + self.ln_graph = LayerNorm(768) + # text + self.text_encoder = SentenceTransformer( + model_name='allenai/scibert_scivocab_uncased', + pooling_strategy='last_hidden_state', + ) + self.text_proj = Linear(768, 768) + self.ln_text = LayerNorm(768) + # vision + self.vision_encoder = VisionTransformer( + model_name='microsoft/swin-base-patch4-window7-224', ) + self.vision_proj = Linear(1024, 768) + self.ln_vision = LayerNorm(768) + # cross-attention + self.gitformer = GITFormer(384, 768) + + self.xtm_head = torch.nn.ModuleDict({ + 'image': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + 'graph': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + 'cs_text': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + }) + + self.xtc_proj = torch.nn.ModuleDict({ + 'image': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + 'graph': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + 'cs_text': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + }) + self.temp = torch.nn.Parameter(0.07 * torch.ones([])) + self.model_freeze() + + def model_freeze(self) -> None: + for param in self.graph_encoder.parameters(): + param.requires_grad = False + + for param in self.vision_encoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + images: Tensor, + captions: List[str], + ) -> Tensor: + batch_size = len(smiles) + + x_vision = self.vision_encoder(images) + x_vision = self.vision_proj(x_vision) + x_vision = self.ln_vision(x_vision) # [bs, patch_len, d] + vision_atts = torch.ones(x_vision.size()[:-1], + dtype=torch.long).to(x_vision.device) + vision_targets = torch.arange(batch_size).to(x_vision.device) + + x_graph, graph_atts = self.graph_encoder(x, edge_index, batch, + edge_attr) + x_graph = self.graph_proj(x_graph) + x_graph = self.ln_graph(x_graph) # [bs, node_len, d] + graph_targets = torch.arange(batch_size).to(x_graph.device) + + x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d] + smiles_atts = torch.ones(x_smiles.size()[:-1], + dtype=torch.long).to(x_smiles.device) + smiles_targets = torch.arange(batch_size).to(x_smiles.device) + + caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501 + captions) + + text_output = self.gitformer.Qformer( + caption_input_ids, + attention_mask=caption_attention_masks, + return_dict=True, + ) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1) + + loss = 0 + for x_embed, x_atts, x_targets, modal in zip( + [x_graph, x_smiles, x_vision], + [graph_atts, smiles_atts, vision_atts], + [graph_targets, smiles_targets, vision_targets], + ['graph', 'cs_text', 'image'], + ): + loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat, + modal) + loss += self._calc_xtm_loss(x_embed, caption_input_ids, + caption_attention_masks, modal) + + return loss / 6 + + def _calc_xtm_loss( + self, + x_embeds: Tensor, + input_ids: Tensor, + attention_mask: Tensor, + modal: str, + ) -> Tensor: + # Initializing lists to hold the original and negative samples + x_embeds_list = [] + text_input_ids_list = [] + text_attention_mask_list = [] + + batch_size = x_embeds.size(0) + for i in range(batch_size): + # Original samples + x_embeds_list.append(x_embeds[i]) + text_input_ids_list.append(input_ids[i, :]) + text_attention_mask_list.append(attention_mask[i, :]) + + if batch_size > 1: + # Negative samples (neg_text_input_ids corresponds to x_embeds) + neg_text_input_ids = input_ids[i - 1 if i == batch_size - + 1 else i + 1, :] + neg_text_attention_mask = attention_mask[i - + 1 if i == batch_size - + 1 else i + 1, :] + text_input_ids_list.append(neg_text_input_ids) + text_attention_mask_list.append(neg_text_attention_mask) + x_embeds_list.append(x_embeds[i, :]) + + # Negative samples (text_input_ids corresponds to neg_x_embeds) + neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i + + 1, :] + x_embeds_list.append(neg_x_embeds) + text_input_ids_list.append(input_ids[i, :]) + text_attention_mask_list.append(attention_mask[i, :]) + + # Stack all samples into two large tensors + x_embeds_all = torch.stack(x_embeds_list, dim=1) \ + .reshape(-1, x_embeds.size(1), x_embeds.size(2)) + text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \ + .reshape(-1, input_ids.size(1)) + # Create image attention masks for the concatenated tensor + image_attns_all = torch.ones(x_embeds_all.size()[:-1], + dtype=torch.long).to(x_embeds_all.device) + query_tokens_xtm = self.gitformer.query_tokens.expand( + text_input_ids_all.shape[0], -1, -1) + query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1], + dtype=torch.long).to(x_embeds_all.device) + + output_xtm = self.gitformer.Qformer( + inputs_embeds=query_tokens_xtm, + attention_mask=query_attns_xtm, + encoder_hidden_states=x_embeds_all, + encoder_attention_mask=image_attns_all, + return_dict=True, + ).last_hidden_state + + xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :] + + xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1) + # Create labels: 1 for the original samples, 0 for the negative samples + if batch_size > 1: + labels = torch.cat( + [torch.ones(batch_size), + torch.zeros(batch_size * 2)], dim=0) + else: + labels = torch.ones(batch_size) + labels = labels.long().to(xtm_logit.device) + + # Calculate cross entropy loss + return F.cross_entropy(xtm_logit, labels) + + def _calc_xtc_loss( + self, + x_embeds: Tensor, + x_atts: Tensor, + x_targets: Tensor, + text_feat: Tensor, + modal: str, + ) -> Tensor: + query_tokens = self.gitformer.query_tokens.expand( + x_embeds.shape[0], -1, -1) + + query_output = self.gitformer.Qformer( + inputs_embeds=query_tokens, + encoder_hidden_states=x_embeds, + encoder_attention_mask=x_atts, + return_dict=True, + ).last_hidden_state + + x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1) + + sim_q2t = torch.matmul( + x_feats.unsqueeze(1), + text_feat.unsqueeze(-1), + ).squeeze(-1) + + # modal-text similarity: aggregate across all query tokens + sim_x2t, _ = sim_q2t.max(-1) + sim_x2t = sim_x2t / self.temp + + # text-query similarity + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), + x_feats.permute(0, 2, 1), + ).squeeze(-2) + + # text-modal similarity: aggregate across all query tokens + sim_t2x, _ = sim_t2q.max(-1) + sim_t2x = sim_t2x / self.temp + + loss_itc = ( + F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2 + + return loss_itc diff --git a/torch_geometric/nn/nlp/__init__.py b/torch_geometric/nn/nlp/__init__.py index c101a359e3f5..434619352460 100644 --- a/torch_geometric/nn/nlp/__init__.py +++ b/torch_geometric/nn/nlp/__init__.py @@ -1,7 +1,9 @@ from .sentence_transformer import SentenceTransformer +from .vision_transformer import VisionTransformer from .llm import LLM __all__ = classes = [ 'SentenceTransformer', + 'VisionTransformer', 'LLM', ] diff --git a/torch_geometric/nn/nlp/sentence_transformer.py b/torch_geometric/nn/nlp/sentence_transformer.py index 715f343bfc19..6d904b8e0fbf 100644 --- a/torch_geometric/nn/nlp/sentence_transformer.py +++ b/torch_geometric/nn/nlp/sentence_transformer.py @@ -48,6 +48,36 @@ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: emb = F.normalize(emb, p=2, dim=1) return emb + def get_input_ids( + self, + text: List[str], + batch_size: Optional[int] = None, + output_device: Optional[Union[torch.device, str]] = None, + ) -> Tensor: + is_empty = len(text) == 0 + text = ['dummy'] if is_empty else text + + batch_size = len(text) if batch_size is None else batch_size + + input_ids: List[Tensor] = [] + attention_masks: List[Tensor] = [] + for start in range(0, len(text), batch_size): + token = self.tokenizer( + text[start:start + batch_size], + padding=True, + truncation=True, + return_tensors='pt', + ) + input_ids.append(token.input_ids.to(self.device)) + attention_masks.append(token.attention_mask.to(self.device)) + + def _out(x: List[Tensor]) -> Tensor: + out = torch.cat(x, dim=0) if len(x) > 1 else x[0] + out = out[:0] if is_empty else out + return out.to(output_device) + + return _out(input_ids), _out(attention_masks) + @property def device(self) -> torch.device: return next(iter(self.model.parameters())).device diff --git a/torch_geometric/nn/nlp/vision_transformer.py b/torch_geometric/nn/nlp/vision_transformer.py new file mode 100644 index 000000000000..517a524f4d84 --- /dev/null +++ b/torch_geometric/nn/nlp/vision_transformer.py @@ -0,0 +1,33 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + + +class VisionTransformer(torch.nn.Module): + def __init__( + self, + model_name: str, + ) -> None: + super().__init__() + self.model_name = model_name + + from transformers import SwinConfig, SwinModel + + self.config = SwinConfig.from_pretrained(model_name) + self.model = SwinModel(self.config) + + @torch.no_grad() + def forward( + self, + images: Tensor, + output_device: Optional[Union[torch.device, str]] = None, + ) -> Tensor: + return self.model(images).last_hidden_state.to(output_device) + + @property + def device(self) -> torch.device: + return next(iter(self.model.parameters())).device + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(model_name={self.model_name})' From f8760ec39951715def19a95ee62ae4fc3ed056fd Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 25 Nov 2024 11:24:00 +0100 Subject: [PATCH 06/45] Run `GitMolDataset` tests only in full test mode (#9804) --- test/datasets/test_git_mol_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/datasets/test_git_mol_dataset.py b/test/datasets/test_git_mol_dataset.py index 3da72f2f2182..f4e652b6ae43 100644 --- a/test/datasets/test_git_mol_dataset.py +++ b/test/datasets/test_git_mol_dataset.py @@ -1,16 +1,19 @@ +from typing import Tuple + import pytest from torch_geometric.datasets import GitMolDataset -from torch_geometric.testing import withPackage +from torch_geometric.testing import onlyFullTest, withPackage +@onlyFullTest @withPackage('torchvision', 'rdkit', 'PIL') @pytest.mark.parametrize('split', [ (0, 3610), (1, 451), (2, 451), ]) -def test_git_mol_dataset(split): +def test_git_mol_dataset(split: Tuple[int, int]) -> None: dataset = GitMolDataset(root='./data/GITMol', split=split[0]) assert len(dataset) == split[1] From bd7876c86046a4ed8e17583bc0bbaebeb0518510 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Mon, 25 Nov 2024 02:24:23 -0800 Subject: [PATCH 07/45] fix for cugraph (#9803) --- examples/multi_gpu/papers100m_gcn_cugraph.py | 4 ++-- examples/multi_gpu/papers100m_gcn_cugraph_multinode.py | 4 ++-- examples/ogbn_papers_100m_cugraph.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/multi_gpu/papers100m_gcn_cugraph.py b/examples/multi_gpu/papers100m_gcn_cugraph.py index 5413492a5bc5..799b6317c374 100644 --- a/examples/multi_gpu/papers100m_gcn_cugraph.py +++ b/examples/multi_gpu/papers100m_gcn_cugraph.py @@ -86,8 +86,8 @@ def run(rank, data, world_size, cugraph_id, model, epochs, batch_size, fan_out, )] = ixr feature_store = TensorDictFeatureStore() - feature_store['node', 'x'] = data.x - feature_store['node', 'y'] = data.y + feature_store['node', 'x', None] = data.x + feature_store['node', 'y', None] = data.y dist.barrier() diff --git a/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py b/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py index 4ea78eb64ad6..eb074defeafe 100644 --- a/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py +++ b/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py @@ -142,9 +142,9 @@ def load_partitioned_data(rank, edge_path, feature_path, label_path, meta_path, split_idx[split] = fs.torch_load(path) path = osp.join(feature_path, f'rank={rank}_x.pt') - feature_store['node', 'x'] = fs.torch_load(path) + feature_store['node', 'x', None] = fs.torch_load(path) path = osp.join(feature_path, f'rank={rank}_y.pt') - feature_store['node', 'y'] = fs.torch_load(path) + feature_store['node', 'y', None] = fs.torch_load(path) eix = fs.torch_load(osp.join(edge_path, f'rank={rank}.pt')) graph_store[dict( diff --git a/examples/ogbn_papers_100m_cugraph.py b/examples/ogbn_papers_100m_cugraph.py index 7c1da866056a..8ae35cd776a4 100644 --- a/examples/ogbn_papers_100m_cugraph.py +++ b/examples/ogbn_papers_100m_cugraph.py @@ -63,8 +63,8 @@ )] = data.edge_index feature_store = cugraph_pyg.data.TensorDictFeatureStore() -feature_store['node', 'x'] = data.x -feature_store['node', 'y'] = data.y +feature_store['node', 'x', None] = data.x +feature_store['node', 'y', None] = data.y data = (feature_store, graph_store) From 742f790090e420b8a3ebff295e5280a373d8aed1 Mon Sep 17 00:00:00 2001 From: zaristei Date: Mon, 25 Nov 2024 17:29:25 -0800 Subject: [PATCH 08/45] G-retriever API updates (NVTX, Remote Backend, Large Graph Indexer, Examples) (#9666) Follow up to [PR 9597](https://github.com/pyg-team/pytorch_geometric/pull/9597). Includes multiple changes related to LLM+GNN experiments and scaling up to a remote backend. Including: - LargeGraphIndexer for building a large knowledge graph locally from multiple samples in an arbitrary dataset - Remote Backend Loader and examples for deploying a Retrieval algorithm to a third party backend FeatureStore or GraphStore - NVTX profiling tools for nsys users - Quality of Life improvements and benchmarking scripts for G-Retriever. Updates using these for WebQSP will be moved to a seperate PR UPDATE: PR is being broken up into smaller PRs. These can be previewed here: - https://github.com/zaristei/pytorch_geometric/pull/6 - https://github.com/zaristei/pytorch_geometric/pull/7 - https://github.com/zaristei/pytorch_geometric/pull/8 --------- Co-authored-by: Zack Aristei Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zachary Aristei Co-authored-by: Rishi Puri --- CHANGELOG.md | 4 + examples/llm/README.md | 15 +- examples/llm/g_retriever.py | 4 + examples/llm/g_retriever_utils/README.md | 11 + .../benchmark_model_archs_rag.py | 105 +++ .../llm/g_retriever_utils/minimal_demo.py | 638 +++++++++++++++++ .../g_retriever_utils/rag_backend_utils.py | 224 ++++++ .../g_retriever_utils/rag_feature_store.py | 189 +++++ .../llm/g_retriever_utils/rag_generate.py | 139 ++++ .../llm/g_retriever_utils/rag_graph_store.py | 107 +++ examples/llm/multihop_rag/README.md | 9 + .../llm/multihop_rag/multihop_download.sh | 12 + .../llm/multihop_rag/multihop_preprocess.py | 276 +++++++ .../llm/multihop_rag/rag_generate_multihop.py | 88 +++ examples/llm/nvtx_examples/README.md | 7 + .../nvtx_examples/nvtx_rag_backend_example.py | 144 ++++ examples/llm/nvtx_examples/nvtx_run.sh | 27 + .../llm/nvtx_examples/nvtx_webqsp_example.py | 22 + test/data/test_large_graph_indexer.py | 177 +++++ test/nn/models/test_g_retriever.py | 49 ++ test/profile/test_nvtx.py | 136 ++++ torch_geometric/data/__init__.py | 5 + torch_geometric/data/large_graph_indexer.py | 677 ++++++++++++++++++ torch_geometric/loader/__init__.py | 2 + torch_geometric/loader/rag_loader.py | 106 +++ torch_geometric/nn/models/g_retriever.py | 13 +- torch_geometric/profile/__init__.py | 2 + torch_geometric/profile/nvtx.py | 66 ++ 28 files changed, 3247 insertions(+), 7 deletions(-) create mode 100644 examples/llm/g_retriever_utils/README.md create mode 100644 examples/llm/g_retriever_utils/benchmark_model_archs_rag.py create mode 100644 examples/llm/g_retriever_utils/minimal_demo.py create mode 100644 examples/llm/g_retriever_utils/rag_backend_utils.py create mode 100644 examples/llm/g_retriever_utils/rag_feature_store.py create mode 100644 examples/llm/g_retriever_utils/rag_generate.py create mode 100644 examples/llm/g_retriever_utils/rag_graph_store.py create mode 100644 examples/llm/multihop_rag/README.md create mode 100644 examples/llm/multihop_rag/multihop_download.sh create mode 100644 examples/llm/multihop_rag/multihop_preprocess.py create mode 100644 examples/llm/multihop_rag/rag_generate_multihop.py create mode 100644 examples/llm/nvtx_examples/README.md create mode 100644 examples/llm/nvtx_examples/nvtx_rag_backend_example.py create mode 100644 examples/llm/nvtx_examples/nvtx_run.sh create mode 100644 examples/llm/nvtx_examples/nvtx_webqsp_example.py create mode 100644 test/data/test_large_graph_indexer.py create mode 100644 test/profile/test_nvtx.py create mode 100644 torch_geometric/data/large_graph_indexer.py create mode 100644 torch_geometric/loader/rag_loader.py create mode 100644 torch_geometric/profile/nvtx.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 86e06d8835e8..341be665fabf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730)) - Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) - Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) diff --git a/examples/llm/README.md b/examples/llm/README.md index eb471563de8e..4503e28ce6ee 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,8 +1,11 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | -| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | -| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | -| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| Example | Description | +| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | +| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | +| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 984ce3f010e7..a48901f1ff0e 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -11,6 +11,7 @@ https://github.com/neo4j-product-examples/neo4j-gnn-llm-example """ import argparse +import gc import math import os.path as osp import re @@ -145,6 +146,9 @@ def adjust_learning_rate(param_group, LR, epoch): test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) + # To clean up after Data Preproc + gc.collect() + torch.cuda.empty_cache() gnn = GAT( in_channels=1024, hidden_channels=hidden_channels, diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md new file mode 100644 index 000000000000..e072e6746b7c --- /dev/null +++ b/examples/llm/g_retriever_utils/README.md @@ -0,0 +1,11 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | +| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | +| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. | + +NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes. diff --git a/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py new file mode 100644 index 000000000000..6522aafca68b --- /dev/null +++ b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py @@ -0,0 +1,105 @@ +"""Used to benchmark the performance of an untuned/fine tuned LLM against +GRetriever with various architectures and layer depths. +""" +# %% +import argparse +import sys + +import torch + +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.nn.models import GAT, MLP, GRetriever + +sys.path.append('..') +from minimal_demo import ( # noqa: E402 # isort:skip + benchmark_models, get_loss, inference_step, +) + +# %% +parser = argparse.ArgumentParser( + description="""Benchmarker for GRetriever\n""" + + """NOTE: Evaluating with smaller samples may result in poorer""" + + """ performance for the trained models compared to """ + + """untrained models.""") +parser.add_argument("--hidden_channels", type=int, default=1024) +parser.add_argument("--learning_rate", type=float, default=1e-5) +parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--eval_batch_size", type=int, default=16) +parser.add_argument("--tiny_llama", action='store_true') + +parser.add_argument("--dataset_path", type=str, required=False) +# Default to WebQSP split +parser.add_argument("--num_train", type=int, default=2826) +parser.add_argument("--num_val", type=int, default=246) +parser.add_argument("--num_test", type=int, default=1628) + +args = parser.parse_args() + +# %% +hidden_channels = args.hidden_channels +lr = args.learning_rate +epochs = args.epochs +batch_size = args.batch_size +eval_batch_size = args.eval_batch_size + +# %% +if not args.dataset_path: + ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True) +else: + # We just assume that the size of the dataset accomodates the + # train/val/test split, because checking may be expensive. + dataset = torch.load(args.dataset_path) + + class MockDataset: + """Utility class to patch the fields in WebQSPDataset used by + GRetriever. + """ + def __init__(self) -> None: + pass + + @property + def split_idxs(self) -> dict: + # Imitates the WebQSP split method + return { + "train": + torch.arange(args.num_train), + "val": + torch.arange(args.num_val) + args.num_train, + "test": + torch.arange(args.num_test) + args.num_train + args.num_val, + } + + def __getitem__(self, idx: int): + return dataset[idx] + + ds = MockDataset() + +# %% +model_names = [] +model_classes = [] +model_kwargs = [] +model_type = ["GAT", "MLP"] +models = {"GAT": GAT, "MLP": MLP} +# Use to vary the depth of the GNN model +num_layers = [4] +# Use to vary the number of LLM tokens reserved for GNN output +num_tokens = [1] +for m_type in model_type: + for n_layer in num_layers: + for n_tokens in num_tokens: + model_names.append(f"{m_type}_{n_layer}_{n_tokens}") + model_classes.append(GRetriever) + kwargs = dict(gnn_hidden_channels=hidden_channels, + num_gnn_layers=n_layer, gnn_to_use=models[m_type], + mlp_out_tokens=n_tokens) + if args.tiny_llama: + kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1' + kwargs['mlp_out_dim'] = 2048 + kwargs['num_llm_params'] = 1 + model_kwargs.append(kwargs) + +# %% +benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs, + batch_size, eval_batch_size, get_loss, inference_step, + skip_LLMs=False, tiny_llama=args.tiny_llama, force=True) diff --git a/examples/llm/g_retriever_utils/minimal_demo.py b/examples/llm/g_retriever_utils/minimal_demo.py new file mode 100644 index 000000000000..bdd78c3180cb --- /dev/null +++ b/examples/llm/g_retriever_utils/minimal_demo.py @@ -0,0 +1,638 @@ +"""This example implements the G-Retriever model +(https://arxiv.org/abs/2402.07630) using PyG. + +G-Retriever significantly reduces hallucinations by 54% compared to the +stand-alone LLM baseline. + +Requirements: +`pip install datasets transformers pcst_fast sentencepiece accelerate` +""" +import argparse +import gc +import math +import multiprocessing as mp +import re +import sys +import time +from os import path +from typing import Any, Callable, Dict, List, Type + +import pandas as pd +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.data import Dataset +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn.models import GAT, GRetriever +from torch_geometric.nn.nlp import LLM + +# NOTE: This used to be merged in the G-Retriever example. +# FIXME: Getting the demos working like before is a WIP +sys.path.append('..') +from g_retriever import ( # noqa: E402 # isort:skip + compute_metrics, load_params_dict, save_params_dict, +) + + +def _detect_hallucinate(inp): + pred, label = inp + try: + split_pred = pred.split('[/s]')[0].strip().split('|') + correct_hit = len(re.findall(split_pred[0], label)) > 0 + correct_hit = correct_hit or any( + [label_i in pred.lower() for label_i in label.split('|')]) + hallucination = not correct_hit + return hallucination + except: # noqa + return "skip" + + +def detect_hallucinate(pred_batch, label_batch): + r"""An approximation for the unsolved task of detecting hallucinations. + We define a hallucination as an output that contains no instances of + acceptable label. + """ + with mp.Pool(len(pred_batch)) as p: + res = p.map(_detect_hallucinate, zip(pred_batch, label_batch)) + return res + + +def compute_n_parameters(model: torch.nn.Module) -> int: + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + + +def get_loss(model, batch, model_save_name) -> Tensor: + if model_save_name == 'llm': + return model(batch.question, batch.label, batch.desc) + else: + return model(batch.question, batch.x, batch.edge_index, batch.batch, + batch.label, batch.edge_attr, batch.desc) + + +def inference_step(model, batch, model_save_name): + if model_save_name == 'llm': + return model.inference(batch.question, batch.desc) + else: + return model.inference(batch.question, batch.x, batch.edge_index, + batch.batch, batch.edge_attr, batch.desc) + + +# TODO: Merge with G-Retriever example and make sure changes still work +def train( + num_epochs, + hidden_channels, + num_gnn_layers, + batch_size, + eval_batch_size, + lr, + checkpointing=False, + tiny_llama=False, + model=None, + dataset=None, + model_save_name=None, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + seed_everything(42) + if dataset is None: + dataset = WebQSPDataset() + gc.collect() + elif not isinstance(dataset, Dataset) and callable(dataset): + dataset = dataset() + gc.collect() + idx_split = dataset.split_idxs + + # Step 1: Build Node Classification Dataset + train_dataset = [dataset[i] for i in idx_split['train']] + val_dataset = [dataset[i] for i in idx_split['val']] + test_dataset = [dataset[i] for i in idx_split['test']] + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if model is None: + gc.collect() + gnn = GAT( + in_channels=1024, + hidden_channels=hidden_channels, + out_channels=1024, + num_layers=num_gnn_layers, + heads=4, + ) + if tiny_llama: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + ) + model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) + else: + llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) + model = GRetriever(llm=llm, gnn=gnn) + + if model_save_name is None: + model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm' + + model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' + if model_save_name == 'llm': + model = llm + + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05 + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = get_loss(model, batch, model_save_name) + loss.backward() + + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss = epoch_loss + float(loss) + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + print(epoch_str + f', Train Loss: {train_loss:4f}') + + val_loss = 0 + eval_output = [] + model.eval() + with torch.no_grad(): + for step, batch in enumerate(val_loader): + loss = get_loss(model, batch, model_save_name) + val_loss += loss.item() + val_loss = val_loss / len(val_loader) + print(epoch_str + f", Val Loss: {val_loss:4f}") + if checkpointing and val_loss < best_val_loss: + print("Checkpointing best model...") + best_val_loss = val_loss + best_epoch = epoch + save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt') + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + if checkpointing and best_epoch != num_epochs - 1: + print("Loading best checkpoint...") + model = load_params_dict( + model, + f'{model_save_name}_best_val_loss_ckpt.pt', + ) + + model.eval() + eval_output = [] + print("Final evaluation...") + progress_bar_test = tqdm(range(len(test_loader))) + for step, batch in enumerate(test_loader): + with torch.no_grad(): + pred = inference_step(model, batch, model_save_name) + eval_data = { + 'pred': pred, + 'question': batch.question, + 'desc': batch.desc, + 'label': batch.label + } + eval_output.append(eval_data) + progress_bar_test.update(1) + + # Step 6 Post-processing & compute metrics + compute_metrics(eval_output) + print(f"Total Training Time: {time.time() - start_time:2f}s") + save_params_dict(model, f'{model_save_name}.pt') + torch.save(eval_output, f'{model_save_name}_eval_outs.pt') + print("Done!") + return prep_time, dataset, eval_output # noqa: F821 + + +def _eval_hallucinations_on_loader(outs, loader, eval_batch_size): + model_save_list = [] + model_preds = [] + for out in outs: + model_preds += out['pred'] + for i, batch in enumerate(loader): + correct_answer = batch.label + + model_pred = model_preds[i * eval_batch_size:(i + 1) * eval_batch_size] + model_hallucinates = detect_hallucinate(model_pred, correct_answer) + model_save_list += [tup for tup in zip(model_pred, model_hallucinates)] + return model_save_list + + +def benchmark_models(models: List[Type[nn.Module]], model_names: List[str], + model_kwargs: List[Dict[str, Any]], dataset: Dataset, + lr: float, epochs: int, batch_size: int, + eval_batch_size: int, loss_fn: Callable, + inference_fn: Callable, skip_LLMs: bool = True, + tiny_llama: bool = False, checkpointing: bool = True, + force: bool = False, root_dir='.'): + """Utility function for creating a model benchmark for GRetriever that + grid searches over hyperparameters. Produces a DataFrame containing + metrics for each model. + + Args: + models (List[Type[nn.Module]]): Models to be benchmarked. + model_names (List[str]): Name of save files for model checkpoints + model_kwargs (List[Dict[str, Any]]): Parameters to use for each + particular model. + dataset (Dataset): Input dataset to train on. + lr (float): Learning rate + epochs (int): Number of epochs + batch_size (int): Batch size for training + eval_batch_size (int): Batch size for eval. Also determines + hallucination detection concurrancy. + loss_fn (Callable): Loss function + inference_fn (Callable): Inference function + skip_LLMs (bool, optional): Whether to skip LLM-only runs. + Defaults to True. + tiny_llama (bool, optional): Whether to use tiny llama as LLM. + Defaults to False. + checkpointing (bool, optional): Whether to checkpoint models. + Defaults to True. + force (bool, optional): Whether to rerun already existing results. + Defaults to False. + root_dir (str, optional): Dir to save results and checkpoints in. + Defaults to '.'. + """ + model_log: Dict[str, Dict[str, Any]] = dict() + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if not skip_LLMs: + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + + if force or not path.exists(root_dir + "/pure_llm_model_log.pt"): + model_log["pure_llm"] = dict() + + pure_preds = [] + for batch in tqdm(loader): + pure_llm_preds = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_preds += pure_llm_preds + pure_preds = [{"pred": pred} for pred in pure_preds] + + model_log["pure_llm"]["preds"] = pure_preds + model_log["pure_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(pure_preds, loader, + eval_batch_size) + model_log["pure_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["pure_llm"], + root_dir + "/pure_llm_model_log.pt") + else: + model_log["pure_llm"] = \ + torch.load(root_dir+"/pure_llm_model_log.pt") + + # LORA + if force or not path.exists(root_dir + "/tuned_llm_model_log.pt"): + model_log["tuned_llm"] = dict() + since = time.time() + gc.collect() + prep_time, _, lora_eval_outs = train(since, epochs, None, None, + batch_size, eval_batch_size, + lr, loss_fn, inference_fn, + model=pure_llm, + dataset=dataset) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log["tuned_llm"]["prep_time"] = prep_time + model_log["tuned_llm"]["e2e_time"] = e2e_time + model_log["tuned_llm"]["eval_output"] = lora_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + + model_log["tuned_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(lora_eval_outs, loader, + eval_batch_size) + model_log["tuned_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["tuned_llm"], + root_dir + "/tuned_llm_model_log.pt") + else: + model_log["tuned_llm"] = \ + torch.load(root_dir+"/tuned_llm_model_log.pt") + + del pure_llm + gc.collect() + + # All other models + for name, Model, kwargs in zip(model_names, models, model_kwargs): + model_log[name] = dict() + train_model = True + if path.exists(root_dir + f"/{name}.pt") and not force: + print(f"Model {name} appears to already exist.") + print("Would you like to retrain?") + train_model = str(input("(y/n):")).lower() == "y" + + if train_model: + since = time.time() + gc.collect() + model = Model(**kwargs) + prep_time, _, model_eval_outs = train( + since=since, num_epochs=epochs, hidden_channels=None, + num_gnn_layers=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, lr=lr, loss_fn=loss_fn, + inference_fn=inference_fn, checkpointing=checkpointing, + tiny_llama=tiny_llama, dataset=dataset, + model_save_name=root_dir + '/' + name, model=model) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log[name]["prep_time"] = prep_time + model_log[name]["e2e_time"] = e2e_time + model_log[name]["eval_output"] = model_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + model_log[name]["n_params"] = compute_n_parameters(model) + del model + gc.collect() + else: + model_eval_outs = torch.load(root_dir + f"/{name}_eval_outs.pt") + + # Calculate Hallucinations + skip_hallucination_detection = False + + if path.exists(root_dir + f"/{name}_model_log.pt") and not force: + print(f"Saved outputs for {name} have been found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_hallucination_detection = user_input != "y" + + if not skip_hallucination_detection: + model_save_list = _eval_hallucinations_on_loader( + model_eval_outs, loader, eval_batch_size) + + model_log[name]["hallucinates_list"] = model_save_list + torch.save(model_log[name], root_dir + f"/{name}_model_log.pt") + else: + model_log[name]["hallucinates_list"] = \ + torch.load( + root_dir+f"/{name}_model_log.pt" + )["hallucinates_list"] + + hal_dict = { + k: [tup[1] for tup in v["hallucinates_list"]] + for (k, v) in model_log.items() + } + hallucinates_df = pd.DataFrame(hal_dict).astype(str) + hallucinates_df = hallucinates_df.apply(pd.Series.value_counts).transpose() + hallucinates_df['e2e_time'] = pd.Series( + {k: v.get('e2e_time') + for (k, v) in model_log.items()}) + hallucinates_df['n_params'] = pd.Series( + {k: v.get('n_params') + for (k, v) in model_log.items()}) + print(hallucinates_df) + hallucinates_df.to_csv(root_dir + "/hallucinates_df.csv", index=False) + + +def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size, + eval_batch_size, loss_fn, inference_fn, + skip_pretrained_LLM=False, tiny_llama=False): + if not skip_pretrained_LLM: + print("First comparing against a pretrained LLM...") + # Step 1: Define a single batch size test loader + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + # batch size 1 loader for simplicity + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + if path.exists("demo_save_dict.pt"): + print("Saved outputs for the first step of the demo found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_step_one = user_input == "n" + else: + skip_step_one = False + + if not skip_step_one: + gnn_llm_hallucin_sum = 0 + pure_llm_hallucin_sum = 0 + gnn_save_list = [] + untuned_llm_save_list = [] + gnn_llm_preds = [] + for out in gnn_llm_eval_outs: + gnn_llm_preds += out['pred'] + if skip_pretrained_LLM: + print("Checking GNN+LLM for hallucinations...") + else: + print( + "Checking pretrained LLM vs trained GNN+LLM for hallucinations..." # noqa + ) + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + + gnn_llm_pred = gnn_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred, + correct_answer) + gnn_save_list += [ + tup for tup in zip(gnn_llm_pred, gnn_llm_hallucinates) + ] + + if not skip_pretrained_LLM: + # GNN+LLM only using 32 tokens to answer. + # Allow more output tokens for untrained LLM + pure_llm_pred = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_llm_hallucinates = detect_hallucinate( + pure_llm_pred, correct_answer) + else: + pure_llm_pred = [''] * len(gnn_llm_hallucinates) + pure_llm_hallucinates = [False] * len(gnn_llm_hallucinates) + untuned_llm_save_list += [ + tup for tup in zip(pure_llm_pred, pure_llm_hallucinates) + ] + + for gnn_llm_hal, pure_llm_hal in zip(gnn_llm_hallucinates, + pure_llm_hallucinates): + if gnn_llm_hal == "skip" or pure_llm_hal == "skip": # noqa + # skipping when hallucination is hard to eval + continue + gnn_llm_hallucin_sum += int(gnn_llm_hal) + pure_llm_hallucin_sum += int(pure_llm_hal) + if not skip_pretrained_LLM: + print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2) + print(f"GNN reduces pretrained LLM hallucinations by: ~{percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Now we see how the LLM compares when finetuned...") + print("Saving outputs of GNN+LLM and pretrained LLM...") + save_dict = { + "gnn_save_list": gnn_save_list, + "untuned_llm_save_list": untuned_llm_save_list, + "gnn_llm_hallucin_sum": gnn_llm_hallucin_sum, + "pure_llm_hallucin_sum": pure_llm_hallucin_sum + } + torch.save(save_dict, "demo_save_dict.pt") + print("Done!") + else: + save_dict = torch.load("demo_save_dict.pt") + gnn_save_list = save_dict["gnn_save_list"] + untuned_llm_save_list = save_dict["untuned_llm_save_list"] + gnn_llm_hallucin_sum = save_dict["gnn_llm_hallucin_sum"] + pure_llm_hallucin_sum = save_dict["pure_llm_hallucin_sum"] + + trained_llm_hallucin_sum = 0 + untuned_llm_hallucin_sum = pure_llm_hallucin_sum + final_prnt_str = "" + if path.exists("llm.pt") and path.exists("llm_eval_outs.pt"): + print("Existing finetuned LLM found.") + print("Would you like to retrain?") + user_input = str(input("(y/n):")).lower() + retrain = user_input == "y" + else: + retrain = True + if retrain: + print("Finetuning LLM...") + since = time.time() + _, _, pure_llm_eval_outputs = train(since, epochs, None, None, + batch_size, eval_batch_size, lr, + loss_fn, inference_fn, + model=pure_llm, dataset=dataset) + e2e_time = round(time.time() - since, 2) + print("E2E time (e2e_time) =", e2e_time, "seconds") + else: + pure_llm_eval_outputs = torch.load("llm_eval_outs.pt") + pure_llm_preds = [] + for out in pure_llm_eval_outputs: + pure_llm_preds += out['pred'] + print("Final comparison between all models...") + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + gnn_llm_pred, gnn_llm_hallucinates = list( + zip(*gnn_save_list[i * eval_batch_size:(i + 1) * eval_batch_size])) + untuned_llm_pred, untuned_llm_hallucinates = list( + zip(*untuned_llm_save_list[i * eval_batch_size:(i + 1) * + eval_batch_size])) + pure_llm_pred = pure_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + pure_llm_hallucinates = detect_hallucinate(pure_llm_pred, + correct_answer) + for j in range(len(gnn_llm_pred)): + if skip_pretrained_LLM: + # we did not check the untrained LLM, so do not decide to demo + # based on this. + # HACK + untuned_llm_hallucinates = {j: True} + if gnn_llm_hallucinates[j] == "skip" or untuned_llm_hallucinates[ + j] == "skip" or pure_llm_hallucinates[j] == "skip": + continue + trained_llm_hallucin_sum += int(pure_llm_hallucinates[j]) + if untuned_llm_hallucinates[j] and pure_llm_hallucinates[ + j] and not gnn_llm_hallucinates[j]: # noqa + final_prnt_str += "Prompt: '" + question[j] + "'\n" + final_prnt_str += "Label: '" + correct_answer[j] + "'\n" + if not skip_pretrained_LLM: + final_prnt_str += "Untuned LLM Output: '" \ + + untuned_llm_pred[j] + "'\n" # noqa + final_prnt_str += "Tuned LLM Output: '" + pure_llm_pred[ + j] + "'\n" + final_prnt_str += "GNN+LLM Output: '" + gnn_llm_pred[j] + "'\n" + final_prnt_str += "\n" + "#" * 20 + "\n\n" + if not skip_pretrained_LLM: + print("Total untuned LLM Hallucinations:", untuned_llm_hallucin_sum) + print("Total tuned LLM Hallucinations:", trained_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + if not skip_pretrained_LLM: + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / untuned_llm_hallucin_sum), 2) + print(f"GNN reduces untuned LLM hallucinations by: ~{percent}%") + tuned_percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / trained_llm_hallucin_sum), 2) + print(f"GNN reduces tuned LLM hallucinations by: ~{tuned_percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Potential instances where GNN solves the hallucinations of LLM:") + print(final_prnt_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gnn_hidden_channels', type=int, default=1024) + parser.add_argument('--num_gnn_layers', type=int, default=4) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--epochs', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--eval_batch_size', type=int, default=16) + parser.add_argument('--checkpointing', action='store_true') + parser.add_argument('--tiny_llama', action='store_true') + parser.add_argument( + "--skip_pretrained_llm_eval", action="store_true", + help="This flag will skip the evaluation of the pretrained LLM.") + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.gnn_hidden_channels, + args.num_gnn_layers, + args.batch_size, + args.eval_batch_size, + args.lr, + checkpointing=args.checkpointing, + tiny_llama=args.tiny_llama, + ) + print(f"Total Time: {time.time() - start_time:2f}s") diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py new file mode 100644 index 000000000000..0f1c0e1b87ec --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_backend_utils.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Protocol, + Tuple, + Type, + runtime_checkable, +) + +import torch +from torch import Tensor +from torch.nn import Module + +from torch_geometric.data import ( + FeatureStore, + GraphStore, + LargeGraphIndexer, + TripletLike, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION +from torch_geometric.distributed import ( + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.typing import EdgeType, NodeType + +RemoteGraphBackend = Tuple[FeatureStore, GraphStore] + +# TODO: Make everything compatible with Hetero graphs aswell + + +# Adapted from LocalGraphStore +@runtime_checkable +class ConvertableGraphStore(Protocol): + @classmethod + def from_data( + cls, + edge_id: Tensor, + edge_index: Tensor, + num_nodes: int, + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_hetero_data( + cls, + edge_id_dict: Dict[EdgeType, Tensor], + edge_index_dict: Dict[EdgeType, Tensor], + num_nodes_dict: Dict[NodeType, int], + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> GraphStore: + ... + + +# Adapted from LocalFeatureStore +@runtime_checkable +class ConvertableFeatureStore(Protocol): + @classmethod + def from_data( + cls, + node_id: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + edge_id: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_hetero_data( + cls, + node_id_dict: Dict[NodeType, Tensor], + x_dict: Optional[Dict[NodeType, Tensor]] = None, + y_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> FeatureStore: + ... + + +class RemoteDataType(Enum): + DATA = auto() + PARTITION = auto() + + +@dataclass +class RemoteGraphBackendLoader: + """Utility class to load triplets into a RAG Backend.""" + path: str + datatype: RemoteDataType + graph_store_type: Type[ConvertableGraphStore] + feature_store_type: Type[ConvertableFeatureStore] + + def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: + if self.datatype == RemoteDataType.DATA: + data_obj = torch.load(self.path) + graph_store = self.graph_store_type.from_data( + edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, + num_nodes=data_obj.num_nodes) + feature_store = self.feature_store_type.from_data( + node_id=data_obj['node_id'], x=data_obj.x, + edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) + elif self.datatype == RemoteDataType.PARTITION: + if pid is None: + assert pid is not None, \ + "Partition ID must be defined for loading from a " \ + + "partitioned store." + graph_store = self.graph_store_type.from_partition(self.path, pid) + feature_store = self.feature_store_type.from_partition( + self.path, pid) + else: + raise NotImplementedError + return (feature_store, graph_store) + + +# TODO: make profilable +def create_remote_backend_from_triplets( + triplets: Iterable[TripletLike], node_embedding_model: Module, + edge_embedding_model: Module | None = None, + graph_db: Type[ConvertableGraphStore] = LocalGraphStore, + feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, + node_method_to_call: str = "forward", + edge_method_to_call: str | None = None, + pre_transform: Callable[[TripletLike], TripletLike] | None = None, + path: str = '', n_parts: int = 1, + node_method_kwargs: Optional[Dict[str, Any]] = None, + edge_method_kwargs: Optional[Dict[str, Any]] = None +) -> RemoteGraphBackendLoader: + """Utility function that can be used to create a RAG Backend from triplets. + + Args: + triplets (Iterable[TripletLike]): Triplets to load into the RAG + Backend. + node_embedding_model (Module): Model to embed nodes into a feature + space. + edge_embedding_model (Module | None, optional): Model to embed edges + into a feature space. Defaults to the node model. + graph_db (Type[ConvertableGraphStore], optional): GraphStore class to + use. Defaults to LocalGraphStore. + feature_db (Type[ConvertableFeatureStore], optional): FeatureStore + class to use. Defaults to LocalFeatureStore. + node_method_to_call (str, optional): method to call for embeddings on + the node model. Defaults to "forward". + edge_method_to_call (str | None, optional): method to call for + embeddings on the edge model. Defaults to the node method. + pre_transform (Callable[[TripletLike], TripletLike] | None, optional): + optional preprocessing function for triplets. Defaults to None. + path (str, optional): path to save resulting stores. Defaults to ''. + n_parts (int, optional): Number of partitons to store in. + Defaults to 1. + node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into node encoding method. Defaults to None. + edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into edge encoding method. Defaults to None. + + Returns: + RemoteGraphBackendLoader: Loader to load RAG backend from disk or + memory. + """ + # Will return attribute errors for missing attributes + if not issubclass(graph_db, ConvertableGraphStore): + getattr(graph_db, "from_data") + getattr(graph_db, "from_hetero_data") + getattr(graph_db, "from_partition") + elif not issubclass(feature_db, ConvertableFeatureStore): + getattr(feature_db, "from_data") + getattr(feature_db, "from_hetero_data") + getattr(feature_db, "from_partition") + + # Resolve callable methods + node_method_kwargs = node_method_kwargs \ + if node_method_kwargs is not None else dict() + + edge_embedding_model = edge_embedding_model \ + if edge_embedding_model is not None else node_embedding_model + edge_method_to_call = edge_method_to_call \ + if edge_method_to_call is not None else node_method_to_call + edge_method_kwargs = edge_method_kwargs \ + if edge_method_kwargs is not None else node_method_kwargs + + # These will return AttributeErrors if they don't exist + node_model = getattr(node_embedding_model, node_method_to_call) + edge_model = getattr(edge_embedding_model, edge_method_to_call) + + indexer = LargeGraphIndexer.from_triplets(triplets, + pre_transform=pre_transform) + + node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) + indexer.add_node_feature('x', node_feats) + + edge_feats = edge_model( + indexer.get_unique_edge_features(feature_name=EDGE_RELATION), + **edge_method_kwargs) + indexer.add_edge_feature(new_feature_name="edge_attr", + new_feature_vals=edge_feats, + map_from_feature=EDGE_RELATION) + + data = indexer.to_data(node_feature_name='x', + edge_feature_name='edge_attr') + + if n_parts == 1: + torch.save(data, path) + return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, + feature_db) + else: + partitioner = Partitioner(data=data, num_parts=n_parts, root=path) + partitioner.generate_partition() + return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, + graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py new file mode 100644 index 000000000000..e01e9e59bb88 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_feature_store.py @@ -0,0 +1,189 @@ +import gc +from collections.abc import Iterable, Iterator +from typing import Any, Dict, Optional, Type, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torchmetrics.functional import pairwise_cosine_similarity + +from torch_geometric.data import Data, HeteroData +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.nn.pool import ApproxMIPSKNNIndex +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +# NOTE: Only compatible with Homogeneous graphs for now +class KNNRAGFeatureStore(LocalFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.enc_model = enc_model(*args, **kwargs).to(self.device) + self.enc_model.eval() + self.model_kwargs = \ + model_kwargs if model_kwargs is not None else dict() + super().__init__() + + @property + def x(self) -> Tensor: + return self.get_tensor(group_name=None, attr_name='x') + + @property + def edge_attr(self) -> Tensor: + return self.get_tensor(group_name=(None, None), attr_name='edge_attr') + + def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: + result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) + topk = min(k_nodes, len(self.x)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: + result = next(self._retrieve_seed_edges_batch([query], k_edges)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + + prizes = pairwise_cosine_similarity(query_enc, + self.edge_attr.to(self.device)) + topk = min(k_edges, len(self.edge_attr)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + + if isinstance(sample, HeteroSamplerOutput): + raise NotImplementedError + + # NOTE: torch_geometric.loader.utils.filter_custom_store can be used + # here if it supported edge features + node_id = sample.node + edge_id = sample.edge + edge_index = torch.stack((sample.row, sample.col), dim=0) + x = self.x[node_id] + edge_attr = self.edge_attr[edge_id] + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + node_idx=node_id, edge_idx=edge_id) + + +# TODO: Refactor because composition >> inheritance + + +def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, + device: torch.device, batch_size: int = 2**20): + """Add new features to the existing KNN index in batches. + + Args: + knn_index (ApproxMIPSKNNIndex): Index to add features to. + emb (Tensor): Embeddings to add. + device (torch.device): Device to store in + batch_size (int, optional): Batch size to iterate by. + Defaults to 2**20, which equates to 4GB if working with + 1024 dim floats. + """ + for i in range(0, emb.size(0), batch_size): + if emb.size(0) - i >= batch_size: + emb_batch = emb[i:i + batch_size].to(device) + else: + emb_batch = emb[i:].to(device) + knn_index.add(emb_batch) + + +class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + # TODO: Add kwargs for approx KNN to parameters here. + super().__init__(enc_model, model_kwargs, *args, **kwargs) + self.node_knn_index = None + self.edge_knn_index = None + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.node_knn_index is None: + self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.node_knn_index, self.x, + self.device) + + output = self.node_knn_index.search(query_enc, k=k_nodes) + yield from output.index + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.edge_knn_index is None: + self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, + self.device) + + output = self.edge_knn_index.search(query_enc, k=k_edges) + yield from output.index + + +# TODO: These two classes should be refactored +class SentenceTransformerFeatureStore(KNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) + + +class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py new file mode 100644 index 000000000000..896fbd7598b1 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_generate.py @@ -0,0 +1,139 @@ +# %% +import argparse +from itertools import chain +from typing import Tuple + +import pandas as pd +import torch +import tqdm +from rag_backend_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore + +from torch_geometric.data import Data +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +# %% +parser = argparse.ArgumentParser( + description="""Generate new WebQSP subgraphs\n""" + + """NOTE: Evaluating with smaller samples may result in""" + + """ poorer performance for the trained models compared""" + + """ to untrained models.""") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--use_pcst", action="store_true") +parser.add_argument("--num_samples", type=int, default=4700) +parser.add_argument("--out_file", default="subg_results.pt") +args = parser.parse_args() + +# %% +ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True, + force_reload=True) + +# %% +triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset) + +# %% +questions = ds.raw_dataset['question'] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]: + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + desc = ( + textual_nodes.to_csv(index=False) + "\n" + + textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"])) + graph["desc"] = desc + return graph + + +transform = apply_retrieval_via_pcst \ + if args.use_pcst else apply_retrieval_with_text + +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5}, + seed_edges_kwargs={"k_edges": 5}, + sampler_kwargs={"num_neighbors": [50] * 2}, + local_filter=transform) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% +retrieval_stats = {"precision": [], "recall": [], "accuracy": []} +subgs = [] +node_len = [] +edge_len = [] +for subg in tqdm.tqdm(query_loader.query(q) for q in questions): + subgs.append(subg) + node_len.append(subg['x'].shape[0]) + edge_len.append(subg['edge_attr'].shape[0]) + +for i, subg in enumerate(subgs): + subg['question'] = questions[i] + subg['label'] = ds[i]['label'] + +pd.DataFrame.from_dict(retrieval_stats).to_csv( + args.out_file.split('.')[0] + '_metadata.csv') +torch.save(subgs, args.out_file) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py new file mode 100644 index 000000000000..48473f287233 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_graph_store.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import FeatureStore +from torch_geometric.distributed import LocalGraphStore +from torch_geometric.sampler import ( + HeteroSamplerOutput, + NeighborSampler, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.sampler.neighbor_sampler import NumNeighborsType +from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes + + +class NeighborSamplingRAGGraphStore(LocalGraphStore): + def __init__(self, feature_store: Optional[FeatureStore] = None, + num_neighbors: NumNeighborsType = [1], **kwargs): + self.feature_store = feature_store + self._num_neighbors = num_neighbors + self.sample_kwargs = kwargs + self._sampler_is_initialized = False + super().__init__() + + def _init_sampler(self): + if self.feature_store is None: + raise AttributeError("Feature store not registered yet.") + self.sampler = NeighborSampler(data=(self.feature_store, self), + num_neighbors=self._num_neighbors, + **self.sample_kwargs) + self._sampler_is_initialized = True + + def register_feature_store(self, feature_store: FeatureStore): + self.feature_store = feature_store + self._sampler_is_initialized = False + + def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: + ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) + self._sampler_is_initialized = False + return ret + + @property + def edge_index(self): + return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) + + def put_edge_index(self, edge_index: EdgeTensorType, *args, + **kwargs) -> bool: + ret = super().put_edge_index(edge_index, *args, **kwargs) + # HACK + self.edge_idx_args = args + self.edge_idx_kwargs = kwargs + self._sampler_is_initialized = False + return ret + + @property + def num_neighbors(self): + return self._num_neighbors + + @num_neighbors.setter + def num_neighbors(self, num_neighbors: NumNeighborsType): + self._num_neighbors = num_neighbors + if hasattr(self, 'sampler'): + self.sampler.num_neighbors = num_neighbors + + def sample_subgraph( + self, seed_nodes: InputNodes, seed_edges: InputEdges, + num_neighbors: Optional[NumNeighborsType] = None + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample the graph starting from the given nodes and edges using the + in-built NeighborSampler. + + Args: + seed_nodes (InputNodes): Seed nodes to start sampling from. + seed_edges (InputEdges): Seed edges to start sampling from. + num_neighbors (Optional[NumNeighborsType], optional): Parameters + to determine how many hops and number of neighbors per hop. + Defaults to None. + + Returns: + Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput + for the input. + """ + if not self._sampler_is_initialized: + self._init_sampler() + if num_neighbors is not None: + self.num_neighbors = num_neighbors + + # FIXME: Right now, only input nodes/edges as tensors are be supported + if not isinstance(seed_nodes, Tensor): + raise NotImplementedError + if not isinstance(seed_edges, Tensor): + raise NotImplementedError + device = seed_nodes.device + + # TODO: Call sample_from_edges for seed_edges + # Turning them into nodes for now. + seed_edges = self.edge_index.to(device).T[seed_edges.to( + device)].reshape(-1) + seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) + + seed_nodes = seed_nodes.unique().contiguous() + node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) + out = self.sampler.sample_from_nodes(node_sample_input) + + return out diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md new file mode 100644 index 000000000000..ff43b16a2c05 --- /dev/null +++ b/examples/llm/multihop_rag/README.md @@ -0,0 +1,9 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. | +| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. | +| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. | + +NOTE: Performance of GRetriever on this dataset has not been evaluated. diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh new file mode 100644 index 000000000000..3c1970d39440 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_download.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Wikidata5m + +wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz" +tar -xvf "wikidata5m_alias.tar.gz" +wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz" +gzip -d "wikidata5m_all_triplet.txt.gz" -f + +# 2Multihopqa +wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip" +unzip -o "data_ids_april7.zip" diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py new file mode 100644 index 000000000000..46052bdf1b15 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_preprocess.py @@ -0,0 +1,276 @@ +"""Example workflow for downloading and assembling a multihop QA dataset.""" + +import argparse +import json +from subprocess import call + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import LargeGraphIndexer + +# %% [markdown] +# # Encoding A Large Knowledge Graph Part 2 + +# %% [markdown] +# In this notebook, we will continue where we left off by building a new +# multi-hop QA dataset based on Wikidata. + +# %% [markdown] +# ## Example 2: Building a new Dataset from Questions and an already-existing +# Knowledge Graph + +# %% [markdown] +# ### Motivation + +# %% [markdown] +# One potential application of knowledge graph structural encodings is +# capturing the relationships between different entities that are multiple +# hops apart. This can be challenging for an LLM to recognize from prepended +# graph information. Here's a motivating example (credit to @Rishi Puri): + +# %% [markdown] +# In this example, the question can only be answered by reasoning about the +# relationships between the entities in the knowledge graph. + +# %% [markdown] +# ### Building a Multi-Hop QA Dataset + +# %% [markdown] +# To start, we need to download the raw data of a knowledge graph. +# In this case, we use WikiData5M +# ([Wang et al] +# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)). +# Here we download the raw triplets and their entity codes. Information about +# this dataset can be found +# [here](https://deepgraphlearning.github.io/project/wikidata5m). + +# %% [markdown] +# The following download contains the ID to plaintext mapping for all the +# entities and relations in the knowledge graph: + +rv = call("./multihop_download.sh") + +# %% [markdown] +# To start, we are going to preprocess the knowledge graph to substitute each +# of the entity/relation codes with their plaintext aliases. This makes it +# easier to use a pre-trained textual encoding model to create triplet +# embeddings, as such a model likely won't understand how to properly embed +# the entity codes. + +# %% + +# %% +parser = argparse.ArgumentParser(description="Preprocess wikidata5m") +parser.add_argument("--n_triplets", type=int, default=-1) +args = parser.parse_args() + +# %% +# Substitute entity codes with their aliases +# Picking the first alias for each entity (rather arbitrarily) +alias_map = {} +rel_alias_map = {} +for line in open('wikidata5m_entity.txt'): + parts = line.strip().split('\t') + entity_id = parts[0] + aliases = parts[1:] + alias_map[entity_id] = aliases[0] +for line in open('wikidata5m_relation.txt'): + parts = line.strip().split('\t') + relation_id = parts[0] + relation_name = parts[1] + rel_alias_map[relation_id] = relation_name + +# %% +full_graph = [] +missing_total = 0 +total = 0 +limit = None if args.n_triplets == -1 else args.n_triplets +i = 0 + +for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')): + if limit is not None and i >= limit: + break + src, rel, dst = line.strip().split('\t') + if src not in alias_map: + missing_total += 1 + if dst not in alias_map: + missing_total += 1 + if rel not in rel_alias_map: + missing_total += 1 + total += 3 + full_graph.append([ + alias_map.get(src, src), + rel_alias_map.get(rel, rel), + alias_map.get(dst, dst) + ]) + i += 1 +print(f"Missing aliases: {missing_total}/{total}") + +# %% [markdown] +# Now `full_graph` represents the knowledge graph triplets in +# understandable plaintext. + +# %% [markdown] +# Next, we need a set of multi-hop questions that the Knowledge Graph will +# provide us with context for. We utilize a subset of +# [HotPotQA](https://hotpotqa.github.io/) +# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called +# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop) +# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)), +# which includes a subgraph of entities that serve as the ground truth +# justification for answering each multi-hop question: + +# %% +with open('train.json') as f: + train_data = json.load(f) +train_df = pd.DataFrame(train_data) +train_df['split_type'] = 'train' + +with open('dev.json') as f: + dev_data = json.load(f) +dev_df = pd.DataFrame(dev_data) +dev_df['split_type'] = 'dev' + +with open('test.json') as f: + test_data = json.load(f) +test_df = pd.DataFrame(test_data) +test_df['split_type'] = 'test' + +df = pd.concat([train_df, dev_df, test_df]) + +# %% [markdown] +# Now we need to extract the subgraphs + +# %% +df['graph_size'] = df['evidences_id'].apply(lambda row: len(row)) + +# %% [markdown] +# (Optional) We take only questions where the evidence graph is greater than +# 0. (Note: this gets rid of the test set): + +# %% +# df = df[df['graph_size'] > 0] + +# %% +refined_df = df[[ + '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type', + 'graph_size' +]] + +# %% [markdown] +# Checkpoint: + +# %% +refined_df.to_csv('wikimultihopqa_refined.csv', index=False) + +# %% [markdown] +# Now we need to check that all the entities mentioned in the question/answer +# set are also present in the Wikidata graph: + +# %% +relation_map = {} +with open('wikidata5m_relation.txt') as f: + for line in tqdm.tqdm(f): + parts = line.strip().split('\t') + for i in range(1, len(parts)): + if parts[i] not in relation_map: + relation_map[parts[i]] = [] + relation_map[parts[i]].append(parts[0]) + +# %% +entity_set = set() +with open('wikidata5m_entity.txt') as f: + for line in tqdm.tqdm(f): + entity_set.add(line.strip().split('\t')[0]) + +# %% +missing_entities = set() +missing_entity_idx = set() +for i, row in enumerate(refined_df.itertuples()): + for trip in row.evidences_id: + entities = trip[0], trip[2] + for entity in entities: + if entity not in entity_set: + # print( + # f'The following entity was not found in the KG: {entity}' + # ) + missing_entities.add(entity) + missing_entity_idx.add(i) + +# %% [markdown] +# Right now, we drop the missing entity entries. Additional preprocessing can +# be done here to resolve the entity/relation collisions, but that is out of +# the scope for this notebook. + +# %% +# missing relations are ok, but missing entities cannot be mapped to +# plaintext, so they should be dropped. +refined_df.reset_index(inplace=True, drop=True) + +# %% +cleaned_df = refined_df.drop(missing_entity_idx) + +# %% [markdown] +# Now we save the resulting graph and questions/answers dataset: + +# %% +cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False) + +# %% + +# %% +torch.save(full_graph, 'wikimultihopqa_full_graph.pt') + +# %% [markdown] +# ### Question: How do we extract a contextual subgraph for a given query? + +# %% [markdown] +# The chosen retrieval algorithm is a critical component in the pipeline for +# affecting RAG performance. In the next section (1), we will demonstrate a +# naive method of retrieval for a large knowledge graph, and how to apply it +# to this dataset along with WebQSP. + +# %% [markdown] +# ### Preparing a Textualized Graph for LLM + +# %% [markdown] +# For now however, we need to prepare the graph data to be used as a plaintext +# prefix to the LLM. In order to do this, we want to prompt the LLM to use the +# unique nodes, and unique edge triplets of a given subgraph. In order to do +# this, we prepare a unique indexed node df and edge df for the knowledge +# graph now. This process occurs trivially with the LargeGraphIndexer: + +# %% + +# %% +indexer = LargeGraphIndexer.from_triplets(full_graph) + +# %% +# Node DF +textual_nodes = pd.DataFrame.from_dict( + {"node_attr": indexer.get_node_features()}) +textual_nodes["node_id"] = textual_nodes.index +textual_nodes = textual_nodes[["node_id", "node_attr"]] + +# %% [markdown] +# Notice how LargeGraphIndexer ensures that there are no duplicate indices: + +# %% +# Edge DF +textual_edges = pd.DataFrame(indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) +textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]] +textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]] + +# %% [markdown] +# Note: The edge table refers to each node by its index in the node table. +# We will see how this gets utilized later when indexing a subgraph. + +# %% [markdown] +# Now we can save the result + +# %% +textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False) +textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False) diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py new file mode 100644 index 000000000000..de93a9e75dd1 --- /dev/null +++ b/examples/llm/multihop_rag/rag_generate_multihop.py @@ -0,0 +1,88 @@ +# %% +import argparse +import sys +from typing import Tuple + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import Data +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +sys.path.append('..') + +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerApproxFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +parser = argparse.ArgumentParser( + description="Generate new multihop dataset for rag") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--num_samples", type=int) +args = parser.parse_args() + +# %% +triplets = torch.load('wikimultihopqa_full_graph.pt') + +# %% +df = pd.read_csv('wikimultihopqa_cleaned.csv') +questions = df['question'][:args.num_samples] +labels = df['answer'][:args.num_samples] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerApproxFeatureStore).load() + +# %% + +all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv') +all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv') + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +# %% +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, + seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": [40] * 3}, + local_filter=apply_retrieval_via_pcst) + +# %% +subgs = [] +for q, l in tqdm.tqdm(zip(questions, labels)): + subg = query_loader.query(q) + subg['question'] = q + subg['label'] = l + subgs.append(subg) + +torch.save(subgs, 'subg_results.pt') diff --git a/examples/llm/nvtx_examples/README.md b/examples/llm/nvtx_examples/README.md new file mode 100644 index 000000000000..aa4f070d9824 --- /dev/null +++ b/examples/llm/nvtx_examples/README.md @@ -0,0 +1,7 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| [`nvtx_run.sh`](./nvtx_run.sh) | Runs nsys profiler on a given Python file that contains NVTX calls. | +| [`nvtx_rag_backend_example.py`](./nvtx_rag_backend_example.py) | Example script for nsys profiling a RAG Backend such as that used in [`rag_generate.py`](../g_retriever_utils/rag_generate.py). | +| [`nvtx_webqsp_example.py`](./nvtx_webqsp_example.py) | Example script for nsys profiling the WebQSP dataset. | diff --git a/examples/llm/nvtx_examples/nvtx_rag_backend_example.py b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py new file mode 100644 index 000000000000..b30e34b8c7b1 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py @@ -0,0 +1,144 @@ +# %% +import argparse +import sys +from itertools import chain +from typing import Tuple + +import torch + +from torch_geometric.data import Data, get_features_for_triplets_groups +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import rag_loader +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.profile.nvtx import nvtxit + +sys.path.append('..') +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +# Patch FeatureStore and GraphStore + +SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_nodes) +SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_edges) +SentenceTransformerFeatureStore.load_subgraph = nvtxit()( + SentenceTransformerFeatureStore.load_subgraph) +NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()( + NeighborSamplingRAGGraphStore.sample_subgraph) +rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +# %% +ds = WebQSPDataset("small_ds_1", force_reload=True, limit=10) + +# %% +triplets = list(chain.from_iterable(d['graph'] for d in ds.raw_dataset)) + +# %% +questions = ds.raw_dataset['question'] + +# %% +ground_truth_graphs = get_features_for_triplets_groups( + ds.indexer, (d['graph'] for d in ds.raw_dataset), + pre_transform=preprocess_triplet) +num_edges = len(ds.indexer._edges) + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer('sentence-transformers/all-roberta-large-v1').to( + device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +@nvtxit() +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return graph + + +# %% +query_loader = rag_loader.RAGQueryLoader( + data=(fs, gs), seed_nodes_kwargs={"k_nodes": + 10}, seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": + [40] * 10}, local_filter=apply_retrieval_via_pcst) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% + + +@nvtxit() +def _run_eval(): + for subg, gt in zip((query_loader.query(q) for q in questions), + ground_truth_graphs): + print(check_retrieval_accuracy(subg, gt, num_edges), + check_retrieval_precision(subg, gt), + check_retrieval_recall(subg, gt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + _run_eval() + else: + _run_eval() diff --git a/examples/llm/nvtx_examples/nvtx_run.sh b/examples/llm/nvtx_examples/nvtx_run.sh new file mode 100644 index 000000000000..4c6fce7c8224 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_run.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# Check if the user provided a Python file +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Check if the provided file exists +if [[ ! -f "$1" ]]; then + echo "Error: File '$1' does not exist." + exit 1 +fi + +# Check if the provided file is a Python file +if [[ ! "$1" == *.py ]]; then + echo "Error: '$1' is not a Python file." + exit 1 +fi + +# Get the base name of the Python file +python_file=$(basename "$1") + +# Run nsys profile on the Python file +nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" + +echo "Profile data saved as profile_${python_file%.py}.nsys-rep" diff --git a/examples/llm/nvtx_examples/nvtx_webqsp_example.py b/examples/llm/nvtx_examples/nvtx_webqsp_example.py new file mode 100644 index 000000000000..5a9aad27f1c0 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_webqsp_example.py @@ -0,0 +1,22 @@ +import argparse + +import torch + +from torch_geometric.datasets import web_qsp_dataset +from torch_geometric.profile import nvtxit + +# Apply Patches +web_qsp_dataset.retrieval_via_pcst = nvtxit()( + web_qsp_dataset.retrieval_via_pcst) +web_qsp_dataset.WebQSPDataset.process = nvtxit()( + web_qsp_dataset.WebQSPDataset.process) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') + else: + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') diff --git a/test/data/test_large_graph_indexer.py b/test/data/test_large_graph_indexer.py new file mode 100644 index 000000000000..b98fe7d7ddbf --- /dev/null +++ b/test/data/test_large_graph_indexer.py @@ -0,0 +1,177 @@ +import random +import string +from typing import List + +import pytest +import torch + +from torch_geometric.data import ( + Data, + LargeGraphIndexer, + TripletLike, + get_features_for_triplets, +) +from torch_geometric.data.large_graph_indexer import ( + EDGE_PID, + EDGE_RELATION, + NODE_PID, +) +from torch_geometric.typing import WITH_PT20 + +# create possible nodes and edges for graph +strkeys = string.ascii_letters + string.digits +NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)}) +EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)}) + + +def featurize(s: str) -> int: + return int.from_bytes(s.encode(), 'little') + + +def sample_triplets(amount: int = 1) -> List[TripletLike]: + trips = [] + for i in range(amount): + h, t = random.sample(NODE_POOL, k=2) + r = random.sample(EDGE_POOL, k=1)[0] + trips.append(tuple([h, r, t])) + return trips + + +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return h.lower(), r, t.lower() + + +def test_basic_collate(): + graphs = [sample_triplets(1000) for i in range(2)] + + indexer_0 = LargeGraphIndexer.from_triplets( + graphs[0], pre_transform=preprocess_triplet) + indexer_1 = LargeGraphIndexer.from_triplets( + graphs[1], pre_transform=preprocess_triplet) + + big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1]) + + assert len(indexer_0._nodes) + len( + indexer_1._nodes) - len(indexer_0._nodes.keys() + & indexer_1._nodes.keys()) == len( + big_indexer._nodes) + assert len(indexer_0._edges) + len( + indexer_1._edges) - len(indexer_0._edges.keys() + & indexer_1._edges.keys()) == len( + big_indexer._edges) + + assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes) + assert len(set(big_indexer._edges.values())) == len(big_indexer._edges) + + for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()): + assert big_indexer.node_attr[NODE_PID][ + big_indexer._nodes[node]] == node + + +def test_large_graph_index(): + graphs = [sample_triplets(1000) for i in range(100)] + + # Preprocessing of trips lowercases nodes but not edges + node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + def encode_graph_from_trips(triplets: List[TripletLike]) -> Data: + seen_nodes = dict() + edge_attrs = list() + edge_idx = [] + for trip in triplets: + trip = preprocess_triplet(trip) + h, r, t = trip + seen_nodes[h] = len( + seen_nodes) if h not in seen_nodes else seen_nodes[h] + seen_nodes[t] = len( + seen_nodes) if t not in seen_nodes else seen_nodes[t] + edge_attrs.append(edge_feature_vecs[r]) + edge_idx.append((seen_nodes[h], seen_nodes[t])) + + x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()]) + edge_idx = torch.LongTensor(edge_idx).T + edge_attrs = torch.Tensor(edge_attrs) + return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs) + + naive_graph_ds = [ + encode_graph_from_trips(triplets=trips) for trips in graphs + ] + + indexer = LargeGraphIndexer.collate([ + LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet) + for g in graphs + ]) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + large_graph_ds = [ + get_features_for_triplets(indexer=indexer, triplets=g, + node_feature_name='x', + edge_feature_name='edge_attr', + pre_transform=preprocess_triplet) + for g in graphs + ] + + for ds in large_graph_ds: + assert NODE_PID in ds + assert EDGE_PID in ds + assert "node_idx" in ds + assert "edge_idx" in ds + + def results_are_close_enough(ground_truth: Data, new_method: Data, + thresh=.99): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all( + torch.isclose(tensor1.sort()[0], + tensor2.sort()[0]) > thresh) + + def _graphs_are_same(tensor1, tensor2): + if not WITH_PT20: + pytest.skip( + "This test requires a PyG version with NetworkX as a " + + "dependency.") + import networkx as nx + return nx.weisfeiler_lehman_graph_hash(nx.Graph( + tensor1.T)) == nx.weisfeiler_lehman_graph_hash( + nx.Graph(tensor2.T)) + return True + return _sorted_tensors_are_close( + ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close( + ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same( + ground_truth.edge_index, new_method.edge_index) + + for dsets in zip(naive_graph_ds, large_graph_ds): + assert results_are_close_enough(*dsets) + + +def test_save_load(tmp_path): + graph = sample_triplets(1000) + + node_feature_vecs = {s: featurize(s) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + indexer = LargeGraphIndexer.from_triplets(graph) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + + indexer.save(str(tmp_path)) + assert indexer == LargeGraphIndexer.from_disk(str(tmp_path)) diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py index 899e70730cc9..24a74d1b6f6e 100644 --- a/test/nn/models/test_g_retriever.py +++ b/test/nn/models/test_g_retriever.py @@ -51,3 +51,52 @@ def test_g_retriever() -> None: # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_g_retriever_many_tokens() -> None: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.float16, + ) + + gnn = GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=1024, + num_layers=2, + heads=4, + norm='batch_norm', + ) + + model = GRetriever( + llm=llm, + gnn=gnn, + mlp_out_channels=2048, + mlp_out_tokens=2, + ) + assert str(model) == ('GRetriever(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' gnn=GAT(1024, 1024, num_layers=2),\n' + ')') + + x = torch.randn(10, 1024) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 1024) + batch = torch.zeros(x.size(0), dtype=torch.long) + + question = ["Is PyG the best open-source GNN library?"] + label = ["yes!"] + + # Test train: + loss = model(question, x, edge_index, batch, label, edge_attr) + assert loss >= 0 + + # Test inference: + pred = model.inference(question, x, edge_index, batch, edge_attr) + assert len(pred) == 1 diff --git a/test/profile/test_nvtx.py b/test/profile/test_nvtx.py new file mode 100644 index 000000000000..56e28a9c2e59 --- /dev/null +++ b/test/profile/test_nvtx.py @@ -0,0 +1,136 @@ +from unittest.mock import call, patch + +from torch_geometric.profile import nvtxit + + +def _setup_mock(torch_cuda_mock): + torch_cuda_mock.is_available.return_value = True + torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None + torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None + return torch_cuda_mock + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_base(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit() + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_rename(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit('a_nvtx') + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('a_nvtx_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_iters(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_iters=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_b_0'), call('call_a_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_warmups(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_warmups=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_1') + ] diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index 821ef9c5c063..fee215b1a357 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -16,6 +16,7 @@ from .makedirs import makedirs from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz +from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups from torch_geometric.lazy_loader import LazyLoader @@ -27,6 +28,8 @@ 'Dataset', 'InMemoryDataset', 'OnDiskDataset', + 'LargeGraphIndexer', + 'TripletLike', ] remote_backend_classes = [ @@ -50,6 +53,8 @@ 'extract_zip', 'extract_bz2', 'extract_gz', + 'get_features_for_triplets', + "get_features_for_triplets_groups", ] __all__ = data_classes + remote_backend_classes + helper_functions diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py new file mode 100644 index 000000000000..0644e2543303 --- /dev/null +++ b/torch_geometric/data/large_graph_indexer.py @@ -0,0 +1,677 @@ +import os +import pickle as pkl +import shutil +from dataclasses import dataclass +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import Data +from torch_geometric.typing import WITH_PT24 + +TripletLike = Tuple[Hashable, Hashable, Hashable] + +KnowledgeGraphLike = Iterable[TripletLike] + + +def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: + return list(dict.fromkeys(values)) + + +# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum? + +NODE_PID = "pid" + +NODE_KEYS = {NODE_PID} + +EDGE_PID = "e_pid" +EDGE_HEAD = "h" +EDGE_RELATION = "r" +EDGE_TAIL = "t" +EDGE_INDEX = "edge_idx" + +EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX} + +FeatureValueType = Union[Sequence[Any], Tensor] + + +@dataclass +class MappedFeature: + name: str + values: FeatureValueType + + def __eq__(self, value: "MappedFeature") -> bool: + eq = self.name == value.name + if isinstance(self.values, torch.Tensor): + eq &= torch.equal(self.values, value.values) + else: + eq &= self.values == value.values + return eq + + +if WITH_PT24: + torch.serialization.add_safe_globals([MappedFeature]) + + +class LargeGraphIndexer: + """For a dataset that consists of mulitiple subgraphs that are assumed to + be part of a much larger graph, collate the values into a large graph store + to save resources. + """ + def __init__( + self, + nodes: Iterable[Hashable], + edges: KnowledgeGraphLike, + node_attr: Optional[Dict[str, List[Any]]] = None, + edge_attr: Optional[Dict[str, List[Any]]] = None, + ) -> None: + r"""Constructs a new index that uniquely catalogs each node and edge + by id. Not meant to be used directly. + + Args: + nodes (Iterable[Hashable]): Node ids in the graph. + edges (KnowledgeGraphLike): Edge ids in the graph. + node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node + attribute name and list of their values in order of unique node + ids. Defaults to None. + edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge + attribute name and list of their values in order of unique edge + ids. Defaults to None. + """ + self._nodes: Dict[Hashable, int] = dict() + self._edges: Dict[TripletLike, int] = dict() + + self._mapped_node_features: Set[str] = set() + self._mapped_edge_features: Set[str] = set() + + if len(nodes) != len(set(nodes)): + raise AttributeError("Nodes need to be unique") + if len(edges) != len(set(edges)): + raise AttributeError("Edges need to be unique") + + if node_attr is not None: + # TODO: Validity checks btw nodes and node_attr + self.node_attr = node_attr + if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS: + raise AttributeError( + "Invalid node_attr object. Missing " + + f"{NODE_KEYS - set(self.node_attr.keys())}") + elif self.node_attr[NODE_PID] != nodes: + raise AttributeError( + "Nodes provided do not match those in node_attr") + else: + self.node_attr = dict() + self.node_attr[NODE_PID] = nodes + + for i, node in enumerate(self.node_attr[NODE_PID]): + self._nodes[node] = i + + if edge_attr is not None: + # TODO: Validity checks btw edges and edge_attr + self.edge_attr = edge_attr + + if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS: + raise AttributeError( + "Invalid edge_attr object. Missing " + + f"{EDGE_KEYS - set(self.edge_attr.keys())}") + elif self.node_attr[EDGE_PID] != edges: + raise AttributeError( + "Edges provided do not match those in edge_attr") + + else: + self.edge_attr = dict() + for default_key in EDGE_KEYS: + self.edge_attr[default_key] = list() + self.edge_attr[EDGE_PID] = edges + + for i, tup in enumerate(edges): + h, r, t = tup + self.edge_attr[EDGE_HEAD].append(h) + self.edge_attr[EDGE_RELATION].append(r) + self.edge_attr[EDGE_TAIL].append(t) + self.edge_attr[EDGE_INDEX].append( + (self._nodes[h], self._nodes[t])) + + for i, tup in enumerate(edges): + self._edges[tup] = i + + @classmethod + def from_triplets( + cls, + triplets: KnowledgeGraphLike, + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + ) -> "LargeGraphIndexer": + r"""Generate a new index from a series of triplets that represent edge + relations between nodes. + Formatted like (source_node, edge, dest_node). + + Args: + triplets (KnowledgeGraphLike): Series of triplets representing + knowledge graph relations. + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function to apply to triplets. + Defaults to None. + + Returns: + LargeGraphIndexer: Index of unique nodes and edges. + """ + # NOTE: Right now assumes that all trips can be loaded into memory + nodes = set() + edges = set() + + if pre_transform is not None: + + def apply_transform( + trips: KnowledgeGraphLike) -> Iterator[TripletLike]: + for trip in trips: + yield pre_transform(trip) + + triplets = apply_transform(triplets) + + for h, r, t in triplets: + + for node in (h, t): + nodes.add(node) + + edge_idx = (h, r, t) + edges.add(edge_idx) + + return cls(list(nodes), list(edges)) + + @classmethod + def collate(cls, + graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer": + r"""Combines a series of large graph indexes into a single large graph + index. + + Args: + graphs (Iterable["LargeGraphIndexer"]): Indices to be + combined. + + Returns: + LargeGraphIndexer: Singular unique index for all nodes and edges + in input indices. + """ + # FIXME Needs to merge node attrs and edge attrs? + trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) + return cls.from_triplets(trips) + + def get_unique_node_features( + self, feature_name: str = NODE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific node attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to NODE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_node_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_node_features(feature_name)) + + except KeyError: + raise AttributeError( + f"Nodes do not have a feature called {feature_name}") + + def add_node_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = NODE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique node in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that + new feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to NODE_PID. + """ + if new_feature_name in self.node_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_node_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_node_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + "Expected encodings for {len(feature_keys)} unique features," + + f" but got {len(new_feature_vals)} encodings.") + + if map_from_feature == NODE_PID: + self.node_attr[new_feature_name] = new_feature_vals + else: + self.node_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_node_features.add(new_feature_name) + + def get_node_features( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get node feature values for a given set of unique node ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. Defaults + to NODE_PID. + pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + for. Defaults to None, which fetches all nodes. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_node_features: + values = self.node_attr[feature_name].values + else: + values = self.node_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_node_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_node_features_iter(feature_name, pids)) + + def get_node_features_iter( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_node_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.node_attr[NODE_PID] + + if feature_name in self._mapped_node_features: + feature_map_info = self.node_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_node_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._nodes[pid] + from_feature_val = self.node_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._nodes[pid] + if index_only: + yield idx + else: + yield self.node_attr[feature_name][idx] + + def get_unique_edge_features( + self, feature_name: str = EDGE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific edge attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to EDGE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_edge_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_edge_features(feature_name)) + except KeyError: + raise AttributeError( + f"Edges do not have a feature called {feature_name}") + + def add_edge_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = EDGE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique edge in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that new + feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to EDGE_PID. + """ + if new_feature_name in self.edge_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_edge_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_edge_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + f"Expected encodings for {len(feature_keys)} unique features, " + + f"but got {len(new_feature_vals)} encodings.") + + if map_from_feature == EDGE_PID: + self.edge_attr[new_feature_name] = new_feature_vals + else: + self.edge_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_edge_features.add(new_feature_name) + + def get_edge_features( + self, + feature_name: str = EDGE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get edge feature values for a given set of unique edge ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. + Defaults to EDGE_PID. + pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + for. Defaults to None, which fetches all edges. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_edge_features: + values = self.edge_attr[feature_name].values + else: + values = self.edge_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_edge_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_edge_features_iter(feature_name, pids)) + + def get_edge_features_iter( + self, + feature_name: str = EDGE_PID, + pids: Optional[KnowledgeGraphLike] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_edge_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.edge_attr[EDGE_PID] + + if feature_name in self._mapped_edge_features: + feature_map_info = self.edge_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_edge_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._edges[pid] + from_feature_val = self.edge_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._edges[pid] + if index_only: + yield idx + else: + yield self.edge_attr[feature_name][idx] + + def to_triplets(self) -> Iterator[TripletLike]: + return iter(self.edge_attr[EDGE_PID]) + + def save(self, path: str) -> None: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + with open(path + "/edges", "wb") as f: + pkl.dump(self._edges, f) + with open(path + "/nodes", "wb") as f: + pkl.dump(self._nodes, f) + + with open(path + "/mapped_edges", "wb") as f: + pkl.dump(self._mapped_edge_features, f) + with open(path + "/mapped_nodes", "wb") as f: + pkl.dump(self._mapped_node_features, f) + + node_attr_path = path + "/node_attr" + os.makedirs(node_attr_path, exist_ok=True) + for attr_name, vals in self.node_attr.items(): + torch.save(vals, node_attr_path + f"/{attr_name}.pt") + + edge_attr_path = path + "/edge_attr" + os.makedirs(edge_attr_path, exist_ok=True) + for attr_name, vals in self.edge_attr.items(): + torch.save(vals, edge_attr_path + f"/{attr_name}.pt") + + @classmethod + def from_disk(cls, path: str) -> "LargeGraphIndexer": + indexer = cls(list(), list()) + with open(path + "/edges", "rb") as f: + indexer._edges = pkl.load(f) + with open(path + "/nodes", "rb") as f: + indexer._nodes = pkl.load(f) + + with open(path + "/mapped_edges", "rb") as f: + indexer._mapped_edge_features = pkl.load(f) + with open(path + "/mapped_nodes", "rb") as f: + indexer._mapped_node_features = pkl.load(f) + + node_attr_path = path + "/node_attr" + for fname in os.listdir(node_attr_path): + full_fname = f"{node_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.node_attr[key] = torch.load(full_fname) + + edge_attr_path = path + "/edge_attr" + for fname in os.listdir(edge_attr_path): + full_fname = f"{edge_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.edge_attr[key] = torch.load(full_fname) + + return indexer + + def to_data(self, node_feature_name: str, + edge_feature_name: Optional[str] = None) -> Data: + """Return a Data object containing all the specified node and + edge features and the graph. + + Args: + node_feature_name (str): Feature to use for nodes + edge_feature_name (Optional[str], optional): Feature to use for + edges. Defaults to None. + + Returns: + Data: Data object containing the specified node and + edge features and the graph. + """ + x = torch.Tensor(self.get_node_features(node_feature_name)) + node_id = torch.LongTensor(range(len(x))) + + edge_index = torch.t( + torch.LongTensor(self.get_edge_features(EDGE_INDEX))) + + edge_attr = (self.get_edge_features(edge_feature_name) + if edge_feature_name is not None else None) + edge_id = torch.LongTensor(range(len(edge_attr))) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + edge_id=edge_id, node_id=node_id) + + def __eq__(self, value: "LargeGraphIndexer") -> bool: + eq = True + eq &= self._nodes == value._nodes + eq &= self._edges == value._edges + eq &= self.node_attr.keys() == value.node_attr.keys() + eq &= self.edge_attr.keys() == value.edge_attr.keys() + eq &= self._mapped_node_features == value._mapped_node_features + eq &= self._mapped_edge_features == value._mapped_edge_features + + for k in self.node_attr: + eq &= isinstance(self.node_attr[k], type(value.node_attr[k])) + if isinstance(self.node_attr[k], torch.Tensor): + eq &= torch.equal(self.node_attr[k], value.node_attr[k]) + else: + eq &= self.node_attr[k] == value.node_attr[k] + for k in self.edge_attr: + eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k])) + if isinstance(self.edge_attr[k], torch.Tensor): + eq &= torch.equal(self.edge_attr[k], value.edge_attr[k]) + else: + eq &= self.edge_attr[k] == value.edge_attr[k] + return eq + + +def get_features_for_triplets_groups( + indexer: LargeGraphIndexer, + triplet_groups: Iterable[KnowledgeGraphLike], + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Iterator[Data]: + """Given an indexer and a series of triplet groups (like a dataset), + retrieve the specified node and edge features for each triplet from the + index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of + triplets to fetch features for + node_feature_name (str, optional): Node feature to fetch. + Defaults to "x". + edge_feature_name (str, optional): edge feature to fetch. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing to perform on triplets. + Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Yields: + Iterator[Data]: For each triplet group, yield a data object containing + the unique graph and features from the index. + """ + if pre_transform is not None: + + def apply_transform(trips): + for trip in trips: + yield pre_transform(tuple(trip)) + + # TODO: Make this safe for large amounts of triplets? + triplet_groups = (list(apply_transform(triplets)) + for triplets in triplet_groups) + + node_keys = [] + edge_keys = [] + edge_index = [] + + for triplets in tqdm(triplet_groups, disable=not verbose): + small_graph_indexer = LargeGraphIndexer.from_triplets( + triplets, pre_transform=pre_transform) + + node_keys.append(small_graph_indexer.get_node_features()) + edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets)) + edge_index.append( + small_graph_indexer.get_edge_features(EDGE_INDEX, triplets)) + + node_feats = indexer.get_node_features(feature_name=node_feature_name, + pids=chain.from_iterable(node_keys)) + edge_feats = indexer.get_edge_features(feature_name=edge_feature_name, + pids=chain.from_iterable(edge_keys)) + + last_node_idx, last_edge_idx = 0, 0 + for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index): + nlen, elen = len(nkeys), len(ekeys) + x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen]) + last_node_idx += len(nkeys) + + edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx + + elen]) + last_edge_idx += len(ekeys) + + edge_idx = torch.LongTensor(eidx).T + + data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) + data_obj[NODE_PID] = node_keys + data_obj[EDGE_PID] = edge_keys + data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys] + data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys] + + yield data_obj + + +def get_features_for_triplets( + indexer: LargeGraphIndexer, + triplets: KnowledgeGraphLike, + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Data: + """For a given set of triplets retrieve a Data object containing the + unique graph and features from the index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplets (KnowledgeGraphLike): Triplets to fetch features for + node_feature_name (str, optional): Feature to use for node features. + Defaults to "x". + edge_feature_name (str, optional): Feature to use for edge features. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function for triplets. Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Returns: + Data: Data object containing the unique graph and features from the + index for the given triplets. + """ + gen = get_features_for_triplets_groups(indexer, [triplets], + node_feature_name, + edge_feature_name, pre_transform, + verbose) + return next(gen) diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 266f498a113b..7e83c35befb6 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,6 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin +from .rag_loader import RAGQueryLoader __all__ = classes = [ 'DataLoader', @@ -50,6 +51,7 @@ 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', + 'RAGQueryLoader', ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py new file mode 100644 index 000000000000..33d6cf0e868e --- /dev/null +++ b/torch_geometric/loader/rag_loader.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from torch_geometric.data import Data, FeatureStore, HeteroData +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +class RAGFeatureStore(Protocol): + """Feature store for remote GNN RAG backend.""" + @abstractmethod + def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: + """Makes a comparison between the query and all the nodes to get all + the closest nodes. Return the indices of the nodes that are to be seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: + """Makes a comparison between the query and all the edges to get all + the closest nodes. Returns the edge indices that are to be the seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + """Combines sampled subgraph output with features in a Data object.""" + ... + + +class RAGGraphStore(Protocol): + """Graph store for remote GNN RAG backend.""" + @abstractmethod + def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, + **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample a subgraph using the seeded nodes and edges.""" + ... + + @abstractmethod + def register_feature_store(self, feature_store: FeatureStore): + """Register a feature store to be used with the sampler. Samplers need + info from the feature store in order to work properly on HeteroGraphs. + """ + ... + + +# TODO: Make compatible with Heterographs + + +class RAGQueryLoader: + def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None): + """Loader meant for making queries from a remote backend. + + Args: + data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore + and GraphStore to load from. Assumed to conform to the + protocols listed above. + local_filter (Optional[Callable[[Data, Any], Data]], optional): + Optional local transform to apply to data after retrieval. + Defaults to None. + seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters + to pass into process for fetching seed nodes. Defaults to None. + seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters + to pass into process for fetching seed edges. Defaults to None. + sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for sampling graph. Defaults to None. + loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for loading graph features. Defaults to None. + """ + fstore, gstore = data + self.feature_store = fstore + self.graph_store = gstore + self.graph_store.register_feature_store(self.feature_store) + self.local_filter = local_filter + self.seed_nodes_kwargs = seed_nodes_kwargs or {} + self.seed_edges_kwargs = seed_edges_kwargs or {} + self.sampler_kwargs = sampler_kwargs or {} + self.loader_kwargs = loader_kwargs or {} + + def query(self, query: Any) -> Data: + """Retrieve a subgraph associated with the query with all its feature + attributes. + """ + seed_nodes = self.feature_store.retrieve_seed_nodes( + query, **self.seed_nodes_kwargs) + seed_edges = self.feature_store.retrieve_seed_edges( + query, **self.seed_edges_kwargs) + + subgraph_sample = self.graph_store.sample_subgraph( + seed_nodes, seed_edges, **self.sampler_kwargs) + + data = self.feature_store.load_subgraph(sample=subgraph_sample, + **self.loader_kwargs) + + if self.local_filter: + data = self.local_filter(data, query) + return data diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py index 6f8fbcc644dc..f7529ae721b7 100644 --- a/torch_geometric/nn/models/g_retriever.py +++ b/torch_geometric/nn/models/g_retriever.py @@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module): (default: :obj:`False`) mlp_out_channels (int, optional): The size of each graph embedding after projection. (default: :obj:`4096`) + mlp_out_tokens (int, optional): Number of LLM prefix tokens to + reserve for GNN output. (default: :obj:`1`) .. warning:: This module has been tested with the following HuggingFace models @@ -43,6 +45,7 @@ def __init__( gnn: torch.nn.Module, use_lora: bool = False, mlp_out_channels: int = 4096, + mlp_out_tokens: int = 1, ) -> None: super().__init__() @@ -77,7 +80,9 @@ def __init__( self.projector = torch.nn.Sequential( torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), torch.nn.Sigmoid(), - torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), + torch.nn.Linear(mlp_hidden_channels, + mlp_out_channels * mlp_out_tokens), + torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)), ).to(self.llm.device) def encode( @@ -126,6 +131,9 @@ def forward( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) @@ -182,6 +190,9 @@ def inference( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py index 833ee657d0e7..22d3039f4c83 100644 --- a/torch_geometric/profile/__init__.py +++ b/torch_geometric/profile/__init__.py @@ -20,6 +20,7 @@ get_gpu_memory_from_nvidia_smi, get_model_size, ) +from .nvtx import nvtxit __all__ = [ 'profileit', @@ -38,6 +39,7 @@ 'get_gpu_memory_from_nvidia_smi', 'get_gpu_memory_from_ipex', 'benchmark', + 'nvtxit', ] classes = __all__ diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py new file mode 100644 index 000000000000..8dbce375ae5a --- /dev/null +++ b/torch_geometric/profile/nvtx.py @@ -0,0 +1,66 @@ +from functools import wraps +from typing import Optional + +import torch + +CUDA_PROFILE_STARTED = False + + +def begin_cuda_profile(): + global CUDA_PROFILE_STARTED + prev_state = CUDA_PROFILE_STARTED + if prev_state is False: + CUDA_PROFILE_STARTED = True + torch.cuda.cudart().cudaProfilerStart() + return prev_state + + +def end_cuda_profile(prev_state: bool): + global CUDA_PROFILE_STARTED + CUDA_PROFILE_STARTED = prev_state + if prev_state is False: + torch.cuda.cudart().cudaProfilerStop() + + +def nvtxit(name: Optional[str] = None, n_warmups: int = 0, + n_iters: Optional[int] = None): + """Enables NVTX profiling for a function. + + Args: + name (Optional[str], optional): Name to give the reference frame for + the function being wrapped. Defaults to the name of the + function in code. + n_warmups (int, optional): Number of iters to call that function + before starting. Defaults to 0. + n_iters (Optional[int], optional): Number of iters of that function to + record. Defaults to all of them. + """ + def nvtx(func): + + nonlocal name + iters_so_far = 0 + if name is None: + name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal iters_so_far + if not torch.cuda.is_available(): + return func(*args, **kwargs) + elif iters_so_far < n_warmups: + iters_so_far += 1 + return func(*args, **kwargs) + elif n_iters is None or iters_so_far < n_iters + n_warmups: + prev_state = begin_cuda_profile() + torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") + result = func(*args, **kwargs) + torch.cuda.nvtx.range_pop() + end_cuda_profile(prev_state) + iters_so_far += 1 + return result + else: + return func(*args, **kwargs) + + return wrapper + + return nvtx From 10047ba010cf17e72ec54ef8c1fb5067c8c54334 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 26 Nov 2024 15:32:55 +0100 Subject: [PATCH 09/45] Check that custom edge types actually exist in `NumNeighbors` definition (#9807) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + test/sampler/test_sampler_base.py | 3 +++ torch_geometric/sampler/base.py | 8 ++++++++ 3 files changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 341be665fabf..3867de8b1bbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) +- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807)) ### Deprecated diff --git a/test/sampler/test_sampler_base.py b/test/sampler/test_sampler_base.py index dc8142176bf6..41a7da25534f 100644 --- a/test/sampler/test_sampler_base.py +++ b/test/sampler/test_sampler_base.py @@ -49,6 +49,9 @@ def test_heterogeneous_num_neighbors_dict_and_default(): num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1]) + with pytest.raises(ValueError, match="Not all edge types"): + num_neighbors.get_values([('A', 'C'), ('B', 'A')]) + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]} diff --git a/torch_geometric/sampler/base.py b/torch_geometric/sampler/base.py index d67ddd5af79b..1bd2e4346e1d 100644 --- a/torch_geometric/sampler/base.py +++ b/torch_geometric/sampler/base.py @@ -425,6 +425,14 @@ def _get_values( else: assert False + # Confirm that `values` only hold valid edge types: + if isinstance(self.values, dict): + edge_types_str = {EdgeTypeStr(key) for key in edge_types} + invalid_edge_types = set(self.values.keys()) - edge_types_str + if len(invalid_edge_types) > 0: + raise ValueError("Not all edge types specified in " + "'num_neighbors' exist in the graph") + out = {} for edge_type in edge_types: edge_type_str = EdgeTypeStr(edge_type) From 46705844b39ededc0fcef1de90e73923480a6446 Mon Sep 17 00:00:00 2001 From: abertics Date: Thu, 28 Nov 2024 08:19:36 -0800 Subject: [PATCH 10/45] Fix typo in Dataset docstring (#9813) baching -> batching --- torch_geometric/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index dd5239eec7c2..9df4359f1450 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -383,7 +383,7 @@ def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. The returned instance can then be used with :pyg:`PyG's` built-in - :class:`DataPipes` for baching graphs as follows: + :class:`DataPipes` for batching graphs as follows: .. code-block:: python From bd5ae45c74a3fbb6b6ff818476f7651d84313d2a Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Fri, 6 Dec 2024 11:27:31 -0800 Subject: [PATCH 11/45] updated Dockerfile based on NGC PyG 24.09 image (#9794) Updated to use new NGC CUDA DL base image. Some differences: 1. /workspace is the working directory 2. Python libs removed that were not included in NGC PyG image: `torch_scatter torch_sparse torch_cluster torch_spline_conv torchnet==0.0.4 h5py torchnet ` 3. Using latest stable versions for graphviz and torch --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishi Puri --- CHANGELOG.md | 1 + docker/Dockerfile | 172 ++++------------------------------------------ 2 files changed, 16 insertions(+), 157 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3867de8b1bbb..4e6789b9a86d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/docker/Dockerfile b/docker/Dockerfile index d4f37f061d68..d7a879ba1157 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,163 +1,21 @@ -FROM ubuntu:18.04 +FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04 -# metainformation -LABEL org.opencontainers.image.version = "2.3.1" -LABEL org.opencontainers.image.authors = "Matthias Fey" -LABEL org.opencontainers.image.source = "https://github.com/pyg-team/pytorch_geometric" -LABEL org.opencontainers.image.licenses = "MIT" -LABEL org.opencontainers.image.base.name="docker.io/library/ubuntu:18.04" +# Based on NGC PyG 24.09 image: +# https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09 -RUN apt-get update && apt-get install -y apt-transport-https ca-certificates && \ - rm -rf /var/lib/apt/lists/* +# install pip +RUN apt-get update && apt-get install -y python3-pip -RUN apt-get update && apt-get install -y --no-install-recommends apt-utils gnupg2 curl && \ - curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ - echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ - echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list &&\ - apt-get purge --autoremove -y curl && \ -rm -rf /var/lib/apt/lists/* +# install PyTorch - latest stable version +RUN pip install torch torchvision torchaudio -ENV CUDA_VERSION 10.1.243 -ENV NCCL_VERSION 2.4.8 -ENV CUDA_PKG_VERSION 10-1=$CUDA_VERSION-1 -ENV CUDNN_VERSION 7.6.5.32 +# install graphviz - latest stable version +RUN apt-get install -y graphviz graphviz-dev +RUN pip install pygraphviz -RUN apt-get update && apt-get install -y --no-install-recommends \ - cuda-cudart-$CUDA_PKG_VERSION \ - cuda-compat-10-1 && \ - ln -s cuda-10.1 /usr/local/cuda && \ - rm -rf /var/lib/apt/lists/* +# install python packages with NGC PyG 24.09 image versions +RUN pip install torch_geometric==2.6.0 +RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5 -RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \ - cuda-libraries-$CUDA_PKG_VERSION \ - cuda-nvtx-$CUDA_PKG_VERSION \ - libcublas10=10.2.1.243-1 \ - libnccl2=$NCCL_VERSION-1+cuda10.1 && \ - apt-mark hold libnccl2 && \ - rm -rf /var/lib/apt/lists/* - -RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \ - cuda-libraries-dev-$CUDA_PKG_VERSION \ - cuda-nvml-dev-$CUDA_PKG_VERSION \ - cuda-minimal-build-$CUDA_PKG_VERSION \ - cuda-command-line-tools-$CUDA_PKG_VERSION \ - libnccl-dev=$NCCL_VERSION-1+cuda10.1 \ - libcublas-dev=10.2.1.243-1 \ - && \ - rm -rf /var/lib/apt/lists/* - -RUN apt-get update && apt-get install -y --no-install-recommends \ - libcudnn7=$CUDNN_VERSION-1+cuda10.1 \ - libcudnn7-dev=$CUDNN_VERSION-1+cuda10.1 \ - && \ - apt-mark hold libcudnn7 && \ - rm -rf /var/lib/apt/lists/* - - -ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs - -# NVIDIA docker 1.0. -LABEL com.nvidia.volumes.needed="nvidia_driver" -LABEL com.nvidia.cuda.version="${CUDA_VERSION}" - -RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ - echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} -ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 - -# NVIDIA container runtime. -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES compute,utility -ENV NVIDIA_REQUIRE_CUDA "cuda>=10.0 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=410,driver<411" - -# PyTorch (Geometric) installation -RUN rm /etc/apt/sources.list.d/cuda.list && \ - rm /etc/apt/sources.list.d/nvidia-ml.list - -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - vim \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - && rm -rf /var/lib/apt/lists/* - -# Create a working directory. -RUN mkdir /app -WORKDIR /app - -# Create a non-root user and switch to it. -RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ - && chown -R user:user /app -RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user -USER user - -# All users can use /home/user as their home directory. -ENV HOME=/home/user -RUN chmod 777 /home/user - -# Install Miniconda. -RUN curl -so ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ - && chmod +x ~/miniconda.sh \ - && ~/miniconda.sh -b -p ~/miniconda \ - && rm ~/miniconda.sh -ENV PATH=/home/user/miniconda/bin:$PATH -ENV CONDA_AUTO_UPDATE_CONDA=false - -# Create a Python 3.6 environment. -RUN /home/user/miniconda/bin/conda install conda-build \ - && /home/user/miniconda/bin/conda create -y --name py36 python=3.6.5 \ - && /home/user/miniconda/bin/conda clean -ya -ENV CONDA_DEFAULT_ENV=py36 -ENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV -ENV PATH=$CONDA_PREFIX/bin:$PATH - -# CUDA 10.0-specific steps. -RUN conda install -y -c pytorch \ - cudatoolkit=10.1 \ - "pytorch=1.4.0=py3.6_cuda10.1.243_cudnn7.6.3_0" \ - torchvision=0.5.0=py36_cu101 \ - && conda clean -ya - -# Install HDF5 Python bindings. -RUN conda install -y h5py=2.8.0 \ - && conda clean -ya -RUN pip install h5py-cache==1.0 - -# Install TorchNet, a high-level framework for PyTorch. -RUN pip install torchnet==0.0.4 - -# Install Requests, a Python library for making HTTP requests. -RUN conda install -y requests=2.19.1 \ - && conda clean -ya - -# Install Graphviz. -RUN conda install -y graphviz=2.40.1 python-graphviz=0.8.4 \ - && conda clean -ya - -# Install OpenCV3 Python bindings. -RUN sudo apt-get update && sudo apt-get install -y --no-install-recommends \ - libgtk2.0-0 \ - libcanberra-gtk-module \ - && sudo rm -rf /var/lib/apt/lists/* -RUN conda install -y -c menpo opencv3=3.1.0 \ - && conda clean -ya - -# Install PyG. -RUN CPATH=/usr/local/cuda/include:$CPATH \ - && LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \ - && DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH - -RUN pip install scipy - -RUN pip install --no-index torch_scatter -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_sparse -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_cluster -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_spline_conv -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install torch-geometric - -# Set the default command to python3. -CMD ["python3"] +# install cugraph +RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com From 1519e9fa9cd0d23ee7b64d80563a45b10599a0c6 Mon Sep 17 00:00:00 2001 From: zaristei Date: Tue, 10 Dec 2024 14:01:29 -0500 Subject: [PATCH 12/45] Fix Docstring Typos for LargeGraphIndexer (#9837) Fix some issues with the docstrings of LargeGraphIndexer. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zachary Aristei --- torch_geometric/data/large_graph_indexer.py | 38 ++++++++++----------- torch_geometric/loader/__init__.py | 4 ++- torch_geometric/loader/rag_loader.py | 5 +-- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py index 0644e2543303..d2cb30378908 100644 --- a/torch_geometric/data/large_graph_indexer.py +++ b/torch_geometric/data/large_graph_indexer.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Hashable, Iterable, Iterator, List, @@ -25,12 +24,13 @@ from torch_geometric.data import Data from torch_geometric.typing import WITH_PT24 -TripletLike = Tuple[Hashable, Hashable, Hashable] +# Could be any hashable type +TripletLike = Tuple[str, str, str] KnowledgeGraphLike = Iterable[TripletLike] -def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: +def ordered_set(values: Iterable[str]) -> List[str]: return list(dict.fromkeys(values)) @@ -70,13 +70,13 @@ def __eq__(self, value: "MappedFeature") -> bool: class LargeGraphIndexer: - """For a dataset that consists of mulitiple subgraphs that are assumed to + """For a dataset that consists of multiple subgraphs that are assumed to be part of a much larger graph, collate the values into a large graph store to save resources. """ def __init__( self, - nodes: Iterable[Hashable], + nodes: Iterable[str], edges: KnowledgeGraphLike, node_attr: Optional[Dict[str, List[Any]]] = None, edge_attr: Optional[Dict[str, List[Any]]] = None, @@ -85,7 +85,7 @@ def __init__( by id. Not meant to be used directly. Args: - nodes (Iterable[Hashable]): Node ids in the graph. + nodes (Iterable[str]): Node ids in the graph. edges (KnowledgeGraphLike): Edge ids in the graph. node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node attribute name and list of their values in order of unique node @@ -94,7 +94,7 @@ def __init__( attribute name and list of their values in order of unique edge ids. Defaults to None. """ - self._nodes: Dict[Hashable, int] = dict() + self._nodes: Dict[str, int] = dict() self._edges: Dict[TripletLike, int] = dict() self._mapped_node_features: Set[str] = set() @@ -201,7 +201,7 @@ def collate(cls, index. Args: - graphs (Iterable["LargeGraphIndexer"]): Indices to be + graphs (Iterable[LargeGraphIndexer]): Indices to be combined. Returns: @@ -212,8 +212,8 @@ def collate(cls, trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) return cls.from_triplets(trips) - def get_unique_node_features( - self, feature_name: str = NODE_PID) -> List[Hashable]: + def get_unique_node_features(self, + feature_name: str = NODE_PID) -> List[str]: r"""Get all the unique values for a specific node attribute. Args: @@ -221,7 +221,7 @@ def get_unique_node_features( Defaults to NODE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_node_features: @@ -272,7 +272,7 @@ def add_node_feature( def get_node_features( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get node feature values for a given set of unique node ids. Returned values are not necessarily unique. @@ -280,7 +280,7 @@ def get_node_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to NODE_PID. - pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + pids (Optional[Iterable[str]], optional): Node ids to fetch for. Defaults to None, which fetches all nodes. Returns: @@ -302,7 +302,7 @@ def get_node_features( def get_node_features_iter( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, index_only: bool = False, ) -> Iterator[Any]: """Iterator version of get_node_features. If index_only is True, @@ -337,8 +337,8 @@ def get_node_features_iter( else: yield self.node_attr[feature_name][idx] - def get_unique_edge_features( - self, feature_name: str = EDGE_PID) -> List[Hashable]: + def get_unique_edge_features(self, + feature_name: str = EDGE_PID) -> List[str]: r"""Get all the unique values for a specific edge attribute. Args: @@ -346,7 +346,7 @@ def get_unique_edge_features( Defaults to EDGE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_edge_features: @@ -396,7 +396,7 @@ def add_edge_feature( def get_edge_features( self, feature_name: str = EDGE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get edge feature values for a given set of unique edge ids. Returned values are not necessarily unique. @@ -404,7 +404,7 @@ def get_edge_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to EDGE_PID. - pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + pids (Optional[Iterable[str]], optional): Edge ids to fetch for. Defaults to None, which fetches all edges. Returns: diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 7e83c35befb6..75dbe9178681 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,7 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin -from .rag_loader import RAGQueryLoader +from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore __all__ = classes = [ 'DataLoader', @@ -52,6 +52,8 @@ 'CachedLoader', 'AffinityMixin', 'RAGQueryLoader', + 'RAGFeatureStore', + 'RAGGraphStore' ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py index 33d6cf0e868e..4ab457ef7072 100644 --- a/torch_geometric/loader/rag_loader.py +++ b/torch_geometric/loader/rag_loader.py @@ -7,7 +7,7 @@ class RAGFeatureStore(Protocol): - """Feature store for remote GNN RAG backend.""" + """Feature store template for remote GNN RAG backend.""" @abstractmethod def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: """Makes a comparison between the query and all the nodes to get all @@ -33,7 +33,7 @@ def load_subgraph( class RAGGraphStore(Protocol): - """Graph store for remote GNN RAG backend.""" + """Graph store template for remote GNN RAG backend.""" @abstractmethod def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: @@ -52,6 +52,7 @@ def register_feature_store(self, feature_store: FeatureStore): class RAGQueryLoader: + """Loader meant for making RAG queries from a remote backend.""" def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], local_filter: Optional[Callable[[Data, Any], Data]] = None, seed_nodes_kwargs: Optional[Dict[str, Any]] = None, From 2b1b32719d5d87c9f34ed467bf44778c567c732a Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 11 Dec 2024 00:25:58 -0800 Subject: [PATCH 13/45] feat: store reverse mapping within `EdgeTypeStr` (#9844) To avoid issues when node types contain the `EDGE_TYPE_STR_SPLIT` delimiter. --------- Co-authored-by: rusty1s --- torch_geometric/typing.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e63b1849b65c..468f37abfaed 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -307,6 +307,8 @@ class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ + edge_type: tuple[str, str, str] + def __new__(cls, *args: Any) -> 'EdgeTypeStr': if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: @@ -314,27 +316,34 @@ def __new__(cls, *args: Any) -> 'EdgeTypeStr': if len(args) == 1 and isinstance(args[0], str): arg = args[0] # An edge type string was passed. + edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT)) + if len(edge_type) != 3: + raise ValueError(f"Cannot convert the edge type '{arg}' to a " + f"tuple since it holds invalid characters") elif len(args) == 2 and all(isinstance(arg, str) for arg in args): # A `(src, dst)` edge type was passed - add `DEFAULT_REL`: - arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1])) + edge_type = (args[0], DEFAULT_REL, args[1]) + arg = EDGE_TYPE_STR_SPLIT.join(edge_type) elif len(args) == 3 and all(isinstance(arg, str) for arg in args): # A `(src, rel, dst)` edge type was passed: + edge_type = tuple(args) arg = EDGE_TYPE_STR_SPLIT.join(args) else: raise ValueError(f"Encountered invalid edge type '{args}'") - return str.__new__(cls, arg) + out = str.__new__(cls, arg) + out.edge_type = edge_type # type: ignore + return out def to_tuple(self) -> EdgeType: r"""Returns the original edge type.""" - out = tuple(self.split(EDGE_TYPE_STR_SPLIT)) - if len(out) != 3: + if len(self.edge_type) != 3: raise ValueError(f"Cannot convert the edge type '{self}' to a " f"tuple since it holds invalid characters") - return out + return self.edge_type # There exist some short-cuts to query edge-types (given that the full triplet From ab2b458f0c0f72d3cb573350b324db563066a7ee Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 17 Dec 2024 09:02:38 -0800 Subject: [PATCH 14/45] fix: update `__reduce__` for `EdgeTypeStr` (#9876) Supports `deepcopy` on `EdgeTypeStr`. --- torch_geometric/typing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 468f37abfaed..513f041847b1 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -345,6 +345,9 @@ def to_tuple(self) -> EdgeType: f"tuple since it holds invalid characters") return self.edge_type + def __reduce__(self) -> tuple[Any, Any]: + return (self.__class__, (self.edge_type, )) + # There exist some short-cuts to query edge-types (given that the full triplet # can be uniquely reconstructed, e.g.: From 5ea6aec8827eabf2a7569d32780ebf3510ba0f6e Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 30 Dec 2024 20:39:27 +0900 Subject: [PATCH 15/45] Don't iterate over samples in the same order every epoch in a distributed training example (#9900) * Adds a call to `set_epoch` to shuffle samples across epochs. * Applies average instead of sum reduction across ranks to make the scale of the loss consistent across a different number of GPUs. * Calls `optimizer.zero_grad` at the end to release gradients before evaluation loops. * Includes minor decorative changes, e.g., type annotations. --- examples/multi_gpu/distributed_batching.py | 88 ++++++++++++++-------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/examples/multi_gpu/distributed_batching.py b/examples/multi_gpu/distributed_batching.py index f5c05a176823..b242499d3e76 100644 --- a/examples/multi_gpu/distributed_batching.py +++ b/examples/multi_gpu/distributed_batching.py @@ -1,36 +1,35 @@ import os +import os.path as osp import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F -from ogb.graphproppred import Evaluator -from ogb.graphproppred import PygGraphPropPredDataset as Dataset +from ogb.graphproppred import Evaluator, PygGraphPropPredDataset from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch.nn import BatchNorm1d as BatchNorm from torch.nn import Linear, ReLU, Sequential from torch.nn.parallel import DistributedDataParallel from torch.utils.data.distributed import DistributedSampler +from torch_sparse import SparseTensor import torch_geometric.transforms as T from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, global_mean_pool -from torch_geometric.typing import WITH_TORCH_SPARSE - -if not WITH_TORCH_SPARSE: - quit("This example requires 'torch-sparse'") class GIN(torch.nn.Module): - def __init__(self, hidden_channels, out_channels, num_layers=3, - dropout=0.5): + def __init__( + self, + hidden_channels: int, + out_channels: int, + num_layers: int = 3, + dropout: float = 0.5, + ) -> None: super().__init__() - self.dropout = dropout - self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) - self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( @@ -45,7 +44,12 @@ def __init__(self, hidden_channels, out_channels, num_layers=3, self.lin = Linear(hidden_channels, out_channels) - def forward(self, x, adj_t, batch): + def forward( + self, + x: torch.Tensor, + adj_t: SparseTensor, + batch: torch.Tensor, + ) -> torch.Tensor: x = self.atom_encoder(x) edge_attr = adj_t.coo()[2] adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo') @@ -59,21 +63,29 @@ def forward(self, x, adj_t, batch): return x -def run(rank, world_size: int, dataset_name: str, root: str): +def run(rank: int, world_size: int, dataset_name: str, root: str) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) - dataset = Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + dataset = PygGraphPropPredDataset( + dataset_name, + root=root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) split_idx = dataset.get_idx_split() evaluator = Evaluator(dataset_name) train_dataset = dataset[split_idx['train']] - train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, - rank=rank) - train_loader = DataLoader(train_dataset, batch_size=128, - sampler=train_sampler) + train_loader = DataLoader( + train_dataset, + batch_size=128, + sampler=DistributedSampler( + train_dataset, + shuffle=True, + drop_last=True, + ), + ) torch.manual_seed(12345) model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank) @@ -87,20 +99,22 @@ def run(rank, world_size: int, dataset_name: str, root: str): for epoch in range(1, 51): model.train() - - total_loss = torch.zeros(2).to(rank) + train_loader.sampler.set_epoch(epoch) + total_loss = torch.zeros(2, device=rank) for data in train_loader: data = data.to(rank) - optimizer.zero_grad() logits = model(data.x, data.adj_t, data.batch) loss = criterion(logits, data.y.to(torch.float)) loss.backward() optimizer.step() - total_loss[0] += float(loss) * logits.size(0) - total_loss[1] += data.num_graphs + optimizer.zero_grad() - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - loss = float(total_loss[0] / total_loss[1]) + with torch.no_grad(): + total_loss[0] += loss * logits.size(0) + total_loss[1] += data.num_graphs + + dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) + train_loss = total_loss[0] / total_loss[1] if rank == 0: # We evaluate on a single GPU for now. model.eval() @@ -127,8 +141,10 @@ def run(rank, world_size: int, dataset_name: str, root: str): 'y_true': torch.cat(y_true, dim=0), })['rocauc'] - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' - f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}') + print(f'Epoch: {epoch:03d}, ' + f'Loss: {train_loss:.4f}, ' + f'Val: {val_rocauc:.4f}, ' + f'Test: {test_rocauc:.4f}') dist.barrier() @@ -137,11 +153,19 @@ def run(rank, world_size: int, dataset_name: str, root: str): if __name__ == '__main__': dataset_name = 'ogbg-molhiv' - root = '../../data/OGB' - + root = osp.join( + osp.dirname(__file__), + '..', + '..', + 'data', + 'OGB', + ) # Download and process the dataset on main process. - Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + PygGraphPropPredDataset( + dataset_name, + root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!') From 076db8403243ad4400da196218f731756293da18 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 30 Dec 2024 20:39:56 +0900 Subject: [PATCH 16/45] Fix broken metrics in a `DistributedDataParallel` example (#9896) Fixes a few issues: * Fixes the same three calls to `dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)` (introduced in #8880) that led to the wrong metrics. * Avoids the `int(cuda_tensor)` call in every eval iteration to get rid of the D2H synchronization. * Calls `optimizer.step()` at the end of each training step to release GPU memory for gradients before evaluation loops. * Minor decorative changes: * Adds type annotations. * Adds a progress bar for the training loop. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/multi_gpu/distributed_sampling.py | 67 +++++++++++++--------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/examples/multi_gpu/distributed_sampling.py b/examples/multi_gpu/distributed_sampling.py index 805e6b305728..d3df91000fd0 100644 --- a/examples/multi_gpu/distributed_sampling.py +++ b/examples/multi_gpu/distributed_sampling.py @@ -1,4 +1,5 @@ import os +import os.path as osp from math import ceil import torch @@ -7,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader @@ -14,10 +16,14 @@ class SAGE(torch.nn.Module): - def __init__(self, in_channels: int, hidden_channels: int, - out_channels: int, num_layers: int = 2): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int = 2, + ) -> None: super().__init__() - self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): @@ -34,20 +40,25 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: @torch.no_grad() -def test(loader, model, rank): +def test( + loader: NeighborLoader, + model: DistributedDataParallel, + rank: int, +) -> Tensor: model.eval() - - total_correct = total_examples = 0 + total_correct = torch.tensor(0, dtype=torch.long, device=rank) + total_examples = 0 for i, batch in enumerate(loader): out = model(batch.x, batch.edge_index.to(rank)) pred = out[:batch.batch_size].argmax(dim=-1) y = batch.y[:batch.batch_size].to(rank) - total_correct += int((pred == y).sum()) + total_correct += (pred == y).sum() total_examples += batch.batch_size - return torch.tensor(total_correct / total_examples, device=rank) + + return total_correct / total_examples -def run(rank, world_size, dataset): +def run(rank: int, world_size: int, dataset: Reddit) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) @@ -94,17 +105,19 @@ def run(rank, world_size, dataset): for epoch in range(1, 21): model.train() - for batch in train_loader: - optimizer.zero_grad() + for batch in tqdm( + train_loader, + desc=f'Epoch {epoch:02d}', + disable=rank != 0, + ): out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size]) loss.backward() optimizer.step() - - dist.barrier() + optimizer.zero_grad() if rank == 0: - print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') + print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}') if epoch % 5 == 0: train_acc = test(train_loader, model, rank) @@ -112,25 +125,27 @@ def run(rank, world_size, dataset): test_acc = test(test_loader, model, rank) if world_size > 1: - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - train_acc /= world_size - val_acc /= world_size - test_acc /= world_size + dist.all_reduce(train_acc, op=dist.ReduceOp.AVG) + dist.all_reduce(val_acc, op=dist.ReduceOp.AVG) + dist.all_reduce(test_acc, op=dist.ReduceOp.AVG) if rank == 0: - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - dist.barrier() + print(f'Train acc: {train_acc:.4f}, ' + f'Val acc: {val_acc:.4f}, ' + f'Test acc: {test_acc:.4f}') dist.destroy_process_group() if __name__ == '__main__': - dataset = Reddit('../../data/Reddit') - + path = osp.join( + osp.dirname(__file__), + '..', + '..', + 'data', + 'Reddit', + ) + dataset = Reddit(path) world_size = torch.cuda.device_count() print("Let's use", world_size, "GPUs!") mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True) From c300f38e4f28f550456c65f7e08e052dc40282cd Mon Sep 17 00:00:00 2001 From: kolmiw <72548086+kolmiw@users.noreply.github.com> Date: Mon, 30 Dec 2024 20:27:54 +0100 Subject: [PATCH 17/45] Graphgym documentation completion (#9905) I changed a single line of a comment. The compute loss function's description was cut in half mid-sentence. This change shouldn't change the code by any means though --- torch_geometric/graphgym/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/graphgym/loss.py b/torch_geometric/graphgym/loss.py index db38ce6e1e42..4e3dfd7029a6 100644 --- a/torch_geometric/graphgym/loss.py +++ b/torch_geometric/graphgym/loss.py @@ -10,7 +10,7 @@ def compute_loss(pred, true): Args: pred (torch.tensor): Unnormalized prediction - true (torch.tensor): Grou + true (torch.tensor): Ground truth labels Returns: Loss, normalized prediction score From 5d1b8987e52ddead470973c611c5b6b1bf935d99 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 5 Jan 2025 03:28:54 +0900 Subject: [PATCH 18/45] Update type annotations (#9917) Unblocks merging to master. --- .github/workflows/linting.yml | 2 +- torch_geometric/datasets/git_mol_dataset.py | 4 ++-- torch_geometric/datasets/molecule_gpt_dataset.py | 9 ++++++--- torch_geometric/utils/smiles.py | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 3653abdea195..f1f16a5e9a9c 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -25,7 +25,7 @@ jobs: # Skip workflow if only certain files have been changed. - name: Get changed files id: changed-files-specific - uses: tj-actions/changed-files@v41 + uses: tj-actions/changed-files@v45 with: files: | benchmark/** diff --git a/torch_geometric/datasets/git_mol_dataset.py b/torch_geometric/datasets/git_mol_dataset.py index 4b7cfa78117c..872b0d1f5d2d 100644 --- a/torch_geometric/datasets/git_mol_dataset.py +++ b/torch_geometric/datasets/git_mol_dataset.py @@ -187,7 +187,7 @@ def process(self) -> None: img = self.img_transform(img).unsqueeze(0) # graph atom_features_list = [] - for atom in mol.GetAtoms(): # type: ignore + for atom in mol.GetAtoms(): atom_feature = [ safe_index( allowable_features['possible_atomic_num_list'], @@ -219,7 +219,7 @@ def process(self) -> None: edges_list = [] edge_features_list = [] - for bond in mol.GetBonds(): # type: ignore + for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_feature = [ safe_index( diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py index b1da09f38570..fed2fe503600 100644 --- a/torch_geometric/datasets/molecule_gpt_dataset.py +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -122,7 +122,10 @@ def clean_up_description(description: str) -> str: return first_sentence -def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]: +def extract_name( + name_raw: str, + description: str, +) -> Tuple[Optional[str], str, str]: first_sentence = clean_up_description(description) splitter = ' -- -- ' @@ -446,12 +449,12 @@ def extract_one_SDF_file(block_id: int) -> None: x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 - for atom in m.GetAtoms() # type: ignore + for atom in m.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] - for bond in m.GetBonds(): # type: ignore + for bond in m.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] diff --git a/torch_geometric/utils/smiles.py b/torch_geometric/utils/smiles.py index 608618fc44d9..6547be968582 100644 --- a/torch_geometric/utils/smiles.py +++ b/torch_geometric/utils/smiles.py @@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data': assert isinstance(mol, Chem.Mol) xs: List[List[int]] = [] - for atom in mol.GetAtoms(): # type: ignore + for atom in mol.GetAtoms(): row: List[int] = [] row.append(x_map['atomic_num'].index(atom.GetAtomicNum())) row.append(x_map['chirality'].index(str(atom.GetChiralTag()))) @@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data': x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] - for bond in mol.GetBonds(): # type: ignore + for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() From 109ec56790c106ccd9738dc034871aa026659c13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:56:43 +0100 Subject: [PATCH 19/45] [pre-commit.ci] pre-commit suggestions (#9920) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/asottile/pyupgrade: v3.17.0 → v3.19.1](https://github.com/asottile/pyupgrade/compare/v3.17.0...v3.19.1) - [github.com/google/yapf: v0.40.2 → v0.43.0](https://github.com/google/yapf/compare/v0.40.2...v0.43.0) - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.8.6](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.9...v0.8.6) - [github.com/executablebooks/mdformat: 0.7.17 → 0.7.21](https://github.com/executablebooks/mdformat/compare/0.7.17...0.7.21) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 8 ++++---- CHANGELOG.md | 18 +++++++++--------- README.md | 2 +- examples/distributed/pyg/README.md | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eb97a9a5eb26..5ebbc1291fd6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}'] - repo: https://github.com/asottile/pyupgrade - rev: v3.17.0 + rev: v3.19.1 hooks: - id: pyupgrade name: Upgrade Python syntax @@ -54,7 +54,7 @@ repos: ] - repo: https://github.com/google/yapf - rev: v0.40.2 + rev: v0.43.0 hooks: - id: yapf name: Format code @@ -74,14 +74,14 @@ repos: additional_dependencies: [Flake8-pyproject] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.6 hooks: - id: ruff name: Ruff formatting args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat name: Format Markdown diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..7321dbb30413 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## \[2.7.0\] - 2024-MM-DD +## [2.7.0] - 2024-MM-DD ### Added @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -## \[2.6.0\] - 2024-09-13 +## [2.6.0] - 2024-09-13 ### Added @@ -132,7 +132,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -## \[2.5.0\] - 2024-02-16 +## [2.5.0] - 2024-02-16 ### Added @@ -232,7 +232,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed disabling of extension packages during `torch_geometric.compile` ([#8698](https://github.com/pyg-team/pytorch_geometric/pull/8698)) -## \[2.4.0\] - 2023-10-12 +## [2.4.0] - 2023-10-12 ### Added @@ -288,7 +288,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656)) - Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649)) - Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647)) -- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) +- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) - Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603)) - Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594)) - Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601)) @@ -357,7 +357,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942)) - Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737)) - Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955)) -- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) +- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) - Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953)) - Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941)) - Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923)) @@ -424,7 +424,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `layer_type` argument in `contrib.explain.GraphMaskExplainer` ([#7445](https://github.com/pyg-team/pytorch_geometric/pull/7445)) - Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117)) -## \[2.3.0\] - 2023-03-23 +## [2.3.0] - 2023-03-23 ### Added @@ -582,7 +582,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270)) - Removed `Aggregation.set_validate_args` option ([#6175](https://github.com/pyg-team/pytorch_geometric/pull/6175)) -## \[2.2.0\] - 2022-12-01 +## [2.2.0] - 2022-12-01 ### Added @@ -700,7 +700,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `scatter_reduce` option from experimental mode ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399)) -## \[2.1.0\] - 2022-08-17 +## [2.1.0] - 2022-08-17 ### Added diff --git a/README.md b/README.md index f96d720837d1..4f28870b8770 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ More information about evaluating final model performance can be found in the co In addition to the easy application of existing GNNs, PyG makes it simple to implement custom Graph Neural Networks (see [here](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html) for the accompanying tutorial). For example, this is all it takes to implement the [edge convolutional layer](https://arxiv.org/abs/1801.07829) from Wang *et al.*: -$$x_i^{\\prime} ~ = ~ \\max\_{j \\in \\mathcal{N}(i)} ~ \\textrm{MLP}\_{\\theta} \\left( \[ ~ x_i, ~ x_j - x_i ~ \] \\right)$$ +$$x_i^{\\prime} ~ = ~ \\max\_{j \\in \\mathcal{N}(i)} ~ \\textrm{MLP}\_{\\theta} \\left( [ ~ x_i, ~ x_j - x_i ~ ] \\right)$$ ```python import torch diff --git a/examples/distributed/pyg/README.md b/examples/distributed/pyg/README.md index 595d557b1936..94cd696bef19 100644 --- a/examples/distributed/pyg/README.md +++ b/examples/distributed/pyg/README.md @@ -22,7 +22,7 @@ To run the example, please refer to the steps below. - Password-less SSH needs to be set up on all the nodes that you are using (see the [Linux SSH manual](https://linuxize.com/post/how-to-setup-passwordless-ssh-login)). - All nodes need to have a consistent environments installed, specifically `torch` and `pyg-lib` versions must be the same. You might want to consider using docker containers. -- *\[Optional\]* In some cases Linux firewall might be blocking TCP connection issues. +- *[Optional]* In some cases Linux firewall might be blocking TCP connection issues. Ensure that firewall settings allow for all nodes to communicate (see the [Linux firewall manual](https://ubuntu.com/server/docs/security-firewall)). For this example TCP ports `11111`, `11112` and `11113` should be open (*i.e.* `sudo ufw allow 11111`). From 0c97b462a63d2b2d549ec62ba323df13b0e579e0 Mon Sep 17 00:00:00 2001 From: Luke Westfield <11140532+lukedoubleu@users.noreply.github.com> Date: Tue, 7 Jan 2025 05:58:15 -0500 Subject: [PATCH 20/45] Update introduction.rst (#9860) I think this should be "permute". --- docs/source/get_started/introduction.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/get_started/introduction.rst b/docs/source/get_started/introduction.rst index 058e24ae8370..69fa50f9f9f2 100644 --- a/docs/source/get_started/introduction.rst +++ b/docs/source/get_started/introduction.rst @@ -176,7 +176,7 @@ We can even use slices, long or bool tensors to split the dataset. test_dataset = dataset[540:] >>> ENZYMES(60) -If you are unsure whether the dataset is already shuffled before you split, you can randomly permutate it by running: +If you are unsure whether the dataset is already shuffled before you split, you can randomly permute it by running: .. code-block:: python From cb424a6e40fb56477129d893cbd949cd5ab14ec0 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Wed, 8 Jan 2025 03:19:56 +0800 Subject: [PATCH 21/45] Fix `glem` example (#9903) Closes #9899. Co-authored-by: Rishi Puri --- examples/llm/glem.py | 2 ++ torch_geometric/nn/models/glem.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index ec76cef4c010..c6bae703fd33 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -371,9 +371,11 @@ def load_model(em_phase): if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) + test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) + test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index afc8b09d77c7..d30d5f8bd062 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -144,7 +144,8 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, acc (float): training accuracy loss (float): loss value """ - pseudo_labels = pseudo_labels.to(self.device) + if pseudo_labels is not None: + pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) From 8a651b3961de29a8c1eda31306b5fce57e343f3d Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Tue, 7 Jan 2025 12:02:10 -0800 Subject: [PATCH 22/45] Using AI tools to improve commenting of base G-retriever example (#9882) Co-authored-by: riship Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/llm/g_retriever.py | 172 +++++++++++++++++++++++++++++++----- 1 file changed, 148 insertions(+), 24 deletions(-) diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index a48901f1ff0e..1fd654886208 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -31,42 +31,64 @@ def compute_metrics(eval_output): + """Compute evaluation metrics (hit, precision, recall, F1). + + Parameters: + eval_output (list): List of dictionaries containing prediction output. + + Returns: + None (prints metrics to console) + """ + # Concatenate prediction output into a single DataFrame df = pd.concat([pd.DataFrame(d) for d in eval_output]) - all_hit = [] - all_precision = [] - all_recall = [] - all_f1 = [] + # Initialize lists to store metrics + all_hit = [] # Boolean values indicating whether prediction matches label + all_precision = [] # List of precision values + all_recall = [] # List of recall values + all_f1 = [] # List of F1 values + + # Iterate over prediction-label pairs for pred, label in zip(df.pred.tolist(), df.label.tolist()): try: + # Preprocess prediction string pred = pred.split('[/s]')[0].strip().split('|') + + # Check if prediction matches label hit = re.findall(pred[0], label) all_hit.append(len(hit) > 0) + # Compute precision, recall, and F1 label = label.split('|') matches = set(pred).intersection(set(label)) precision = len(matches) / len(set(pred)) recall = len(matches) / len(set(label)) + + # Handle division by zero if recall + precision == 0: f1 = 0 else: f1 = 2 * precision * recall / (precision + recall) + # Store metrics all_precision.append(precision) all_recall.append(recall) all_f1.append(f1) except Exception as e: + # Handle exceptions by printing error message and skipping print(f'Label: {label}') print(f'Pred: {pred}') print(f'Exception: {e}') print('------------------') + # Compute average metrics hit = sum(all_hit) / len(all_hit) precision = sum(all_precision) / len(all_precision) recall = sum(all_recall) / len(all_recall) f1 = sum(all_f1) / len(all_f1) + # Print metrics to console print(f'Hit: {hit:.4f}') print(f'Precision: {precision:.4f}') print(f'Recall: {recall:.4f}') @@ -74,51 +96,132 @@ def compute_metrics(eval_output): def save_params_dict(model, save_path): + """Saves a model's parameters, excluding non-trainable weights. + + Args: + model (torch.nn.Module): The model to save parameters from. + save_path (str): The path to save the parameters to. + """ + # Get the model's state dictionary, which contains all its parameters state_dict = model.state_dict() + + # Create a dictionary mapping parameter names to their requires_grad status param_grad_dict = { k: v.requires_grad for (k, v) in model.named_parameters() } + + # Remove non-trainable parameters from the state dictionary for k in list(state_dict.keys()): if k in param_grad_dict.keys() and not param_grad_dict[k]: del state_dict[k] # Delete parameters that do not require gradient + + # Save the filtered state dictionary to the specified path torch.save(state_dict, save_path) def load_params_dict(model, save_path): + # Load the saved model parameters from the specified file path state_dict = torch.load(save_path) + + # Update the model's parameters with the loaded state dictionary model.load_state_dict(state_dict) + + # Return the model with updated parameters return model -def get_loss(model, batch, model_save_name) -> Tensor: +def get_loss(model, batch, model_save_name: str) -> Tensor: + """Compute the loss for a given model and batch of data. + + Args: + model: The model to compute the loss for. + batch: The batch of data to compute the loss for. + model_save_name: The name of the model being used (e.g. 'llm'). + + Returns: + Tensor: The computed loss. + """ + # Check the type of model being used to determine the input arguments if model_save_name == 'llm': + # For LLM models return model(batch.question, batch.label, batch.desc) - else: - return model(batch.question, batch.x, batch.edge_index, batch.batch, - batch.label, batch.edge_attr, batch.desc) + else: # (GNN+LLM) + return model( + batch.question, + batch.x, # node features + batch.edge_index, # edge indices + batch.batch, # batch indices + batch.label, # answers (labels) + batch.edge_attr, # edge attributes + batch.desc # description + ) def inference_step(model, batch, model_save_name): + """Performs inference on a given batch of data using the provided model. + + Args: + model (nn.Module): The model to use for inference. + batch: The batch of data to process. + model_save_name (str): The name of the model (e.g. 'llm'). + + Returns: + The output of the inference step. + """ + # Check the type of model being used to determine the input arguments if model_save_name == 'llm': + # Perform inference on the question and textual graph description return model.inference(batch.question, batch.desc) - else: - return model.inference(batch.question, batch.x, batch.edge_index, - batch.batch, batch.edge_attr, batch.desc) + else: # (GNN+LLM) + return model.inference( + batch.question, + batch.x, # node features + batch.edge_index, # edge indices + batch.batch, # batch indices + batch.edge_attr, # edge attributes + batch.desc # description + ) def train( - num_epochs, - hidden_channels, - num_gnn_layers, - batch_size, - eval_batch_size, - lr, - checkpointing=False, - tiny_llama=False, + num_epochs, # Total number of training epochs + hidden_channels, # Number of hidden channels in GNN + num_gnn_layers, # Number of GNN layers + batch_size, # Training batch size + eval_batch_size, # Evaluation batch size + lr, # Initial learning rate + checkpointing=False, # Whether to checkpoint model + tiny_llama=False, # Whether to use tiny LLaMA model ): + """Train a GNN+LLM model on WebQSP dataset. + + Args: + num_epochs (int): Total number of training epochs. + hidden_channels (int): Number of hidden channels in GNN. + num_gnn_layers (int): Number of GNN layers. + batch_size (int): Training batch size. + eval_batch_size (int): Evaluation batch size. + lr (float): Initial learning rate. + checkpointing (bool, optional): Whether to checkpoint model. + Defaults to False. + tiny_llama (bool, optional): Whether to use tiny LLaMA model. + Defaults to False. + + Returns: + None + """ def adjust_learning_rate(param_group, LR, epoch): - # Decay the learning rate with half-cycle cosine after warmup + """Decay learning rate with half-cycle cosine after warmup. + + Args: + param_group (dict): Parameter group. + LR (float): Learning rate. + epoch (int): Current epoch. + + Returns: + float: Adjusted learning rate. + """ min_lr = 5e-6 warmup_epochs = 1 if epoch < warmup_epochs: @@ -130,7 +233,10 @@ def adjust_learning_rate(param_group, LR, epoch): param_group['lr'] = lr return lr + # Start training time start_time = time.time() + + # Load dataset and create data loaders path = osp.dirname(osp.realpath(__file__)) path = osp.join(path, '..', '..', 'data', 'WebQSPDataset') train_dataset = WebQSPDataset(path, split='train') @@ -146,9 +252,11 @@ def adjust_learning_rate(param_group, LR, epoch): test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) - # To clean up after Data Preproc + # Clean up memory gc.collect() torch.cuda.empty_cache() + + # Create GNN model gnn = GAT( in_channels=1024, hidden_channels=hidden_channels, @@ -156,6 +264,8 @@ def adjust_learning_rate(param_group, LR, epoch): num_layers=num_gnn_layers, heads=4, ) + + # Create LLaMA model if tiny_llama: llm = LLM( model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', @@ -166,10 +276,12 @@ def adjust_learning_rate(param_group, LR, epoch): llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) model = GRetriever(llm=llm, gnn=gnn) + # Set model save name model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' if model_save_name == 'llm': model = llm + # Create optimizer params = [p for _, p in model.named_parameters() if p.requires_grad] optimizer = torch.optim.AdamW([ { @@ -178,10 +290,12 @@ def adjust_learning_rate(param_group, LR, epoch): 'weight_decay': 0.05 }, ], betas=(0.9, 0.95)) - grad_steps = 2 + # Initialize best epoch and best validation loss best_epoch = 0 best_val_loss = float('inf') + + # Train model for epoch in range(num_epochs): model.train() epoch_loss = 0 @@ -198,18 +312,19 @@ def adjust_learning_rate(param_group, LR, epoch): clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) - if (step + 1) % grad_steps == 0: + if (step + 1) % 2 == 0: adjust_learning_rate(optimizer.param_groups[0], lr, step / len(train_loader) + epoch) optimizer.step() epoch_loss = epoch_loss + float(loss) - if (step + 1) % grad_steps == 0: + if (step + 1) % 2 == 0: lr = optimizer.param_groups[0]['lr'] train_loss = epoch_loss / len(train_loader) print(epoch_str + f', Train Loss: {train_loss:4f}') + # Evaluate model val_loss = 0 eval_output = [] model.eval() @@ -224,9 +339,12 @@ def adjust_learning_rate(param_group, LR, epoch): best_val_loss = val_loss best_epoch = epoch save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt') + + # Clean up memory torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() + # Load best checkpoint if necessary if checkpointing and best_epoch != num_epochs - 1: print("Loading best checkpoint...") model = load_params_dict( @@ -234,6 +352,7 @@ def adjust_learning_rate(param_group, LR, epoch): f'{model_save_name}_best_val_loss_ckpt.pt', ) + # Evaluate model on test set model.eval() eval_output = [] print("Final evaluation...") @@ -250,8 +369,13 @@ def adjust_learning_rate(param_group, LR, epoch): eval_output.append(eval_data) progress_bar_test.update(1) + # Compute metrics compute_metrics(eval_output) + + # Print final training time print(f"Total Training Time: {time.time() - start_time:2f}s") + + # Save model and evaluation output save_params_dict(model, f'{model_save_name}.pt') torch.save(eval_output, f'{model_save_name}_eval_outs.pt') From ef028547ff4459f6e98fe429d1564bd1d513fc31 Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:53:31 -0800 Subject: [PATCH 23/45] Fixed bug for `writer` initialized by `Chem.SDWriter(...)`. (#9929) Without the `writer.close()` statement, the file written by `writer` will not be closed properly. As a result, in our test the end of the file `/workspace/data/MoleculeGPT/raw/molecules.sdf` is missing. This is what it looks like: ``` 472184 RDKit 2D 1 0 0 0 0 0 0 0 0 0999 V2000 2.0000 0.0000 0.0000 Os 0 0 0 0 0 15 0 0 0 0 0 0 M CHG 1 1 4 M END > (4303) 472184 > (4303) 1 > (4303) 0 > (4303) 0 > (4303) 0 > (4303) 0 > (4303) AAADcQAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== > ` character, and the last molecule (#4303) is missing. As a result, we get a crash later when running the test: ``` Traceback (most recent call last): File "/workspace/examples/llm/molecule_gpt.py", line 187, in train( File "/workspace/examples/llm/molecule_gpt.py", line 69, in train dataset = MoleculeGPTDataset(path) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch_geometric/datasets/molecule_gpt_dataset.py", line 217, in __init__ super().__init__(root, transform, pre_transform, pre_filter, File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/in_memory_dataset.py", line 81, in __init__ super().__init__(root, transform, pre_transform, pre_filter, log, File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/dataset.py", line 115, in __init__ self._process() File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/dataset.py", line 262, in _process self.process() File "/usr/local/lib/python3.12/dist-packages/torch_geometric/datasets/molecule_gpt_dataset.py", line 436, in process CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ KeyError: 'PUBCHEM_OPENEYE_CAN_SMILES' ``` --- torch_geometric/datasets/molecule_gpt_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py index fed2fe503600..0fe4c9b9d589 100644 --- a/torch_geometric/datasets/molecule_gpt_dataset.py +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -371,6 +371,7 @@ def extract_one_SDF_file(block_id: int) -> None: writer.write(mol) valid_mol_count += 1 + writer.close() print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") sys.stdout.flush() return @@ -410,6 +411,7 @@ def extract_one_SDF_file(block_id: int) -> None: print(f"block id: {block_id} with 0 valid SDF file") continue + writer.close() print(f"In total: {len(found_CID_set)} molecules") # Step 05. Convert to PyG data format From f46ebda4a83177b6341b4ce84a960209f554a724 Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:37:47 -0800 Subject: [PATCH 24/45] [BugFix] Fixing two `lightning` tests. (#9931) The tests `test_lightning_dataset` and `test_lightning_node_data` began failing after the recent update of `pytorch_lightning` to version 2.5. The cause of these failures is the new way for calculation of `str(datamodule)` Since 2024-12-10 they have started to use a new method: ``` ~/Projects/pytorch-lightning$ git blame src/lightning/pytorch/core/datamodule.py -L 249,253 9709c645c8 (Nikita Tatsch 2024-12-10 19:39:25 -0500 249) def __str__(self) -> str: 9709c645c8 (Nikita Tatsch 2024-12-10 19:39:25 -0500 250) """Return a string representation of the datasets that are set up. 9709c645c8 (Nikita Tatsch 2024-12-10 19:39:25 -0500 251) 9709c645c8 (Nikita Tatsch 2024-12-10 19:39:25 -0500 252) Returns: 9709c645c8 (Nikita Tatsch 2024-12-10 19:39:25 -0500 253) A string representation of the datasets that are setup. ``` In 2.4 the operation `str(datamodule)` was performed by ``` 304 def __repr__(self) -> str: 305 -> kwargs = kwargs_repr( 306 train_dataset=self.train_dataset, 307 val_dataset=self.val_dataset, 308 test_dataset=self.test_dataset, 309 pred_dataset=self.pred_dataset, 310 **self.kwargs, 311 ) 312 return f'{self.__class__.__name__}({kwargs})' ``` This PR addresses the issue while maintaining backward compatibility, ensuring that these tests run correctly with `pytorch_lightning` versions earlier than 2.5. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- test/data/lightning/test_datamodule.py | 52 ++++++++++++++++++-------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/test/data/lightning/test_datamodule.py b/test/data/lightning/test_datamodule.py index cb86d810dfa8..8f66e05918ff 100644 --- a/test/data/lightning/test_datamodule.py +++ b/test/data/lightning/test_datamodule.py @@ -18,6 +18,7 @@ MyFeatureStore, MyGraphStore, get_random_edge_index, + has_package, onlyCUDA, onlyFullTest, onlyNeighborSampler, @@ -114,12 +115,19 @@ def expect_rank_zero_user_warning(match: str): num_workers=3, shuffle=True) assert 'shuffle' not in datamodule.kwargs old_x = train_dataset._data.x.clone() - assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), ' - 'val_dataset=MUTAG(30), ' - 'test_dataset=MUTAG(10), ' - 'pred_dataset=MUTAG(98), batch_size=5, ' - 'num_workers=3, pin_memory=True, ' - 'persistent_workers=True)') + new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: size=30}\n' + '{Test dataloader: size=10}\n' + '{Predict dataloader: size=98}' if new_datamodule_repr + else 'LightningDataset(train_dataset=MUTAG(50), ' + 'val_dataset=MUTAG(30), ' + 'test_dataset=MUTAG(10), ' + 'pred_dataset=MUTAG(98), batch_size=5, ' + 'num_workers=3, pin_memory=True, ' + 'persistent_workers=True)') + assert str(datamodule) == datamodule_repr + trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = train_dataset._data.x @@ -133,10 +141,15 @@ def expect_rank_zero_user_warning(match: str): log_every_n_steps=1) datamodule = LightningDataset(train_dataset, batch_size=5) - assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), ' - 'batch_size=5, num_workers=0, ' - 'pin_memory=True, ' - 'persistent_workers=False)') + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: None}\n' + '{Test dataloader: None}\n{' + 'Predict dataloader: None}' if new_datamodule_repr + else 'LightningDataset(train_dataset=MUTAG(50), ' + 'batch_size=5, num_workers=0, ' + 'pin_memory=True, ' + 'persistent_workers=False)') + assert str(datamodule) == datamodule_repr with expect_rank_zero_user_warning("defined a `validation_step`"): trainer.fit(model, datamodule) @@ -231,11 +244,20 @@ def test_lightning_node_data(get_dataset, strategy_type, loader): num_workers=num_workers, **kwargs) old_x = data.x.clone().cpu() - assert str(datamodule) == (f'LightningNodeData(data={data_repr}, ' - f'loader={loader}, batch_size={batch_size}, ' - f'num_workers={num_workers}, {kwargs_repr}' - f'pin_memory={loader != "full"}, ' - f'persistent_workers={loader != "full"})') + new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') + flag = loader != 'full' + datamodule_repr = ( + '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' + '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' + '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' + '{Predict dataloader: ' + f'size={2708 if flag else 1}' + + '}' if new_datamodule_repr else f'LightningNodeData(data={data_repr}, ' + f'loader={loader}, batch_size={batch_size}, ' + f'num_workers={num_workers}, {kwargs_repr}' + f'pin_memory={flag}, ' + f'persistent_workers={flag})') + assert str(datamodule) == datamodule_repr + trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = data.x.cpu() From f90776218af9ae9c377156dbdc0565a5750ceb5b Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Mon, 13 Jan 2025 19:54:26 -0800 Subject: [PATCH 25/45] [FutureWarning] Fixing warning triggered by `torch.cuda.reset_max_memory_allocated()` usage. (#9930) This PR addresses the following warning that occurs during the ``` /usr/local/lib/python3.12/dist-packages/torch/cuda/memory.py:374: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. ``` that occurs during the `molecule_gpt` test: ``` python3 /workspace/examples/llm/molecule_gpt.py ``` --- examples/llm/molecule_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py index 8f6c6024014d..6f11d87969a4 100644 --- a/examples/llm/molecule_gpt.py +++ b/examples/llm/molecule_gpt.py @@ -167,7 +167,7 @@ def adjust_learning_rate(param_group, LR, epoch): f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 ) torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() print(f"Total Training Time: {time.time() - start_time:2f}s") # Test From f7771568996da2368c68bb4ce92b31972e3f0343 Mon Sep 17 00:00:00 2001 From: Bulat Kerimov <38624058+erytheis@users.noreply.github.com> Date: Tue, 14 Jan 2025 05:17:20 +0100 Subject: [PATCH 26/45] Fix `GlobalStorage.is_node_attr` and `GlobalStorage.is_edge_attr` when passing tuples in `cat_dim` (#9927) The PR fixes the issue when calling is_node_attr and is_edge_attr for Data objects with sparse attributes that assume block-diagonal concatenation ( [#9895](https://github.com/pyg-team/pytorch_geometric/issues/9895) and in [#8709](https://github.com/pyg-team/pytorch_geometric/issues/8709)). These cases raise an error due to cat_dim being a tuple, the change catches this case and looks at the 0-th dimension of the sparse tensor. --------- Co-authored-by: rusty1s --- CHANGELOG.md | 1 + torch_geometric/data/storage.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7321dbb30413..8691f4af5671 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756)) - Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766)) - Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665)) +- Fixed `is_node_attr()` and `is_edge_attr()` errors when `cat_dim` is a tuple ([#9895](https://github.com/pyg-team/pytorch_geometric/issues/9895)) ### Removed diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index 07a52fdfc21a..e39bcebd6931 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -806,6 +806,10 @@ def is_node_attr(self, key: str) -> bool: return False cat_dim = self._parent().__cat_dim__(key, value, self) + + if not isinstance(cat_dim, int): + return False + num_nodes, num_edges = self.num_nodes, self.num_edges if value.shape[cat_dim] != num_nodes: @@ -852,6 +856,10 @@ def is_edge_attr(self, key: str) -> bool: return False cat_dim = self._parent().__cat_dim__(key, value, self) + + if not isinstance(cat_dim, int): + return False + num_nodes, num_edges = self.num_nodes, self.num_edges if value.shape[cat_dim] != num_edges: From c9e563c8578c90990a0aae1b32032ca48744f340 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 14 Jan 2025 05:30:40 +0100 Subject: [PATCH 27/45] Introduce `LinkPredMetric._prepare`; allow for less predictions than `k` (#9939) --- readthedocs.yml | 3 + test/data/lightning/test_datamodule.py | 64 ++++++++++---------- test/metrics/test_link_pred_metric.py | 30 ++++++++++ torch_geometric/metrics/link_pred.py | 83 +++++++++++++++----------- 4 files changed, 114 insertions(+), 66 deletions(-) diff --git a/readthedocs.yml b/readthedocs.yml index 44ed0553c035..9c60a9b183ca 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,8 @@ version: 2 +sphinx: + configuration: docs/source/conf.py + build: os: ubuntu-22.04 tools: diff --git a/test/data/lightning/test_datamodule.py b/test/data/lightning/test_datamodule.py index 8f66e05918ff..00f97dc3c27a 100644 --- a/test/data/lightning/test_datamodule.py +++ b/test/data/lightning/test_datamodule.py @@ -115,17 +115,18 @@ def expect_rank_zero_user_warning(match: str): num_workers=3, shuffle=True) assert 'shuffle' not in datamodule.kwargs old_x = train_dataset._data.x.clone() - new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') - datamodule_repr = ('{Train dataloader: size=50}\n' - '{Validation dataloader: size=30}\n' - '{Test dataloader: size=10}\n' - '{Predict dataloader: size=98}' if new_datamodule_repr - else 'LightningDataset(train_dataset=MUTAG(50), ' - 'val_dataset=MUTAG(30), ' - 'test_dataset=MUTAG(10), ' - 'pred_dataset=MUTAG(98), batch_size=5, ' - 'num_workers=3, pin_memory=True, ' - 'persistent_workers=True)') + if has_package('pytorch_lightning>=2.5.0'): + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: size=30}\n' + '{Test dataloader: size=10}\n' + '{Predict dataloader: size=98}') + else: + datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), ' + 'val_dataset=MUTAG(30), ' + 'test_dataset=MUTAG(10), ' + 'pred_dataset=MUTAG(98), batch_size=5, ' + 'num_workers=3, pin_memory=True, ' + 'persistent_workers=True)') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule) @@ -141,14 +142,16 @@ def expect_rank_zero_user_warning(match: str): log_every_n_steps=1) datamodule = LightningDataset(train_dataset, batch_size=5) - datamodule_repr = ('{Train dataloader: size=50}\n' - '{Validation dataloader: None}\n' - '{Test dataloader: None}\n{' - 'Predict dataloader: None}' if new_datamodule_repr - else 'LightningDataset(train_dataset=MUTAG(50), ' - 'batch_size=5, num_workers=0, ' - 'pin_memory=True, ' - 'persistent_workers=False)') + if has_package('pytorch_lightning>=2.5.0'): + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: None}\n' + '{Test dataloader: None}\n{' + 'Predict dataloader: None}') + else: + datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), ' + 'batch_size=5, num_workers=0, ' + 'pin_memory=True, ' + 'persistent_workers=False)') assert str(datamodule) == datamodule_repr with expect_rank_zero_user_warning("defined a `validation_step`"): @@ -244,18 +247,19 @@ def test_lightning_node_data(get_dataset, strategy_type, loader): num_workers=num_workers, **kwargs) old_x = data.x.clone().cpu() - new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') flag = loader != 'full' - datamodule_repr = ( - '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' - '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' - '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' - '{Predict dataloader: ' + f'size={2708 if flag else 1}' + - '}' if new_datamodule_repr else f'LightningNodeData(data={data_repr}, ' - f'loader={loader}, batch_size={batch_size}, ' - f'num_workers={num_workers}, {kwargs_repr}' - f'pin_memory={flag}, ' - f'persistent_workers={flag})') + if has_package('pytorch_lightning>=2.5.0'): + datamodule_repr = ( + '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' + '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' + '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' + '{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}') + else: + datamodule_repr = (f'LightningNodeData(data={data_repr}, ' + f'loader={loader}, batch_size={batch_size}, ' + f'num_workers={num_workers}, {kwargs_repr}' + f'pin_memory={flag}, ' + f'persistent_workers={flag})') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index f4bb5c1afb87..1dd30a352193 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -54,6 +54,11 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): expected = torch.tensor(values).mean() assert torch.allclose(out, expected) + # Test with `k > pred_mat.size(1)`: + metric.update(top_k_pred_mat[:, :k - 1], edge_label_index) + metric.compute() + metric.reset() + def test_recall(): pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) @@ -66,6 +71,11 @@ def test_recall(): assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5)) + # Test with `k > pred_mat.size(1)`: + metric.update(pred_mat[:, :1], edge_label_index) + metric.compute() + metric.reset() + def test_f1(): pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) @@ -77,6 +87,11 @@ def test_f1(): result = metric.compute() assert float(result) == pytest.approx(0.6500) + # Test with `k > pred_mat.size(1)`: + metric.update(pred_mat[:, :1], edge_label_index) + metric.compute() + metric.reset() + def test_map(): pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) @@ -88,6 +103,11 @@ def test_map(): result = metric.compute() assert float(result) == pytest.approx(0.6250) + # Test with `k > pred_mat.size(1)`: + metric.update(pred_mat[:, :1], edge_label_index) + metric.compute() + metric.reset() + def test_ndcg(): pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) @@ -100,6 +120,11 @@ def test_ndcg(): assert float(result) == pytest.approx(0.6934264) + # Test with `k > pred_mat.size(1)`: + metric.update(pred_mat[:, :1], edge_label_index) + metric.compute() + metric.reset() + def test_mrr(): pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) @@ -111,3 +136,8 @@ def test_mrr(): result = metric.compute() assert float(result) == pytest.approx((1 + 0.5 + 0) / 3) + + # Test with `k > pred_mat.size(1)`: + metric.update(pred_mat[:, :1], edge_label_index) + metric.compute() + metric.reset() diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index 8785cf487915..a450bcd1a5e2 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -43,34 +43,15 @@ def __init__(self, k: int) -> None: self.register_buffer('accum', torch.tensor(0.)) self.register_buffer('total', torch.tensor(0)) - def update( - self, + @staticmethod + def _prepare( pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], - ) -> None: - r"""Updates the state variables based on the current mini-batch - prediction. - - :meth:`update` can be repeated multiple times to accumulate the results - of successive predictions, *e.g.*, inside a mini-batch training or - evaluation loop. - - Args: - pred_index_mat (torch.Tensor): The top-:math:`k` predictions of - every example in the mini-batch with shape - :obj:`[batch_size, k]`. - edge_label_index (torch.Tensor): The ground-truth indices for every - example in the mini-batch, given in COO format of shape - :obj:`[2, num_ground_truth_indices]`. - """ - if pred_index_mat.size(1) != self.k: - raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} " - f"many indices for every entry " - f"(got {pred_index_mat.size(1)})") - - # Compute a boolean matrix indicating if the k-th prediction is part of - # the ground-truth. We do this by flattening both prediction and - # target indices, and then determining overlaps via `torch.isin`. + ) -> Tuple[Tensor, Tensor]: + # Compute a boolean matrix indicating if the `k`-th prediction is part + # of the ground-truth, as well as the number of ground-truths for every + # example. We do this by flattening both prediction and ground-truth + # indices, and then determining overlaps via `torch.isin`. max_index = max( # type: ignore pred_index_mat.max() if pred_index_mat.numel() > 0 else 0, edge_label_index[1].max() @@ -88,7 +69,7 @@ def update( pred_isin_mat = torch.isin(flat_pred_index, flat_y_index) pred_isin_mat = pred_isin_mat.view(pred_index_mat.size()) - # Compute the number of targets per example: + # Compute the number of ground-truths per example: y_count = scatter( torch.ones_like(edge_label_index[0]), edge_label_index[0], @@ -97,11 +78,41 @@ def update( reduce='sum', ) - metric = self._compute(pred_isin_mat, y_count) + return pred_isin_mat, y_count + def _update_from_prepared( + self, + pred_isin_mat: Tensor, + y_count: Tensor, + ) -> None: + metric = self._compute(pred_isin_mat[:, :self.k], y_count) self.accum += metric.sum() self.total += (y_count > 0).sum() + def update( + self, + pred_index_mat: Tensor, + edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], + ) -> None: + r"""Updates the state variables based on the current mini-batch + prediction. + + :meth:`update` can be repeated multiple times to accumulate the results + of successive predictions, *e.g.*, inside a mini-batch training or + evaluation loop. + + Args: + pred_index_mat (torch.Tensor): The top-:math:`k` predictions of + every example in the mini-batch with shape + :obj:`[batch_size, k]`. + edge_label_index (torch.Tensor): The ground-truth indices for every + example in the mini-batch, given in COO format of shape + :obj:`[2, num_ground_truth_indices]`. + """ + pred_isin_mat, y_count = self._prepare(pred_index_mat, + edge_label_index) + self._update_from_prepared(pred_isin_mat, y_count) + def compute(self) -> Tensor: r"""Computes the final metric value.""" if self.total == 0: @@ -182,8 +193,9 @@ class LinkPredMAP(LinkPredMetric): higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - cum_precision = (torch.cumsum(pred_isin_mat, dim=1) / - torch.arange(1, self.k + 1, device=y_count.device)) + device = pred_isin_mat.device + arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device) + cum_precision = pred_isin_mat.cumsum(dim=1) / arange return ((cum_precision * pred_isin_mat).sum(dim=-1) / y_count.clamp(min=1e-7, max=self.k)) @@ -210,7 +222,8 @@ def __init__(self, k: int): self.register_buffer('idcg', cumsum(multiplier)) def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - dcg = (pred_isin_mat * self.multiplier.view(1, -1)).sum(dim=-1) + multiplier = self.multiplier[:pred_isin_mat.size(1)].view(1, -1) + dcg = (pred_isin_mat * multiplier).sum(dim=-1) idcg = self.idcg[y_count.clamp(max=self.k)] out = dcg / idcg @@ -228,8 +241,6 @@ class LinkPredMRR(LinkPredMetric): higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - rank = pred_isin_mat.type(torch.uint8).argmax(dim=-1) - is_correct = pred_isin_mat.gather(1, rank.view(-1, 1)).view(-1) - reciprocals = 1.0 / (rank + 1) - reciprocals[~is_correct] = 0.0 - return reciprocals + device = pred_isin_mat.device + arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device) + return (pred_isin_mat / arange).max(dim=-1)[0] From 8ce58faa9f6689286fa3022e8f8029d3aad72b48 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 14 Jan 2025 06:43:12 +0100 Subject: [PATCH 28/45] Add `LinkPredMetricCollection` (#9941) --- .github/labeler.yml | 4 ++ CHANGELOG.md | 5 +- test/metrics/test_link_pred_metric.py | 90 +++++++++++++++++------- torch_geometric/metrics/__init__.py | 2 + torch_geometric/metrics/link_pred.py | 99 ++++++++++++++++++++++++++- 5 files changed, 172 insertions(+), 28 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index ae2f57cc925a..8a2f60bbd87b 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -34,6 +34,10 @@ transform: - changed-files: - any-glob-to-any-file: "torch_geometric/transforms/**/*" +metrics: + - changed-files: + - any-glob-to-any-file: "torch_geometric/metrics/**/*" + utils: - changed-files: - any-glob-to-any-file: "torch_geometric/utils/**/*" diff --git a/CHANGELOG.md b/CHANGELOG.md index 8691f4af5671..50ff227f428c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) -- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) +- Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) - Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807)) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 1dd30a352193..6256eeb4d6b3 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -6,6 +6,7 @@ from torch_geometric.metrics import ( LinkPredF1, LinkPredMAP, + LinkPredMetricCollection, LinkPredMRR, LinkPredNDCG, LinkPredPrecision, @@ -25,7 +26,7 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): pred = torch.rand(num_src_nodes, num_dst_nodes) pred[row, col] += 0.3 # Offset positive links by a little. - top_k_pred_mat = pred.topk(k, dim=1)[1] + pred_index_mat = pred.topk(k, dim=1)[1] metric = LinkPredPrecision(k) assert str(metric) == f'LinkPredPrecision(k={k})' @@ -39,7 +40,7 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): arange[node_id] = torch.arange(node_id.numel()) y_batch = arange[y_batch] - metric.update(top_k_pred_mat[node_id], (y_batch, y_index)) + metric.update(pred_index_mat[node_id], (y_batch, y_index)) out = metric.compute() metric.reset() @@ -48,96 +49,135 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): for i in range(num_src_nodes): # Naive computation per node: y_index = col[row == i] if y_index.numel() > 0: - mask = torch.isin(top_k_pred_mat[i], y_index) + mask = torch.isin(pred_index_mat[i], y_index) precision = float(mask.sum() / k) values.append(precision) expected = torch.tensor(values).mean() assert torch.allclose(out, expected) - # Test with `k > pred_mat.size(1)`: - metric.update(top_k_pred_mat[:, :k - 1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :k - 1], edge_label_index) metric.compute() metric.reset() def test_recall(): - pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) + pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) metric = LinkPredRecall(k=2) assert str(metric) == 'LinkPredRecall(k=2)' - metric.update(pred_mat, edge_label_index) + metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5)) - # Test with `k > pred_mat.size(1)`: - metric.update(pred_mat[:, :1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_f1(): - pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) + pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) metric = LinkPredF1(k=2) assert str(metric) == 'LinkPredF1(k=2)' - metric.update(pred_mat, edge_label_index) + metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6500) - # Test with `k > pred_mat.size(1)`: - metric.update(pred_mat[:, :1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_map(): - pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) + pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) metric = LinkPredMAP(k=2) assert str(metric) == 'LinkPredMAP(k=2)' - metric.update(pred_mat, edge_label_index) + metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6250) - # Test with `k > pred_mat.size(1)`: - metric.update(pred_mat[:, :1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_ndcg(): - pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) + pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]]) metric = LinkPredNDCG(k=2) assert str(metric) == 'LinkPredNDCG(k=2)' - metric.update(pred_mat, edge_label_index) + metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx(0.6934264) - # Test with `k > pred_mat.size(1)`: - metric.update(pred_mat[:, :1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() def test_mrr(): - pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) + pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]]) metric = LinkPredMRR(k=2) assert str(metric) == 'LinkPredMRR(k=2)' - metric.update(pred_mat, edge_label_index) + metric.update(pred_index_mat, edge_label_index) result = metric.compute() assert float(result) == pytest.approx((1 + 0.5 + 0) / 3) - # Test with `k > pred_mat.size(1)`: - metric.update(pred_mat[:, :1], edge_label_index) + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index) metric.compute() metric.reset() + + +@pytest.mark.parametrize('num_src_nodes', [10]) +@pytest.mark.parametrize('num_dst_nodes', [50]) +@pytest.mark.parametrize('num_edges', [200]) +def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges): + metrics = [ + LinkPredMAP(k=10), + LinkPredPrecision(k=100), + LinkPredRecall(k=50), + ] + + row = torch.randint(0, num_src_nodes, (num_edges, )) + col = torch.randint(0, num_dst_nodes, (num_edges, )) + edge_label_index = torch.stack([row, col], dim=0) + + pred = torch.rand(num_src_nodes, num_dst_nodes) + pred[row, col] += 0.3 # Offset positive links by a little. + pred_index_mat = pred.argsort(dim=1) + + metric_collection = LinkPredMetricCollection(metrics) + assert str(metric_collection) == ( + 'LinkPredMetricCollection([\n' + ' LinkPredMAP@10: LinkPredMAP(k=10),\n' + ' LinkPredPrecision@100: LinkPredPrecision(k=100),\n' + ' LinkPredRecall@50: LinkPredRecall(k=50),\n' + '])') + assert metric_collection.max_k == 100 + + expected = {} + for metric in metrics: + metric.update(pred_index_mat[:, :metric.k], edge_label_index) + out = metric.compute() + expected[f'{metric.__class__.__name__}@{metric.k}'] = out + metric.reset() + + metric_collection.update(pred_index_mat, edge_label_index) + assert metric_collection.compute() == expected + metric_collection.reset() diff --git a/torch_geometric/metrics/__init__.py b/torch_geometric/metrics/__init__.py index 1340829b8e42..e142a7dc1152 100644 --- a/torch_geometric/metrics/__init__.py +++ b/torch_geometric/metrics/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .link_pred import ( + LinkPredMetricCollection, LinkPredPrecision, LinkPredRecall, LinkPredF1, @@ -10,6 +11,7 @@ ) link_pred_metrics = [ + 'LinkPredMetricCollection', 'LinkPredPrecision', 'LinkPredRecall', 'LinkPredF1', diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index a450bcd1a5e2..fa5ff371a08b 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor @@ -144,6 +144,103 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}(k={self.k})' +class LinkPredMetricCollection(torch.nn.ModuleDict): + r"""A collection of metrics to reduce and speed-up computation of link + prediction metrics. + + .. code-block:: python + + from torch_geometric.metrics import ( + LinkPredMAP, + LinkPredMetricCollection, + LinkPredPrecision, + LinkPredRecall, + ) + + metrics = LinkPredMetricCollection([ + LinkPredMAP(k=10), + LinkPredPrecision(k=100), + LinkPredRecall(k=50), + ]) + + metrics.update(pred_index_mat, edge_label_index) + out = metrics.compute() + metrics.reset() + + print(out) + >>> {'LinkPredMAP@10': tensor(0.375), + ... 'LinkPredPrecision@100': tensor(0.127), + ... 'LinkPredRecall@50': tensor(0.483)} + + Args: + metrics: The link prediction metrics. + """ + def __init__( + self, + metrics: Union[ + List[LinkPredMetric], + Dict[str, LinkPredMetric], + ], + ) -> None: + super().__init__() + + if isinstance(metrics, (list, tuple)): + metrics = { + f'{metric.__class__.__name__}@{metric.k}': metric + for metric in metrics + } + assert len(metrics) > 0 + assert isinstance(metrics, dict) + + for name, metric in metrics.items(): + self[name] = metric + + @property + def max_k(self) -> int: + r"""The maximum number of top-:math:`k` predictions to evaluate + against. + """ + return max([metric.k for metric in self.values()]) + + def update( # type: ignore + self, + pred_index_mat: Tensor, + edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], + ) -> None: + r"""Updates the state variables based on the current mini-batch + prediction. + + :meth:`update` can be repeated multiple times to accumulate the results + of successive predictions, *e.g.*, inside a mini-batch training or + evaluation loop. + + Args: + pred_index_mat (torch.Tensor): The top-:math:`k` predictions of + every example in the mini-batch with shape + :obj:`[batch_size, k]`. + edge_label_index (torch.Tensor): The ground-truth indices for every + example in the mini-batch, given in COO format of shape + :obj:`[2, num_ground_truth_indices]`. + """ + pred_isin_mat, y_count = LinkPredMetric._prepare( + pred_index_mat, edge_label_index) + for metric in self.values(): + metric._update_from_prepared(pred_isin_mat, y_count) + + def compute(self) -> Dict[str, Tensor]: + r"""Computes the final metric values.""" + return {name: metric.compute() for name, metric in self.items()} + + def reset(self) -> None: + r"""Reset metric state variables to their default value.""" + for metric in self.values(): + metric.reset() + + def __repr__(self) -> str: + names = [f' {name}: {metric},\n' for name, metric in self.items()] + return f'{self.__class__.__name__}([\n{"".join(names)}])' + + class LinkPredPrecision(LinkPredMetric): r"""A link prediction metric to compute Precision @ :math:`k`. From c828200c6a65f1a32f725d32a95eaee6837616e5 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 14 Jan 2025 07:19:47 +0100 Subject: [PATCH 29/45] Limit concurrency in nightly tests; Add more labels (#9942) --- .github/labeler.yml | 10 +++++++++- .github/workflows/full_testing.yml | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 8a2f60bbd87b..f166b1cab675 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,6 +1,14 @@ +installation: + - changed-files: + - any-glob-to-any-file: ["pyproject.toml"] + +ci: + - changed-files: + - any-glob-to-any-file: [".github/**/*", "codecov.yaml", ".pre-commit-config.yaml"] + documentation: - changed-files: - - any-glob-to-any-file: "docs/**/*" + - any-glob-to-any-file: ["docs/**/*", "readthedocs.yml", "README.MD"] example: - changed-files: diff --git a/.github/workflows/full_testing.yml b/.github/workflows/full_testing.yml index 28e7b74740a2..1bdb65508924 100644 --- a/.github/workflows/full_testing.yml +++ b/.github/workflows/full_testing.yml @@ -12,6 +12,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: + max-parallel: 10 fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-14] From 60a1141949b113fb068ad085a32758a209a025a2 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 14 Jan 2025 07:20:11 +0100 Subject: [PATCH 30/45] Introduce weighted `LinkPredMetric` interface (#9943) --- torch_geometric/metrics/link_pred.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index fa5ff371a08b..d2f67272cbc3 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -23,6 +23,7 @@ class LinkPredMetric(BaseMetric): is_differentiable: bool = False full_state_update: bool = False higher_is_better: Optional[bool] = None + weighted: bool = False def __init__(self, k: int) -> None: super().__init__() @@ -93,6 +94,7 @@ def update( self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], + edge_label_weight: Optional[Tensor] = None, ) -> None: r"""Updates the state variables based on the current mini-batch prediction. @@ -108,7 +110,14 @@ def update( edge_label_index (torch.Tensor): The ground-truth indices for every example in the mini-batch, given in COO format of shape :obj:`[2, num_ground_truth_indices]`. + edge_label_weight (torch.Tensor, optional): The weight of the + ground-truth indices for every example in the mini-batch of + shape :obj:`[num_ground_truth_indices]`. Required for + weighted metrics and ignored otherwise. (default: :obj:`None`) """ + if self.weighted and edge_label_weight is None: + raise ValueError("'edge_label_weight' required for {self}") + pred_isin_mat, y_count = self._prepare(pred_index_mat, edge_label_index) self._update_from_prepared(pred_isin_mat, y_count) @@ -206,6 +215,7 @@ def update( # type: ignore self, pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], + edge_label_weight: Optional[Tensor] = None, ) -> None: r"""Updates the state variables based on the current mini-batch prediction. @@ -221,6 +231,10 @@ def update( # type: ignore edge_label_index (torch.Tensor): The ground-truth indices for every example in the mini-batch, given in COO format of shape :obj:`[2, num_ground_truth_indices]`. + edge_label_weight (torch.Tensor, optional): The weight of the + ground-truth indices for every example in the mini-batch of + shape :obj:`[num_ground_truth_indices]`. Required for + weighted metrics and ignored otherwise. (default: :obj:`None`) """ pred_isin_mat, y_count = LinkPredMetric._prepare( pred_index_mat, edge_label_index) @@ -248,6 +262,7 @@ class LinkPredPrecision(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: return pred_isin_mat.sum(dim=-1) / self.k @@ -260,6 +275,7 @@ class LinkPredRecall(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7) @@ -272,6 +288,7 @@ class LinkPredF1(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: isin_count = pred_isin_mat.sum(dim=-1) @@ -288,6 +305,7 @@ class LinkPredMAP(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: device = pred_isin_mat.device @@ -305,6 +323,7 @@ class LinkPredNDCG(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def __init__(self, k: int): super().__init__(k=k) @@ -336,6 +355,7 @@ class LinkPredMRR(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True + weighted: bool = False def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: device = pred_isin_mat.device From 4406bd767b3bb03ae498661f2d58316a6196b4e3 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 14 Jan 2025 18:05:36 +0900 Subject: [PATCH 31/45] Automatic num_params in LLM + update `GRetriever` default llm (#9938) This PR: - Introduces llm_model_name argument in g_retriever.py to allow specifying the LLM model - Make `num_params` optional in the `LLM` constructor, automatically determining it using `huggingface_hub` metadata if not provided - Change the default LLM model to the more recent `meta-llama/Meta-Llama-3.1-8B-Instruct` Previous model `meta-llama/Llama-2-7b-chat-hf` metrics: ``` Hit: 0.6966 Precision: 0.6250 Recall: 0.5344 F1: 0.5405 Total Training Time: 556.111935s ``` Newer model `meta-llama/Llama-3.1-8B-Instruct` metrics: ``` Hit: 0.7629 Precision: 0.7145 Recall: 0.6027 F1: 0.6190 Total Training Time: 572.248117s ``` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + examples/llm/g_retriever.py | 16 +++++++++------- torch_geometric/nn/nlp/llm.py | 13 ++++++++++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50ff227f428c..05153fd46005 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) - Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807)) +- Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938)) ### Deprecated diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 1fd654886208..0c7c10ae5f31 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -191,6 +191,7 @@ def train( batch_size, # Training batch size eval_batch_size, # Evaluation batch size lr, # Initial learning rate + llm_model_name, # `transformers` model name checkpointing=False, # Whether to checkpoint model tiny_llama=False, # Whether to use tiny LLaMA model ): @@ -203,6 +204,7 @@ def train( batch_size (int): Training batch size. eval_batch_size (int): Evaluation batch size. lr (float): Initial learning rate. + llm_model_name (str): The name of the LLM to use. checkpointing (bool, optional): Whether to checkpoint model. Defaults to False. tiny_llama (bool, optional): Whether to use tiny LLaMA model. @@ -267,14 +269,11 @@ def adjust_learning_rate(param_group, LR, epoch): # Create LLaMA model if tiny_llama: - llm = LLM( - model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', - num_params=1, - ) - model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) + llm = LLM(model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', ) else: - llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) - model = GRetriever(llm=llm, gnn=gnn) + llm = LLM(model_name=llm_model_name) + model = GRetriever(llm=llm, gnn=gnn, + mlp_out_channels=llm.word_embedding.embedding_dim) # Set model save name model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' @@ -390,6 +389,8 @@ def adjust_learning_rate(param_group, LR, epoch): parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument('--checkpointing', action='store_true') parser.add_argument('--tiny_llama', action='store_true') + parser.add_argument('--llm_model_name', type=str, + default="meta-llama/Meta-Llama-3.1-8B-Instruct") args = parser.parse_args() start_time = time.time() @@ -400,6 +401,7 @@ def adjust_learning_rate(param_group, LR, epoch): args.batch_size, args.eval_batch_size, args.lr, + args.llm_model_name, checkpointing=args.checkpointing, tiny_llama=args.tiny_llama, ) diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index d18aa42382f7..9c39cb2c6bef 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -51,17 +51,18 @@ class LLM(torch.nn.Module): model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"` or :obj:`"gemma"`. - num_params (int): An integer representing how many parameters the + num_params (int, optional): An integer representing how many parameters the HuggingFace model has, in billions. This is used to automatically allocate the correct number of GPUs needed, given the available GPU - memory of your GPUs. + memory of your GPUs. If not specified, the number of parameters + is determined using the `huggingface_hub` module. dtype (torch.dtype, optional): The data type to use for the LLM. (default :obj: `torch.bfloat16`) """ def __init__( self, model_name: str, - num_params: int, + num_params: int = None, dtype=torch.bfloat16, ) -> None: super().__init__() @@ -70,6 +71,12 @@ def __init__( from transformers import AutoModelForCausalLM, AutoTokenizer + if num_params is None: + from huggingface_hub import get_safetensors_metadata + safetensors_metadata = get_safetensors_metadata(model_name) + param_count = safetensors_metadata.parameter_count + num_params = list(param_count.values())[0] // 10**9 + # A rough heuristic on GPU memory requirements, e.g., we found that # LLAMA2 (7B parameters) fits on a 85GB GPU. required_memory = 85 * num_params / 7 From 51a5aa090e7ca1d8f947a46caa87817ff68a2908 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 14 Jan 2025 16:12:20 +0100 Subject: [PATCH 32/45] Support for `edge_label_weight` in `LinkPredMetric` (#9944) --- test/metrics/test_link_pred_metric.py | 6 + torch_geometric/metrics/link_pred.py | 250 +++++++++++++++++--------- 2 files changed, 168 insertions(+), 88 deletions(-) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 6256eeb4d6b3..a3ede9906000 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -152,6 +152,9 @@ def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges): LinkPredMAP(k=10), LinkPredPrecision(k=100), LinkPredRecall(k=50), + LinkPredF1(k=20), + LinkPredMRR(k=40), + LinkPredNDCG(k=80), ] row = torch.randint(0, num_src_nodes, (num_edges, )) @@ -168,6 +171,9 @@ def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges): ' LinkPredMAP@10: LinkPredMAP(k=10),\n' ' LinkPredPrecision@100: LinkPredPrecision(k=100),\n' ' LinkPredRecall@50: LinkPredRecall(k=50),\n' + ' LinkPredF1@20: LinkPredF1(k=20),\n' + ' LinkPredMRR@40: LinkPredMRR(k=40),\n' + ' LinkPredNDCG@80: LinkPredNDCG(k=80),\n' '])') assert metric_collection.max_k == 100 diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index d2f67272cbc3..c692576ea72f 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union import torch @@ -14,6 +15,73 @@ BaseMetric = torch.nn.Module # type: ignore +@dataclass(repr=False) +class LinkPredMetricData: + pred_index_mat: Tensor + edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]] + edge_label_weight: Optional[Tensor] = None + + @property + def pred_rel_mat(self) -> Tensor: + r"""Returns a matrix indicating the relevance of the `k`-th prediction. + If :obj:`edge_label_weight` is not given, relevance will be denoted as + binary. + """ + if hasattr(self, '_pred_rel_mat'): + return self._pred_rel_mat # type: ignore + + # Flatten both prediction and ground-truth indices, and determine + # overlaps afterwards via `torch.searchsorted`. + max_index = max( # type: ignore + self.pred_index_mat.max() + if self.pred_index_mat.numel() > 0 else 0, + self.edge_label_index[1].max() + if self.edge_label_index[1].numel() > 0 else 0, + ) + 1 + arange = torch.arange( + start=0, + end=max_index * self.pred_index_mat.size(0), # type: ignore + step=max_index, # type: ignore + device=self.pred_index_mat.device, + ).view(-1, 1) + flat_pred_index = (self.pred_index_mat + arange).view(-1) + flat_label_index = max_index * self.edge_label_index[0] + flat_label_index = flat_label_index + self.edge_label_index[1] + flat_label_index, perm = flat_label_index.sort() + edge_label_weight = self.edge_label_weight + if edge_label_weight is not None: + assert edge_label_weight.size() == self.edge_label_index[0].size() + edge_label_weight = edge_label_weight[perm] + + pos = torch.searchsorted(flat_label_index, flat_pred_index) + pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds. + + pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches + if edge_label_weight is not None: + pred_rel_mat = edge_label_weight[pos].where(pred_rel_mat, 0.0) + pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size()) + + self._pred_rel_mat = pred_rel_mat + return pred_rel_mat + + @property + def label_count(self) -> Tensor: + r"""The number of ground-truth labels for every example.""" + if hasattr(self, '_label_count'): + return self._label_count # type: ignore + + label_count = scatter( + torch.ones_like(self.edge_label_index[0]), + self.edge_label_index[0], + dim=0, + dim_size=self.pred_index_mat.size(0), + reduce='sum', + ) + + self._label_count = label_count + return label_count + + class LinkPredMetric(BaseMetric): r"""An abstract class for computing link prediction retrieval metrics. @@ -44,52 +112,6 @@ def __init__(self, k: int) -> None: self.register_buffer('accum', torch.tensor(0.)) self.register_buffer('total', torch.tensor(0)) - @staticmethod - def _prepare( - pred_index_mat: Tensor, - edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], - ) -> Tuple[Tensor, Tensor]: - # Compute a boolean matrix indicating if the `k`-th prediction is part - # of the ground-truth, as well as the number of ground-truths for every - # example. We do this by flattening both prediction and ground-truth - # indices, and then determining overlaps via `torch.isin`. - max_index = max( # type: ignore - pred_index_mat.max() if pred_index_mat.numel() > 0 else 0, - edge_label_index[1].max() - if edge_label_index[1].numel() > 0 else 0, - ) + 1 - arange = torch.arange( - start=0, - end=max_index * pred_index_mat.size(0), # type: ignore - step=max_index, # type: ignore - device=pred_index_mat.device, - ).view(-1, 1) - flat_pred_index = (pred_index_mat + arange).view(-1) - flat_y_index = max_index * edge_label_index[0] + edge_label_index[1] - - pred_isin_mat = torch.isin(flat_pred_index, flat_y_index) - pred_isin_mat = pred_isin_mat.view(pred_index_mat.size()) - - # Compute the number of ground-truths per example: - y_count = scatter( - torch.ones_like(edge_label_index[0]), - edge_label_index[0], - dim=0, - dim_size=pred_index_mat.size(0), - reduce='sum', - ) - - return pred_isin_mat, y_count - - def _update_from_prepared( - self, - pred_isin_mat: Tensor, - y_count: Tensor, - ) -> None: - metric = self._compute(pred_isin_mat[:, :self.k], y_count) - self.accum += metric.sum() - self.total += (y_count > 0).sum() - def update( self, pred_index_mat: Tensor, @@ -112,15 +134,28 @@ def update( :obj:`[2, num_ground_truth_indices]`. edge_label_weight (torch.Tensor, optional): The weight of the ground-truth indices for every example in the mini-batch of - shape :obj:`[num_ground_truth_indices]`. Required for - weighted metrics and ignored otherwise. (default: :obj:`None`) + shape :obj:`[num_ground_truth_indices]`. If given, needs to be + a vector of positive values. Required for weighted metrics, + ignored otherwise. (default: :obj:`None`) """ if self.weighted and edge_label_weight is None: - raise ValueError("'edge_label_weight' required for {self}") + raise ValueError(f"'edge_label_weight' is a required argument for " + f"weighted '{self.__class__.__name__}' metrics") + if not self.weighted: + edge_label_weight = None + + data = LinkPredMetricData( + pred_index_mat=pred_index_mat, + edge_label_index=edge_label_index, + edge_label_weight=edge_label_weight, + ) + self._update(data) - pred_isin_mat, y_count = self._prepare(pred_index_mat, - edge_label_index) - self._update_from_prepared(pred_isin_mat, y_count) + def _update(self, data: LinkPredMetricData) -> None: + metric = self._compute(data) + + self.accum += metric.sum() + self.total += (data.label_count > 0).sum() def compute(self) -> Tensor: r"""Computes the final metric value.""" @@ -129,28 +164,26 @@ def compute(self) -> Tensor: return self.accum / self.total def reset(self) -> None: - r"""Reset metric state variables to their default value.""" + r"""Resets metric state variables to their default value.""" if WITH_TORCHMETRICS: super().reset() else: self.accum.zero_() self.total.zero_() - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - r"""Compute the specific metric. + def _compute(self, data: LinkPredMetricData) -> Tensor: + r"""Computes the specific metric. To be implemented separately for each metric class. Args: - pred_isin_mat (torch.Tensor): A boolean matrix whose :obj:`(i,k)` - element indicates if the :obj:`k`-th prediction for the - :obj:`i`-th example is correct or not. - y_count (torch.Tensor): A vector indicating the number of - ground-truth labels for each example. + data (LinkPredMetricData): The mini-batch data for computing a link + prediction metric per example. """ raise NotImplementedError def __repr__(self) -> str: - return f'{self.__class__.__name__}(k={self.k})' + weighted_repr = ', weighted=True' if self.weighted else '' + return f'{self.__class__.__name__}(k={self.k}{weighted_repr})' class LinkPredMetricCollection(torch.nn.ModuleDict): @@ -211,6 +244,13 @@ def max_k(self) -> int: """ return max([metric.k for metric in self.values()]) + @property + def weighted(self) -> bool: + r"""Returns :obj:`True` in case the collection holds at least one + weighted link prediction metric. + """ + return any([metric.weighted for metric in self.values()]) + def update( # type: ignore self, pred_index_mat: Tensor, @@ -233,13 +273,37 @@ def update( # type: ignore :obj:`[2, num_ground_truth_indices]`. edge_label_weight (torch.Tensor, optional): The weight of the ground-truth indices for every example in the mini-batch of - shape :obj:`[num_ground_truth_indices]`. Required for - weighted metrics and ignored otherwise. (default: :obj:`None`) + shape :obj:`[num_ground_truth_indices]`. If given, needs to be + a vector of positive values. Required for weighted metrics, + ignored otherwise. (default: :obj:`None`) """ - pred_isin_mat, y_count = LinkPredMetric._prepare( - pred_index_mat, edge_label_index) + if self.weighted and edge_label_weight is None: + raise ValueError(f"'edge_label_weight' is a required argument for " + f"weighted '{self.__class__.__name__}' metrics") + if not self.weighted: + edge_label_weight = None + + data = LinkPredMetricData( # Share metric data across metrics. + pred_index_mat=pred_index_mat, + edge_label_index=edge_label_index, + edge_label_weight=edge_label_weight, + ) + + for metric in self.values(): + if metric.weighted: + metric._update(data) + if WITH_TORCHMETRICS: + metric._update_count += 1 + + data.edge_label_weight = None + if hasattr(data, '_pred_rel_mat'): + data._pred_rel_mat = data._pred_rel_mat != 0.0 + for metric in self.values(): - metric._update_from_prepared(pred_isin_mat, y_count) + if not metric.weighted: + metric._update(data) + if WITH_TORCHMETRICS: + metric._update_count += 1 def compute(self) -> Dict[str, Tensor]: r"""Computes the final metric values.""" @@ -264,8 +328,9 @@ class LinkPredPrecision(LinkPredMetric): higher_is_better: bool = True weighted: bool = False - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - return pred_isin_mat.sum(dim=-1) / self.k + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + return pred_rel_mat.sum(dim=-1) / self.k class LinkPredRecall(LinkPredMetric): @@ -277,8 +342,9 @@ class LinkPredRecall(LinkPredMetric): higher_is_better: bool = True weighted: bool = False - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7) + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7) class LinkPredF1(LinkPredMetric): @@ -290,10 +356,11 @@ class LinkPredF1(LinkPredMetric): higher_is_better: bool = True weighted: bool = False - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - isin_count = pred_isin_mat.sum(dim=-1) + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + isin_count = pred_rel_mat.sum(dim=-1) precision = isin_count / self.k - recall = isin_count = isin_count / y_count.clamp(min=1e-7) + recall = isin_count / data.label_count.clamp(min=1e-7) return 2 * precision * recall / (precision + recall).clamp(min=1e-7) @@ -307,12 +374,13 @@ class LinkPredMAP(LinkPredMetric): higher_is_better: bool = True weighted: bool = False - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - device = pred_isin_mat.device - arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device) - cum_precision = pred_isin_mat.cumsum(dim=1) / arange - return ((cum_precision * pred_isin_mat).sum(dim=-1) / - y_count.clamp(min=1e-7, max=self.k)) + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + device = pred_rel_mat.device + arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device) + cum_precision = pred_rel_mat.cumsum(dim=1) / arange + return ((cum_precision * pred_rel_mat).sum(dim=-1) / + data.label_count.clamp(min=1e-7, max=self.k)) class LinkPredNDCG(LinkPredMetric): @@ -321,12 +389,16 @@ class LinkPredNDCG(LinkPredMetric): Args: k (int): The number of top-:math:`k` predictions to evaluate against. + weighted (bool, optional): If set to :obj:`True`, assumes sorted lists + of ground-truth items according to a relevance score as given by + :obj:`edge_label_weight`. (default: :obj:`False`) """ higher_is_better: bool = True weighted: bool = False - def __init__(self, k: int): + def __init__(self, k: int, weighted: bool = False): super().__init__(k=k) + self.weighted = weighted dtype = torch.get_default_dtype() multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2() @@ -337,10 +409,11 @@ def __init__(self, k: int): self.idcg: Tensor self.register_buffer('idcg', cumsum(multiplier)) - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - multiplier = self.multiplier[:pred_isin_mat.size(1)].view(1, -1) - dcg = (pred_isin_mat * multiplier).sum(dim=-1) - idcg = self.idcg[y_count.clamp(max=self.k)] + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + multiplier = self.multiplier[:pred_rel_mat.size(1)].view(1, -1) + dcg = (pred_rel_mat * multiplier).sum(dim=-1) + idcg = self.idcg[data.label_count.clamp(max=self.k)] out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0 @@ -357,7 +430,8 @@ class LinkPredMRR(LinkPredMetric): higher_is_better: bool = True weighted: bool = False - def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: - device = pred_isin_mat.device - arange = torch.arange(1, pred_isin_mat.size(1) + 1, device=device) - return (pred_isin_mat / arange).max(dim=-1)[0] + def _compute(self, data: LinkPredMetricData) -> Tensor: + pred_rel_mat = data.pred_rel_mat[:, :self.k] + device = pred_rel_mat.device + arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device) + return (pred_rel_mat / arange).max(dim=-1)[0] From 371678c4979575f7e058019370b7d6584a550e3b Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 15 Jan 2025 04:57:54 +0100 Subject: [PATCH 33/45] Support weighted `LinkPredNDCG` metric (#9945) --- CHANGELOG.md | 1 + test/metrics/test_link_pred_metric.py | 16 +++++++++ torch_geometric/metrics/link_pred.py | 50 ++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05153fd46005..fc5ea7b3fbf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) - Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index a3ede9906000..8e1964fc8b0a 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -113,6 +113,7 @@ def test_map(): def test_ndcg(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]]) + edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5]) metric = LinkPredNDCG(k=2) assert str(metric) == 'LinkPredNDCG(k=2)' @@ -126,6 +127,21 @@ def test_ndcg(): metric.compute() metric.reset() + metric = LinkPredNDCG(k=2, weighted=True) + assert str(metric) == 'LinkPredNDCG(k=2, weighted=True)' + with pytest.raises(ValueError, match="'edge_label_weight'"): + metric.update(pred_index_mat, edge_label_index) + + metric.update(pred_index_mat, edge_label_index, edge_label_weight) + result = metric.compute() + + assert float(result) == pytest.approx(0.7854486) + + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) + metric.compute() + metric.reset() + def test_mrr(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]]) diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index c692576ea72f..fafd0c9d2f3f 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +from torch_geometric.index import index2ptr from torch_geometric.utils import cumsum, scatter try: @@ -58,7 +59,10 @@ def pred_rel_mat(self) -> Tensor: pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches if edge_label_weight is not None: - pred_rel_mat = edge_label_weight[pos].where(pred_rel_mat, 0.0) + pred_rel_mat = edge_label_weight[pos].where( + pred_rel_mat, + pred_rel_mat.new_zeros(1), + ) pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size()) self._pred_rel_mat = pred_rel_mat @@ -401,19 +405,47 @@ def __init__(self, k: int, weighted: bool = False): self.weighted = weighted dtype = torch.get_default_dtype() - multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2() + discount = torch.arange(2, k + 2, dtype=dtype).log2() - self.multiplier: Tensor - self.register_buffer('multiplier', multiplier) + self.discount: Tensor + self.register_buffer('discount', discount) - self.idcg: Tensor - self.register_buffer('idcg', cumsum(multiplier)) + if not weighted: + self.register_buffer('idcg', cumsum(1.0 / discount)) + else: + self.idcg = None def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] - multiplier = self.multiplier[:pred_rel_mat.size(1)].view(1, -1) - dcg = (pred_rel_mat * multiplier).sum(dim=-1) - idcg = self.idcg[data.label_count.clamp(max=self.k)] + discount = self.discount[:pred_rel_mat.size(1)].view(1, -1) + dcg = (pred_rel_mat / discount).sum(dim=-1) + + if not self.weighted: + assert self.idcg is not None + idcg = self.idcg[data.label_count.clamp(max=self.k)] + else: + assert data.edge_label_weight is not None + # Sort weights in buckets via two sorts: + weight, perm = data.edge_label_weight.sort(descending=True) + batch = data.edge_label_index[0][perm] + batch, perm = torch.sort(batch, stable=True) + weight = weight[perm] + + # Shrink buckets that are larger than `k`: + arange = torch.arange(batch.size(0), device=batch.device) + ptr = index2ptr(batch, size=data.pred_index_mat.size(0)) + batched_arange = arange - ptr[batch] + mask = batched_arange < self.k + batch = batch[mask] + batched_arange = batched_arange[mask] + weight = weight[mask] + + # Compute ideal relevance matrix: + irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k) + irel_mat[batch * self.k + batched_arange] = weight + irel_mat = irel_mat.view(-1, self.k) + + idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1) out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0 From 8db0702d5242336f983b11a087fc96a474e09586 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 15 Jan 2025 05:38:45 +0100 Subject: [PATCH 34/45] More efficient way to compute weighted `IDCG` in `LinkPredNDCG` (#9946) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- test/metrics/test_link_pred_metric.py | 12 ++++--- torch_geometric/metrics/link_pred.py | 45 ++++++++++++++------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 8e1964fc8b0a..d459c6d8e4b1 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -112,14 +112,13 @@ def test_map(): def test_ndcg(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) - edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]]) - edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5]) + edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5]) metric = LinkPredNDCG(k=2) assert str(metric) == 'LinkPredNDCG(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() - assert float(result) == pytest.approx(0.6934264) # Test with `k > pred_index_mat.size(1)`: @@ -134,9 +133,14 @@ def test_ndcg(): metric.update(pred_index_mat, edge_label_index, edge_label_weight) result = metric.compute() - + metric.reset() assert float(result) == pytest.approx(0.7854486) + perm = torch.randperm(edge_label_weight.size(0)) + metric.update(pred_index_mat, edge_label_index[:, perm], + edge_label_weight[perm]) + assert metric.compute() == result + # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) metric.compute() diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index fafd0c9d2f3f..7f30429162fb 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -4,7 +4,6 @@ import torch from torch import Tensor -from torch_geometric.index import index2ptr from torch_geometric.utils import cumsum, scatter try: @@ -425,27 +424,29 @@ def _compute(self, data: LinkPredMetricData) -> Tensor: idcg = self.idcg[data.label_count.clamp(max=self.k)] else: assert data.edge_label_weight is not None - # Sort weights in buckets via two sorts: - weight, perm = data.edge_label_weight.sort(descending=True) - batch = data.edge_label_index[0][perm] - batch, perm = torch.sort(batch, stable=True) - weight = weight[perm] - - # Shrink buckets that are larger than `k`: - arange = torch.arange(batch.size(0), device=batch.device) - ptr = index2ptr(batch, size=data.pred_index_mat.size(0)) - batched_arange = arange - ptr[batch] - mask = batched_arange < self.k - batch = batch[mask] - batched_arange = batched_arange[mask] - weight = weight[mask] - - # Compute ideal relevance matrix: - irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k) - irel_mat[batch * self.k + batched_arange] = weight - irel_mat = irel_mat.view(-1, self.k) - - idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1) + # Sort weights within example-wise buckets via two sorts to get the + # local index order within buckets: + weight, batch = data.edge_label_weight, data.edge_label_index[0] + perm1 = weight.argsort(descending=True) + perm2 = batch[perm1].argsort(stable=True) + global_index = torch.empty_like(perm1) + global_index[perm1[perm2]] = torch.arange( + global_index.size(0), device=global_index.device) + local_index = global_index - cumsum(data.label_count)[batch] + + # Get the discount per local index: + discount = torch.cat([ + self.discount, + self.discount.new_full((1, ), fill_value=float('inf')), + ]) + discount = discount[local_index.clamp(max=self.k + 1)] + + idcg = scatter( # Apply discount and aggregate: + weight / discount, + batch, + dim_size=data.pred_index_mat.size(0), + reduce='sum', + ) out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0 From 50f56225a54b97a0b432935ec25fcc04e40b7d52 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 15 Jan 2025 07:33:54 +0100 Subject: [PATCH 35/45] Weighted `LinkPredRecall` (#9947) --- CHANGELOG.md | 1 + test/metrics/test_link_pred_metric.py | 16 +++++- torch_geometric/metrics/link_pred.py | 81 +++++++++++++++++++++------ 3 files changed, 79 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc5ea7b3fbf2..8cf9e8b3ba52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) - Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index d459c6d8e4b1..13e338cb1002 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -64,12 +64,12 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): def test_recall(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + edge_label_weight = torch.tensor([4.0, 1.0, 2.0, 3.0, 0.5]) metric = LinkPredRecall(k=2) assert str(metric) == 'LinkPredRecall(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() - assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5)) # Test with `k > pred_index_mat.size(1)`: @@ -77,6 +77,20 @@ def test_recall(): metric.compute() metric.reset() + metric = LinkPredRecall(k=2, weighted=True) + assert str(metric) == 'LinkPredRecall(k=2, weighted=True)' + with pytest.raises(ValueError, match="'edge_label_weight'"): + metric.update(pred_index_mat, edge_label_index) + + metric.update(pred_index_mat, edge_label_index, edge_label_weight) + result = metric.compute() + assert float(result) == pytest.approx(0.5 * (5.0 / 7.0 + 3.0 / 3.5)) + + # Test with `k > pred_index_mat.size(1)`: + metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) + metric.compute() + metric.reset() + def test_f1(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index 7f30429162fb..22e6652f16a6 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -84,6 +84,51 @@ def label_count(self) -> Tensor: self._label_count = label_count return label_count + @property + def label_weight_sum(self) -> Tensor: + r"""The sum of edge label weights for every example.""" + if self.edge_label_weight is None: + return self.label_count + + if hasattr(self, '_label_weight_sum'): + return self._label_weight_sum # type: ignore + + label_weight_sum = scatter( + self.edge_label_weight, + self.edge_label_index[0], + dim=0, + dim_size=self.pred_index_mat.size(0), + reduce='sum', + ) + + self._label_weight_sum = label_weight_sum + return label_weight_sum + + @property + def edge_label_weight_pos(self) -> Optional[Tensor]: + r"""Returns the position of edge label weights in descending order + within example-wise buckets. + """ + if self.edge_label_weight is None: + return None + + if hasattr(self, '_edge_label_weight_pos'): + return self._edge_label_weight_pos # type: ignore + + # Get the permutation via two sorts: One globally on the weights, + # followed by a (stable) sort on the example indices. + perm1 = self.edge_label_weight.argsort(descending=True) + perm2 = self.edge_label_index[0][perm1].argsort(stable=True) + perm = perm1[perm2] + # Invert the permutation to get the final position: + pos = torch.empty_like(perm) + pos[perm] = torch.arange(perm.size(0), device=perm.device) + # Normalize position to zero within all buckets: + pos = pos - cumsum(self.label_count)[self.edge_label_index[0]] + + self._edge_label_weight_pos = pos + return pos + class LinkPredMetric(BaseMetric): r"""An abstract class for computing link prediction retrieval metrics. @@ -231,7 +276,9 @@ def __init__( if isinstance(metrics, (list, tuple)): metrics = { - f'{metric.__class__.__name__}@{metric.k}': metric + (f'{"Weighted" if metric.weighted else ""}' + f'{metric.__class__.__name__}@{metric.k}'): + metric for metric in metrics } assert len(metrics) > 0 @@ -301,6 +348,10 @@ def update( # type: ignore data.edge_label_weight = None if hasattr(data, '_pred_rel_mat'): data._pred_rel_mat = data._pred_rel_mat != 0.0 + if hasattr(data, '_label_weight_sum'): + del data._label_weight_sum + if hasattr(data, '_edge_label_weight_pos'): + del data._edge_label_weight_pos for metric in self.values(): if not metric.weighted: @@ -343,11 +394,14 @@ class LinkPredRecall(LinkPredMetric): k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True - weighted: bool = False + + def __init__(self, k: int, weighted: bool = False): + super().__init__(k=k) + self.weighted = weighted def _compute(self, data: LinkPredMetricData) -> Tensor: pred_rel_mat = data.pred_rel_mat[:, :self.k] - return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7) + return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7) class LinkPredF1(LinkPredMetric): @@ -397,7 +451,6 @@ class LinkPredNDCG(LinkPredMetric): :obj:`edge_label_weight`. (default: :obj:`False`) """ higher_is_better: bool = True - weighted: bool = False def __init__(self, k: int, weighted: bool = False): super().__init__(k=k) @@ -424,26 +477,18 @@ def _compute(self, data: LinkPredMetricData) -> Tensor: idcg = self.idcg[data.label_count.clamp(max=self.k)] else: assert data.edge_label_weight is not None - # Sort weights within example-wise buckets via two sorts to get the - # local index order within buckets: - weight, batch = data.edge_label_weight, data.edge_label_index[0] - perm1 = weight.argsort(descending=True) - perm2 = batch[perm1].argsort(stable=True) - global_index = torch.empty_like(perm1) - global_index[perm1[perm2]] = torch.arange( - global_index.size(0), device=global_index.device) - local_index = global_index - cumsum(data.label_count)[batch] - - # Get the discount per local index: + pos = data.edge_label_weight_pos + assert pos is not None + discount = torch.cat([ self.discount, self.discount.new_full((1, ), fill_value=float('inf')), ]) - discount = discount[local_index.clamp(max=self.k + 1)] + discount = discount[pos.clamp(max=self.k + 1)] idcg = scatter( # Apply discount and aggregate: - weight / discount, - batch, + data.edge_label_weight / discount, + data.edge_label_index[0], dim_size=data.pred_index_mat.size(0), reduce='sum', ) From 6a1db0744adff36ca8ff1890779383d37460c1e0 Mon Sep 17 00:00:00 2001 From: Rorry Brenner <47458848+RorryB@users.noreply.github.com> Date: Thu, 16 Jan 2025 17:33:18 -0500 Subject: [PATCH 36/45] Graph Sage OGBN Example with Scheduler (#9877) This PR is to update the ogbn example to have a scheduler, improving performance from a test accuracy of 75.52% to 77.19% --- CHANGELOG.md | 1 + examples/ogbn_train.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cf9e8b3ba52..52440b408280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -855,6 +855,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in which `nn.models.GAT` did not produce `out_channels`-many output channels ([#4299](https://github.com/pyg-team/pytorch_geometric/pull/4299)) - Fixed mini-batching with empty lists as attributes ([#4293](https://github.com/pyg-team/pytorch_geometric/pull/4293)) - Fixed a bug in which `GCNConv` could not be combined with `to_hetero` on heterogeneous graphs with one node type ([#4279](https://github.com/pyg-team/pytorch_geometric/pull/4279)) +- Added a scheduler to the Graph Sage OGBN Example [#9877](https://github.com/pyg-team/pytorch_geometric/pull/9877) ### Removed diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py index 56a5c1c7a538..d975b86f51d0 100644 --- a/examples/ogbn_train.py +++ b/examples/ogbn_train.py @@ -33,7 +33,7 @@ action='store_true', help='Whether or not to use GAT model', ) -parser.add_argument('-e', '--epochs', type=int, default=10) +parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--num_heads', type=int, default=2, help='number of heads for GAT model.') @@ -179,6 +179,8 @@ def test(loader: NeighborLoader) -> float: lr=args.lr, weight_decay=args.wd, ) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', + patience=5) print(f'Total time before training begins took ' f'{time.perf_counter() - wall_clock_start:.4f}s') @@ -204,6 +206,10 @@ def test(loader: NeighborLoader) -> float: if val_acc > best_val: best_val = val_acc times.append(time.perf_counter() - train_start) + for param_group in optimizer.param_groups: + print('lr:') + print(param_group['lr']) + scheduler.step(val_acc) print(f'Average Epoch Time on training: ' f'{torch.tensor(train_times).mean():.4f}s') From 9e0be20a1126246a5759d1aca7f0a7a5f6a08a3b Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:40:32 -0800 Subject: [PATCH 37/45] Address `FutureWarning` on `nx.node_link_graph(..., edges="links")` (#9954) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta --- torch_geometric/datasets/ppi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/datasets/ppi.py b/torch_geometric/datasets/ppi.py index 019317dc34b7..e5e32eb5f9c6 100644 --- a/torch_geometric/datasets/ppi.py +++ b/torch_geometric/datasets/ppi.py @@ -107,7 +107,8 @@ def process(self) -> None: for s, split in enumerate(['train', 'valid', 'test']): path = osp.join(self.raw_dir, f'{split}_graph.json') with open(path) as f: - G = nx.DiGraph(json_graph.node_link_graph(json.load(f))) + G = nx.DiGraph( + json_graph.node_link_graph(json.load(f), edges="links")) x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy')) x = torch.from_numpy(x).to(torch.float) From 5fb2a8eaae3047a9dd430fd58e735e333daed79d Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Sat, 18 Jan 2025 04:17:01 -0800 Subject: [PATCH 38/45] Fixed a warning generated in `exemples/llm/glen.py` test. (#9955) --- examples/llm/glem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index c6bae703fd33..8a28d7359de6 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -79,7 +79,7 @@ def main(args): dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) split_idx = dataset.get_idx_split() - data = dataset.data + data = dataset[0] tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) From c31b7f949811a646bb68b063342cf977bbf48dca Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 20 Jan 2025 10:05:21 +0100 Subject: [PATCH 39/45] Expose `LinkPredMetric` (#9961) --- torch_geometric/metrics/__init__.py | 2 ++ torch_geometric/metrics/link_pred.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_geometric/metrics/__init__.py b/torch_geometric/metrics/__init__.py index e142a7dc1152..fd964263f8e8 100644 --- a/torch_geometric/metrics/__init__.py +++ b/torch_geometric/metrics/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .link_pred import ( + LinkPredMetric, LinkPredMetricCollection, LinkPredPrecision, LinkPredRecall, @@ -11,6 +12,7 @@ ) link_pred_metrics = [ + 'LinkPredMetric', 'LinkPredMetricCollection', 'LinkPredPrecision', 'LinkPredRecall', diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index 22e6652f16a6..b80ecc706054 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -139,7 +139,7 @@ class LinkPredMetric(BaseMetric): is_differentiable: bool = False full_state_update: bool = False higher_is_better: Optional[bool] = None - weighted: bool = False + weighted: bool def __init__(self, k: int) -> None: super().__init__() From 6d6ba700e6444815c50a47a9423a4377eee9cbb2 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 20 Jan 2025 18:07:57 +0900 Subject: [PATCH 40/45] Update docstring in GRetriever (#9960) Follow up @akihironitta 's comment in https://github.com/pyg-team/pytorch_geometric/pull/9938#discussion_r1920640308 --- torch_geometric/nn/nlp/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index 9c39cb2c6bef..904e9bc78191 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -62,8 +62,8 @@ class LLM(torch.nn.Module): def __init__( self, model_name: str, - num_params: int = None, - dtype=torch.bfloat16, + num_params: Optional[int] = None, + dtype: Optional[torch.dtype] = torch.bfloat16, ) -> None: super().__init__() From b481f6c51ced5f07df7fbdd7f0dd4ae02c9ef16b Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 20 Jan 2025 17:08:57 +0800 Subject: [PATCH 41/45] docs: fix typo in installation.rst (#9949) This pull request includes a small change to the `docs/source/install/installation.rst` file. The change updates the PyTorch version in the installation instructions for consistency. * Updated PyTorch version from 2.4 to 2.5 in the installation instructions. --------- Co-authored-by: Matthias Fey --- docs/source/install/installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/install/installation.rst b/docs/source/install/installation.rst index c3a8bd3d9fcb..36da50f9debd 100644 --- a/docs/source/install/installation.rst +++ b/docs/source/install/installation.rst @@ -89,7 +89,7 @@ For ease of installation of these extensions, we provide :obj:`pip` wheels for t where :obj:`${TORCH}` and :obj:`${CUDA}` should be replaced by the specific :pytorch:`PyTorch` and CUDA versions, respectively: - * :pytorch:`PyTorch` 2.4: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124` + * :pytorch:`PyTorch` 2.5: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124` * :pytorch:`PyTorch` 2.4: :obj:`${TORCH}=2.4.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124` * :pytorch:`PyTorch` 2.3: :obj:`${TORCH}=2.3.0` and :obj:`${CUDA}=cpu|cu118|cu121` * :pytorch:`PyTorch` 2.2: :obj:`${TORCH}=2.2.0` and :obj:`${CUDA}=cpu|cu118|cu121` From 9c73f9908bfce4d8c47d23fa00b55d83ad4f1399 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 20 Jan 2025 14:11:44 +0100 Subject: [PATCH 42/45] Fix `LinkPredMetric` for empty ground-truths (#9962) --- test/metrics/test_link_pred_metric.py | 17 +++++++++++++++++ torch_geometric/metrics/link_pred.py | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 13e338cb1002..67a57683d8a7 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -221,3 +221,20 @@ def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges): metric_collection.update(pred_index_mat, edge_label_index) assert metric_collection.compute() == expected metric_collection.reset() + + +def test_empty_ground_truth(): + pred = torch.rand(10, 5) + pred_index_mat = pred.argsort(dim=1) + edge_label_index = torch.empty(2, 0, dtype=torch.long) + edge_label_weight = torch.empty(0) + + metric = LinkPredMAP(k=5) + metric.update(pred_index_mat, edge_label_index) + assert metric.compute() == 0 + metric.reset() + + metric = LinkPredNDCG(k=5, weighted=True) + metric.update(pred_index_mat, edge_label_index, edge_label_weight) + assert metric.compute() == 0 + metric.reset() diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index b80ecc706054..b9ad50a642a5 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -30,6 +30,14 @@ def pred_rel_mat(self) -> Tensor: if hasattr(self, '_pred_rel_mat'): return self._pred_rel_mat # type: ignore + if self.edge_label_index[1].numel() == 0: + self._pred_rel_mat = torch.zeros_like( + self.pred_index_mat, + dtype=torch.bool if self.edge_label_weight is None else + torch.get_default_dtype(), + ) + return self._pred_rel_mat + # Flatten both prediction and ground-truth indices, and determine # overlaps afterwards via `torch.searchsorted`. max_index = max( # type: ignore From 47ac8186010bcd14ce14493f02075962d6b359d5 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 21 Jan 2025 10:34:49 +0100 Subject: [PATCH 43/45] Fix index-out-of-bounds issue in `LinkPredNDCG(weighted=True)` (#9963) --- torch_geometric/metrics/link_pred.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index b9ad50a642a5..95e274dcd619 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -492,7 +492,7 @@ def _compute(self, data: LinkPredMetricData) -> Tensor: self.discount, self.discount.new_full((1, ), fill_value=float('inf')), ]) - discount = discount[pos.clamp(max=self.k + 1)] + discount = discount[pos.clamp(max=self.k)] idcg = scatter( # Apply discount and aggregate: data.edge_label_weight / discount, From 9b794b600d41802edaad26aac18bff958fc8f642 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 23 Jan 2025 12:57:14 +0900 Subject: [PATCH 44/45] Add ComplexWebQuestions (CWQ) dataset (#9950) This PR adds the CWQ dataset, which is similar to WebQSP but larger and featuring more complex multi-hop questions Comparing the datasets: | Datasets | #Train | #Test | Max #hop | |----------|--------|-------|----------| | WebQSP | 2,826 | 1,628 | 2 | | CWQ | 27,639 | 3,531 | 4 | GRetriever performance: ``` LLM: meta-llama/Llama-3.1-8B-Instruct Hit@1: 0.5675 Precision: 0.5444 Recall: 0.5435 F1: 0.5310 Total Training Time: 4337.614916s ``` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +- examples/llm/g_retriever.py | 24 ++++++--- torch_geometric/datasets/__init__.py | 3 +- torch_geometric/datasets/web_qsp_dataset.py | 56 +++++++++++++++++++-- 4 files changed, 72 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52440b408280..73a5a3d2a6fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [2.7.0] - 2024-MM-DD +## [2.7.0] - 2025-MM-DD ### Added @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722)) - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) +- Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950)) ### Changed diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 0c7c10ae5f31..608ac2b0db5f 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -24,7 +24,7 @@ from tqdm import tqdm from torch_geometric import seed_everything -from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets import CWQDataset, WebQSPDataset from torch_geometric.loader import DataLoader from torch_geometric.nn.models import GAT, GRetriever from torch_geometric.nn.nlp import LLM @@ -89,7 +89,7 @@ def compute_metrics(eval_output): f1 = sum(all_f1) / len(all_f1) # Print metrics to console - print(f'Hit: {hit:.4f}') + print(f'Hit@1: {hit:.4f}') print(f'Precision: {precision:.4f}') print(f'Recall: {recall:.4f}') print(f'F1: {f1:.4f}') @@ -193,9 +193,10 @@ def train( lr, # Initial learning rate llm_model_name, # `transformers` model name checkpointing=False, # Whether to checkpoint model + cwq=False, # Whether to train on the CWQ dataset tiny_llama=False, # Whether to use tiny LLaMA model ): - """Train a GNN+LLM model on WebQSP dataset. + """Train a GNN+LLM model on WebQSP or CWQ dataset. Args: num_epochs (int): Total number of training epochs. @@ -207,6 +208,8 @@ def train( llm_model_name (str): The name of the LLM to use. checkpointing (bool, optional): Whether to checkpoint model. Defaults to False. + cwq (bool, optional): Whether to train on the CWQ dataset + instead of WebQSP. tiny_llama (bool, optional): Whether to use tiny LLaMA model. Defaults to False. @@ -240,10 +243,16 @@ def adjust_learning_rate(param_group, LR, epoch): # Load dataset and create data loaders path = osp.dirname(osp.realpath(__file__)) - path = osp.join(path, '..', '..', 'data', 'WebQSPDataset') - train_dataset = WebQSPDataset(path, split='train') - val_dataset = WebQSPDataset(path, split='val') - test_dataset = WebQSPDataset(path, split='test') + if not cwq: + path = osp.join(path, '..', '..', 'data', 'WebQSPDataset') + train_dataset = WebQSPDataset(path, split='train') + val_dataset = WebQSPDataset(path, split='val') + test_dataset = WebQSPDataset(path, split='test') + else: + path = osp.join(path, '..', '..', 'data', 'CWQDataset') + train_dataset = CWQDataset(path, split='train') + val_dataset = CWQDataset(path, split='val') + test_dataset = CWQDataset(path, split='test') seed_everything(42) @@ -388,6 +397,7 @@ def adjust_learning_rate(param_group, LR, epoch): parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument('--checkpointing', action='store_true') + parser.add_argument('--cwq', action='store_true') parser.add_argument('--tiny_llama', action='store_true') parser.add_argument('--llm_model_name', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct") diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 12895ad1dbac..0d48ba9c0e00 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -76,7 +76,7 @@ from .myket import MyketDataset from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset -from .web_qsp_dataset import WebQSPDataset +from .web_qsp_dataset import WebQSPDataset, CWQDataset from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset from .tag_dataset import TAGDataset @@ -193,6 +193,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'CWQDataset', 'GitMolDataset', 'MoleculeGPTDataset', 'TAGDataset', diff --git a/torch_geometric/datasets/web_qsp_dataset.py b/torch_geometric/datasets/web_qsp_dataset.py index 799eeceb3970..28ce8de3a554 100644 --- a/torch_geometric/datasets/web_qsp_dataset.py +++ b/torch_geometric/datasets/web_qsp_dataset.py @@ -117,12 +117,13 @@ def retrieval_via_pcst( return data, desc -class WebQSPDataset(InMemoryDataset): - r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse - Labeling for Knowledge Base Question Answering" - `_ paper. +class KGQABaseDataset(InMemoryDataset): + r"""Base class for the 2 KGQA datasets used in `"Reasoning on Graphs: + Faithful and Interpretable Large Language Model Reasoning" + `_ paper. Args: + dataset_name (str): HuggingFace `dataset` name. root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. @@ -134,11 +135,14 @@ class WebQSPDataset(InMemoryDataset): """ def __init__( self, + dataset_name: str, root: str, split: str = "train", force_reload: bool = False, use_pcst: bool = True, + use_cwq: bool = True, ) -> None: + self.dataset_name = dataset_name self.use_pcst = use_pcst super().__init__(root, force_reload=force_reload) @@ -156,7 +160,7 @@ def process(self) -> None: import datasets import pandas as pd - datasets = datasets.load_dataset('rmanluo/RoG-webqsp') + datasets = datasets.load_dataset(self.dataset_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_name = 'sentence-transformers/all-roberta-large-v1' @@ -244,3 +248,45 @@ def process(self) -> None: data_list.append(data) self.save(data_list, path) + + +class WebQSPDataset(KGQABaseDataset): + r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse + Labeling for Knowledge Base Question Answering" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + split (str, optional): If :obj:`"train"`, loads the training dataset. + If :obj:`"val"`, loads the validation dataset. + If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + use_pcst (bool, optional): Whether to preprocess the dataset's graph + with PCST or return the full graphs. (default: :obj:`True`) + """ + def __init__(self, root: str, split: str = "train", + force_reload: bool = False, use_pcst: bool = True) -> None: + dataset_name = 'rmanluo/RoG-webqsp' + super().__init__(dataset_name, root, split, force_reload, use_pcst) + + +class CWQDataset(KGQABaseDataset): + r"""The ComplexWebQuestions (CWQ) dataset of the `"The Web as a + Knowledge-base forAnswering Complex Questions" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + split (str, optional): If :obj:`"train"`, loads the training dataset. + If :obj:`"val"`, loads the validation dataset. + If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + use_pcst (bool, optional): Whether to preprocess the dataset's graph + with PCST or return the full graphs. (default: :obj:`True`) + """ + def __init__(self, root: str, split: str = "train", + force_reload: bool = False, use_pcst: bool = True) -> None: + dataset_name = 'rmanluo/RoG-cwq' + super().__init__(dataset_name, root, split, force_reload, use_pcst) From ed89c94904e6b2789c3a6720b365b47ddf90e3df Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 24 Jan 2025 23:50:10 +0800 Subject: [PATCH 45/45] Add `InstructMol` dataset (#9975) ### Issue #9699 ### Detail compare between InstructMol and MoleculeGPT - data: the same data structure but different data sources, molecular graph + smiles sequence + question + answer - model: almost the same model paradigm, multimodal + QA so in this PR I only implemented the InstructMol dataset and added it to the MoleculeGPT model example. --------- Co-authored-by: Rishi Puri --- CHANGELOG.md | 1 + examples/llm/README.md | 2 +- examples/llm/molecule_gpt.py | 14 +- test/datasets/test_instruct_mol_dataset.py | 11 ++ torch_geometric/datasets/__init__.py | 2 + .../datasets/instruct_mol_dataset.py | 134 ++++++++++++++++++ 6 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 test/datasets/test_instruct_mol_dataset.py create mode 100644 torch_geometric/datasets/instruct_mol_dataset.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a5a3d2a6fd..addce364df6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) diff --git a/examples/llm/README.md b/examples/llm/README.md index 4503e28ce6ee..d339b1ea1039 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -6,6 +6,6 @@ | [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | | [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | | [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | -| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset | | [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | | [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py index 6f11d87969a4..ceff16e8b1ef 100644 --- a/examples/llm/molecule_gpt.py +++ b/examples/llm/molecule_gpt.py @@ -11,7 +11,7 @@ from tqdm import tqdm from torch_geometric import seed_everything -from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv from torch_geometric.nn.models import MoleculeGPT @@ -44,6 +44,7 @@ def eval(model, data_loader): def train( + dataset_name: str, num_epochs: int, lr: float, batch_size: int, @@ -65,8 +66,11 @@ def adjust_learning_rate(param_group, LR, epoch): start_time = time.time() # Load dataset ================================================ path = osp.dirname(osp.realpath(__file__)) - path = osp.join(path, '..', '..', 'data', 'MoleculeGPT') - dataset = MoleculeGPTDataset(path) + path = osp.join(path, '..', '..', 'data', dataset_name) + if dataset_name == 'MoleculeGPT': + dataset = MoleculeGPTDataset(path) + elif dataset_name == 'InstructMol': + dataset = InstructMolDataset(path) train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) train_dataset = dataset[:train_size] val_dataset = dataset[train_size:train_size + val_size] @@ -177,6 +181,9 @@ def adjust_learning_rate(param_group, LR, epoch): if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument("--dataset_name", type=str, default='MoleculeGPT', + choices=['MoleculeGPT', 'InstructMol'], + help='Support MoleculeGPT and InstructMol') parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--batch_size', type=int, default=2) @@ -185,6 +192,7 @@ def adjust_learning_rate(param_group, LR, epoch): start_time = time.time() train( + args.dataset_name, args.epochs, args.lr, args.batch_size, diff --git a/test/datasets/test_instruct_mol_dataset.py b/test/datasets/test_instruct_mol_dataset.py new file mode 100644 index 000000000000..b225b48210e3 --- /dev/null +++ b/test/datasets/test_instruct_mol_dataset.py @@ -0,0 +1,11 @@ +from torch_geometric.datasets import InstructMolDataset +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('rdkit') +def test_instruct_mol_dataset(): + dataset = InstructMolDataset(root='./data/InstructMol') + assert len(dataset) == 326689 + assert dataset.num_edge_features == 4 + assert dataset.num_node_features == 6 diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 0d48ba9c0e00..e6468da7ae89 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -79,6 +79,7 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset +from .instruct_mol_dataset import InstructMolDataset from .tag_dataset import TAGDataset from .dbp15k import DBP15K @@ -196,6 +197,7 @@ 'CWQDataset', 'GitMolDataset', 'MoleculeGPTDataset', + 'InstructMolDataset', 'TAGDataset', ] diff --git a/torch_geometric/datasets/instruct_mol_dataset.py b/torch_geometric/datasets/instruct_mol_dataset.py new file mode 100644 index 000000000000..af490c6affc9 --- /dev/null +++ b/torch_geometric/datasets/instruct_mol_dataset.py @@ -0,0 +1,134 @@ +import json +import sys +from typing import Callable, List, Optional + +import torch +from tqdm import tqdm + +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.io import fs +from torch_geometric.utils import one_hot + + +class InstructMolDataset(InMemoryDataset): + r"""The dataset from the `"InstructMol: Multi-Modal Integration for + Building a Versatile and Reliable Molecular Assistant in Drug Discovery" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main' + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + ): + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['all_clean.json'] + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + print('downloading dataset...') + fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir) + + def process(self) -> None: + try: + from rdkit import Chem + from rdkit.Chem.rdchem import BondType as BT + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + # types of atom and bond + types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + # load data + mols = json.load(open(f'{self.raw_dir}/all_clean.json')) + + data_list = [] + for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + continue + + x: torch.Tensor = torch.tensor([ + types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 + for atom in mol.GetAtoms() + ]) + x = one_hot(x, num_classes=len(types), dtype=torch.float) + + rows, cols, edge_types = [], [], [] + for bond in mol.GetBonds(): + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_types += [bonds[bond.GetBondType()]] * 2 + rows += [i, j] + cols += [j, i] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_type = torch.tensor(edge_types, dtype=torch.long) + edge_attr = one_hot(edge_type, num_classes=len(bonds)) + + for question, answer in qa_pairs: + data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + smiles=smiles, + instruction=question, + y=answer, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0])