Skip to content

Commit

Permalink
enable batched featurisation + add batch_size parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Aug 5, 2024
1 parent c259bd4 commit c3fc3a3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
19 changes: 11 additions & 8 deletions minimol/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from omegaconf import OmegaConf
from typing import Union
import pkg_resources
from contextlib import redirect_stdout, redirect_stderr

from torch_geometric.nn import global_max_pool

Expand All @@ -23,12 +24,12 @@

class Minimol:

def __init__(self):
def __init__(self, batch_size: int = 100):
self.batch_size = batch_size
# handle the paths
state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth')
config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml')
base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml')

# Load the config
cfg = self.load_config(os.path.basename(config_path))
cfg = OmegaConf.to_container(cfg, resolve=True)
Expand Down Expand Up @@ -58,22 +59,24 @@ def __init__(self):
self.predictor = Fingerprinter(predictor, 'gnn:15')
self.predictor.setup()


def load_config(self, config_name):
hydra.initialize('ckpts/minimol_v1/', version_base=None)
cfg = hydra.compose(config_name=config_name)
return cfg

def __call__(self, smiles: Union[str,list]) -> torch.Tensor:
smiles = [smiles] if not isinstance(smiles, list) else smiles

input_features, _ = self.datamodule._featurize_molecules(smiles)
input_features = self.to_fp32(input_features)

batch_size = min(100, len(input_features))
batch_size = min(self.batch_size, len(smiles))

results = []
for i in tqdm(range(0, len(input_features), batch_size)):
batch = Batch.from_data_list(input_features[i:(i + batch_size)])
for i in tqdm(range(0, len(smiles), batch_size)):
with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output
input_features, _ = self.datamodule._featurize_molecules(smiles[i:(i + batch_size)])
input_features = self.to_fp32(input_features)

batch = Batch.from_data_list(input_features)
batch = {"features": batch, "batch_indices": batch.batch}
node_features = self.predictor.get_fingerprints_for_batch(batch)
fingerprint_graph = global_max_pool(node_features, batch['batch_indices'])
Expand Down
20 changes: 7 additions & 13 deletions notebooks/downstream_adaptation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"Found local copy...\n",
"generating training, validation splits...\n",
"generating training, validation splits...\n",
"100%|██████████| 461/461 [00:00<00:00, 3529.71it/s]\n"
"100%|██████████| 461/461 [00:00<00:00, 3648.38it/s]\n"
]
}
],
Expand Down Expand Up @@ -112,33 +112,27 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from minimol import Minimol\n",
"\n",
"featuriser = Minimol()"
"featuriser = Minimol(batch_size=50)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"featurizing_smiles, batch=1: 100%|██████████| 58/58 [00:04<00:00, 11.97it/s]\n",
"Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 11466.87it/s]\n",
"100%|██████████| 1/1 [00:00<00:00, 5.32it/s]\n",
"featurizing_smiles, batch=3: 100%|██████████| 39/39 [00:00<00:00, 416.25it/s]\n",
"Casting to FP32: 100%|██████████| 117/117 [00:00<00:00, 20868.96it/s]\n",
"100%|██████████| 2/2 [00:00<00:00, 8.08it/s]\n",
"featurizing_smiles, batch=13: 100%|██████████| 31/31 [00:00<00:00, 104.34it/s]\n",
"Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 15718.65it/s]\n",
"100%|██████████| 5/5 [00:00<00:00, 6.81it/s]\n"
"100%|██████████| 12/12 [00:25<00:00, 2.14s/it]\n",
"100%|██████████| 24/24 [00:01<00:00, 14.06it/s]\n",
"100%|██████████| 81/81 [00:05<00:00, 13.51it/s]\n"
]
}
],
Expand Down

0 comments on commit c3fc3a3

Please sign in to comment.