Skip to content

Commit

Permalink
stuff for the ablation study
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Aug 6, 2024
1 parent bf5e6d6 commit 30447c1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 31 deletions.
8 changes: 8 additions & 0 deletions ablation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
python tdc_leaderboard_submission.py g25_n4 >> results/g25_n4.txt
python tdc_leaderboard_submission.py l1000s >> results/l1000s.txt
python tdc_leaderboard_submission.py l1000s_pcba_g25 >> results/l1000s_pcba_g25.txt
python tdc_leaderboard_submission.py l1000s_pcba_n4 >> results/l1000s_pcba_n4.txt
# python tdc_leaderboard_submission.py pcba >> results/pcba.txt
# python tdc_leaderboard_submission.py pcba_g25_n4 >> results/pcba_g25_n4.txt
# python tdc_leaderboard_submission.py pcba_l1000s >> results/pcba_l1000s.txt
# python tdc_leaderboard_submission.py pcba_n4 >> results/pcba_n4.txt
10 changes: 7 additions & 3 deletions minimol/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@

class Minimol:

def __init__(self, batch_size: int = 100):
def __init__(self, batch_size: int = 100, ckpt_folder: str = None):
self.batch_size = batch_size
# handle the paths
state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth')
config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml')
if ckpt_folder is not None:
state_dict_path = os.path.join(ckpt_folder, 'state_dict.pth')
config_path = os.path.join(ckpt_folder, 'config.yaml')
else:
state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth')
config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml')
base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml')
# Load the config
cfg = self.load_config(os.path.basename(config_path))
Expand Down
70 changes: 42 additions & 28 deletions tdc_leaderboard_submission.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import builtins

original_print = print

def print(*args, **kwargs):
original_print(*args, flush=True, **kwargs)


from minimol import Minimol

import os
Expand Down Expand Up @@ -139,34 +147,39 @@ def __getitem__(self, idx):
return sample, target


import sys
ckpt_name = sys.argv[1]
base = '/home/blazejb/minimol/minimol/ckpts/'

EPOCHS = 25
REPETITIONS = 5
REPETITIONS = 3
ENSEMBLE_SIZE = 5
RESULTS_FILE_PATH = 'results_best_val.pkl'
SWEEP_RESULTS = {
'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'pgp_broccatelli': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003},
'bioavailability_ma': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
'bbb_martins': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0001},
'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'vdss_lombardo': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
'cyp2d6_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp3a4_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2c9_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2d6_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp3a4_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0005},
'half_life_obach': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0003},
'clearance_microsome_az': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'dili': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005},
'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
}
RESULTS_FILE_PATH = f'results_best_{ckpt_name}.pkl'
DEFAULT_HPARAMS = {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}
# SWEEP_RESULTS = {
# 'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
# 'pgp_broccatelli': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003},
# 'bioavailability_ma': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
# 'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'bbb_martins': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
# 'vdss_lombardo': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
# 'cyp2d6_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'cyp3a4_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'cyp2c9_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'cyp2d6_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'cyp3a4_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0005},
# 'half_life_obach': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0003},
# 'clearance_microsome_az': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
# 'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
# 'dili': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005},
# 'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
# }

if os.path.exists(RESULTS_FILE_PATH):
with open(RESULTS_FILE_PATH, 'rb') as f:
Expand All @@ -175,7 +188,7 @@ def __getitem__(self, idx):
predictions_list = []

group = admet_group(path='admet_data/')
featuriser = Minimol()
featuriser = Minimol(ckpt_folder=os.path.join(base, ckpt_name))

# LOOP 1: repetitions
for rep_i, seed1 in enumerate(range(1, REPETITIONS+1)):
Expand Down Expand Up @@ -208,7 +221,8 @@ def __getitem__(self, idx):
val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False)
train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True)

hparams = SWEEP_RESULTS[dataset_name]
# hparams = SWEEP_RESULTS[dataset_name]
hparams = DEFAULT_HPARAMS
model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task)

best_epoch = {"model": None, "result": None}
Expand Down

0 comments on commit 30447c1

Please sign in to comment.