Skip to content

Commit

Permalink
Merge branch 'master' into rebase-txt2kg
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jan 24, 2025
2 parents c7a5e4d + ed89c94 commit 024cabe
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Adds TXT2KG class with example on HotPotQA ([#9846](https://github.com/pyg-team/pytorch_geometric/pull/9846))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand Down
5 changes: 2 additions & 3 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |

| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
| [`hotpot_qa.py`](./hotpot_qa.py) | Example for converting adapting the retrieval step of conventional Retrieval-Augmented Generation (RAG) for use with G-retriever, and how to approximate the precision/recall of a subgraph retrieval method. Uses the HotPotQA dataset from [Hugging Face](https://huggingface.co/datasets/hotpotqa/hotpot_qa). This is it is multihop in nature.|
| [`tech_qa.py`](./hotpot_qa.py) | Full end 2 end GraphRAG pipeline combining txt2kg and retrieval from `hotpot_qa.py` and training/testing from g_retriever.py. Uses the techQA dataset from [Hugging Face](https://huggingface.co/datasets/rojagtap/tech-qa)|
| [`tech_qa.py`](./tech_qa.py) | Full end 2 end GraphRAG pipeline combining txt2kg and retrieval from `hotpot_qa.py` and training/testing from g_retriever.py. Uses the techQA dataset from [Hugging Face](https://huggingface.co/datasets/rojagtap/tech-qa)|
14 changes: 11 additions & 3 deletions examples/llm/molecule_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import MoleculeGPTDataset
from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv
from torch_geometric.nn.models import MoleculeGPT
Expand Down Expand Up @@ -44,6 +44,7 @@ def eval(model, data_loader):


def train(
dataset_name: str,
num_epochs: int,
lr: float,
batch_size: int,
Expand All @@ -65,8 +66,11 @@ def adjust_learning_rate(param_group, LR, epoch):
start_time = time.time()
# Load dataset ================================================
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'MoleculeGPT')
dataset = MoleculeGPTDataset(path)
path = osp.join(path, '..', '..', 'data', dataset_name)
if dataset_name == 'MoleculeGPT':
dataset = MoleculeGPTDataset(path)
elif dataset_name == 'InstructMol':
dataset = InstructMolDataset(path)
train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + val_size]
Expand Down Expand Up @@ -177,6 +181,9 @@ def adjust_learning_rate(param_group, LR, epoch):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default='MoleculeGPT',
choices=['MoleculeGPT', 'InstructMol'],
help='Support MoleculeGPT and InstructMol')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=2)
Expand All @@ -185,6 +192,7 @@ def adjust_learning_rate(param_group, LR, epoch):

start_time = time.time()
train(
args.dataset_name,
args.epochs,
args.lr,
args.batch_size,
Expand Down
11 changes: 11 additions & 0 deletions test/datasets/test_instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_geometric.datasets import InstructMolDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('rdkit')
def test_instruct_mol_dataset():
dataset = InstructMolDataset(root='./data/InstructMol')
assert len(dataset) == 326689
assert dataset.num_edge_features == 4
assert dataset.num_node_features == 6
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .web_qsp_dataset import WebQSPDataset, CWQDataset
from .git_mol_dataset import GitMolDataset
from .molecule_gpt_dataset import MoleculeGPTDataset
from .instruct_mol_dataset import InstructMolDataset
from .tag_dataset import TAGDataset

from .dbp15k import DBP15K
Expand Down Expand Up @@ -196,6 +197,7 @@
'CWQDataset',
'GitMolDataset',
'MoleculeGPTDataset',
'InstructMolDataset',
'TAGDataset',
]

Expand Down
134 changes: 134 additions & 0 deletions torch_geometric/datasets/instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import sys
from typing import Callable, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
from torch_geometric.utils import one_hot


class InstructMolDataset(InMemoryDataset):
r"""The dataset from the `"InstructMol: Multi-Modal Integration for
Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
<https://arxiv.org/pdf/2311.16208>`_ paper.
Args:
root (str): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main'

def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return ['all_clean.json']

@property
def processed_file_names(self) -> List[str]:
return ['data.pt']

def download(self) -> None:
print('downloading dataset...')
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)

def process(self) -> None:
try:
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
WITH_RDKIT = True

except ImportError:
WITH_RDKIT = False

if not WITH_RDKIT:
print(("Using a pre-processed version of the dataset. Please "
"install 'rdkit' to alternatively process the raw data."),
file=sys.stderr)

data_list = fs.torch_load(self.raw_paths[0])
data_list = [Data(**data_dict) for data_dict in data_list]

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]

if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]

self.save(data_list, self.processed_paths[0])
return

# types of atom and bond
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

# load data
mols = json.load(open(f'{self.raw_dir}/all_clean.json'))

data_list = []
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
continue

x: torch.Tensor = torch.tensor([
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
for atom in mol.GetAtoms()
])
x = one_hot(x, num_classes=len(types), dtype=torch.float)

rows, cols, edge_types = [], [], []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_types += [bonds[bond.GetBondType()]] * 2
rows += [i, j]
cols += [j, i]

edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long)
edge_attr = one_hot(edge_type, num_classes=len(bonds))

for question, answer in qa_pairs:
data = Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
smiles=smiles,
instruction=question,
y=answer,
)

if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

self.save(data_list, self.processed_paths[0])

0 comments on commit 024cabe

Please sign in to comment.