-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into rebase-txt2kg
- Loading branch information
Showing
6 changed files
with
161 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from torch_geometric.datasets import InstructMolDataset | ||
from torch_geometric.testing import onlyFullTest, withPackage | ||
|
||
|
||
@onlyFullTest | ||
@withPackage('rdkit') | ||
def test_instruct_mol_dataset(): | ||
dataset = InstructMolDataset(root='./data/InstructMol') | ||
assert len(dataset) == 326689 | ||
assert dataset.num_edge_features == 4 | ||
assert dataset.num_node_features == 6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import json | ||
import sys | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from torch_geometric.data import Data, InMemoryDataset | ||
from torch_geometric.io import fs | ||
from torch_geometric.utils import one_hot | ||
|
||
|
||
class InstructMolDataset(InMemoryDataset): | ||
r"""The dataset from the `"InstructMol: Multi-Modal Integration for | ||
Building a Versatile and Reliable Molecular Assistant in Drug Discovery" | ||
<https://arxiv.org/pdf/2311.16208>`_ paper. | ||
Args: | ||
root (str): Root directory where the dataset should be saved. | ||
transform (callable, optional): A function/transform that takes in an | ||
:obj:`torch_geometric.data.Data` object and returns a transformed | ||
version. The data object will be transformed before every access. | ||
(default: :obj:`None`) | ||
pre_transform (callable, optional): A function/transform that takes in | ||
an :obj:`torch_geometric.data.Data` object and returns a | ||
transformed version. The data object will be transformed before | ||
being saved to disk. (default: :obj:`None`) | ||
pre_filter (callable, optional): A function that takes in an | ||
:obj:`torch_geometric.data.Data` object and returns a boolean | ||
value, indicating whether the data object should be included in the | ||
final dataset. (default: :obj:`None`) | ||
force_reload (bool, optional): Whether to re-process the dataset. | ||
(default: :obj:`False`) | ||
""" | ||
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main' | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
transform: Optional[Callable] = None, | ||
pre_transform: Optional[Callable] = None, | ||
pre_filter: Optional[Callable] = None, | ||
force_reload: bool = False, | ||
): | ||
super().__init__(root, transform, pre_transform, pre_filter, | ||
force_reload=force_reload) | ||
self.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self) -> List[str]: | ||
return ['all_clean.json'] | ||
|
||
@property | ||
def processed_file_names(self) -> List[str]: | ||
return ['data.pt'] | ||
|
||
def download(self) -> None: | ||
print('downloading dataset...') | ||
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir) | ||
|
||
def process(self) -> None: | ||
try: | ||
from rdkit import Chem | ||
from rdkit.Chem.rdchem import BondType as BT | ||
WITH_RDKIT = True | ||
|
||
except ImportError: | ||
WITH_RDKIT = False | ||
|
||
if not WITH_RDKIT: | ||
print(("Using a pre-processed version of the dataset. Please " | ||
"install 'rdkit' to alternatively process the raw data."), | ||
file=sys.stderr) | ||
|
||
data_list = fs.torch_load(self.raw_paths[0]) | ||
data_list = [Data(**data_dict) for data_dict in data_list] | ||
|
||
if self.pre_filter is not None: | ||
data_list = [d for d in data_list if self.pre_filter(d)] | ||
|
||
if self.pre_transform is not None: | ||
data_list = [self.pre_transform(d) for d in data_list] | ||
|
||
self.save(data_list, self.processed_paths[0]) | ||
return | ||
|
||
# types of atom and bond | ||
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} | ||
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} | ||
|
||
# load data | ||
mols = json.load(open(f'{self.raw_dir}/all_clean.json')) | ||
|
||
data_list = [] | ||
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)): | ||
mol = Chem.MolFromSmiles(smiles) | ||
if mol is None: | ||
continue | ||
|
||
x: torch.Tensor = torch.tensor([ | ||
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 | ||
for atom in mol.GetAtoms() | ||
]) | ||
x = one_hot(x, num_classes=len(types), dtype=torch.float) | ||
|
||
rows, cols, edge_types = [], [], [] | ||
for bond in mol.GetBonds(): | ||
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | ||
edge_types += [bonds[bond.GetBondType()]] * 2 | ||
rows += [i, j] | ||
cols += [j, i] | ||
|
||
edge_index = torch.tensor([rows, cols], dtype=torch.long) | ||
edge_type = torch.tensor(edge_types, dtype=torch.long) | ||
edge_attr = one_hot(edge_type, num_classes=len(bonds)) | ||
|
||
for question, answer in qa_pairs: | ||
data = Data( | ||
x=x, | ||
edge_index=edge_index, | ||
edge_attr=edge_attr, | ||
smiles=smiles, | ||
instruction=question, | ||
y=answer, | ||
) | ||
|
||
if self.pre_filter is not None and not self.pre_filter(data): | ||
continue | ||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
self.save(data_list, self.processed_paths[0]) |