From ed89c94904e6b2789c3a6720b365b47ddf90e3df Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 24 Jan 2025 23:50:10 +0800 Subject: [PATCH] Add `InstructMol` dataset (#9975) ### Issue #9699 ### Detail compare between InstructMol and MoleculeGPT - data: the same data structure but different data sources, molecular graph + smiles sequence + question + answer - model: almost the same model paradigm, multimodal + QA so in this PR I only implemented the InstructMol dataset and added it to the MoleculeGPT model example. --------- Co-authored-by: Rishi Puri --- CHANGELOG.md | 1 + examples/llm/README.md | 2 +- examples/llm/molecule_gpt.py | 14 +- test/datasets/test_instruct_mol_dataset.py | 11 ++ torch_geometric/datasets/__init__.py | 2 + .../datasets/instruct_mol_dataset.py | 134 ++++++++++++++++++ 6 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 test/datasets/test_instruct_mol_dataset.py create mode 100644 torch_geometric/datasets/instruct_mol_dataset.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a5a3d2a6fd..addce364df6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) diff --git a/examples/llm/README.md b/examples/llm/README.md index 4503e28ce6ee..d339b1ea1039 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -6,6 +6,6 @@ | [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | | [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | | [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | -| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset | | [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | | [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py index 6f11d87969a4..ceff16e8b1ef 100644 --- a/examples/llm/molecule_gpt.py +++ b/examples/llm/molecule_gpt.py @@ -11,7 +11,7 @@ from tqdm import tqdm from torch_geometric import seed_everything -from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv from torch_geometric.nn.models import MoleculeGPT @@ -44,6 +44,7 @@ def eval(model, data_loader): def train( + dataset_name: str, num_epochs: int, lr: float, batch_size: int, @@ -65,8 +66,11 @@ def adjust_learning_rate(param_group, LR, epoch): start_time = time.time() # Load dataset ================================================ path = osp.dirname(osp.realpath(__file__)) - path = osp.join(path, '..', '..', 'data', 'MoleculeGPT') - dataset = MoleculeGPTDataset(path) + path = osp.join(path, '..', '..', 'data', dataset_name) + if dataset_name == 'MoleculeGPT': + dataset = MoleculeGPTDataset(path) + elif dataset_name == 'InstructMol': + dataset = InstructMolDataset(path) train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) train_dataset = dataset[:train_size] val_dataset = dataset[train_size:train_size + val_size] @@ -177,6 +181,9 @@ def adjust_learning_rate(param_group, LR, epoch): if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument("--dataset_name", type=str, default='MoleculeGPT', + choices=['MoleculeGPT', 'InstructMol'], + help='Support MoleculeGPT and InstructMol') parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--batch_size', type=int, default=2) @@ -185,6 +192,7 @@ def adjust_learning_rate(param_group, LR, epoch): start_time = time.time() train( + args.dataset_name, args.epochs, args.lr, args.batch_size, diff --git a/test/datasets/test_instruct_mol_dataset.py b/test/datasets/test_instruct_mol_dataset.py new file mode 100644 index 000000000000..b225b48210e3 --- /dev/null +++ b/test/datasets/test_instruct_mol_dataset.py @@ -0,0 +1,11 @@ +from torch_geometric.datasets import InstructMolDataset +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('rdkit') +def test_instruct_mol_dataset(): + dataset = InstructMolDataset(root='./data/InstructMol') + assert len(dataset) == 326689 + assert dataset.num_edge_features == 4 + assert dataset.num_node_features == 6 diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 0d48ba9c0e00..e6468da7ae89 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -79,6 +79,7 @@ from .web_qsp_dataset import WebQSPDataset, CWQDataset from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset +from .instruct_mol_dataset import InstructMolDataset from .tag_dataset import TAGDataset from .dbp15k import DBP15K @@ -196,6 +197,7 @@ 'CWQDataset', 'GitMolDataset', 'MoleculeGPTDataset', + 'InstructMolDataset', 'TAGDataset', ] diff --git a/torch_geometric/datasets/instruct_mol_dataset.py b/torch_geometric/datasets/instruct_mol_dataset.py new file mode 100644 index 000000000000..af490c6affc9 --- /dev/null +++ b/torch_geometric/datasets/instruct_mol_dataset.py @@ -0,0 +1,134 @@ +import json +import sys +from typing import Callable, List, Optional + +import torch +from tqdm import tqdm + +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.io import fs +from torch_geometric.utils import one_hot + + +class InstructMolDataset(InMemoryDataset): + r"""The dataset from the `"InstructMol: Multi-Modal Integration for + Building a Versatile and Reliable Molecular Assistant in Drug Discovery" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main' + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + ): + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['all_clean.json'] + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + print('downloading dataset...') + fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir) + + def process(self) -> None: + try: + from rdkit import Chem + from rdkit.Chem.rdchem import BondType as BT + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + # types of atom and bond + types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + # load data + mols = json.load(open(f'{self.raw_dir}/all_clean.json')) + + data_list = [] + for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + continue + + x: torch.Tensor = torch.tensor([ + types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 + for atom in mol.GetAtoms() + ]) + x = one_hot(x, num_classes=len(types), dtype=torch.float) + + rows, cols, edge_types = [], [], [] + for bond in mol.GetBonds(): + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_types += [bonds[bond.GetBondType()]] * 2 + rows += [i, j] + cols += [j, i] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_type = torch.tensor(edge_types, dtype=torch.long) + edge_attr = one_hot(edge_type, num_classes=len(bonds)) + + for question, answer in qa_pairs: + data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + smiles=smiles, + instruction=question, + y=answer, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0])