Skip to content

Commit

Permalink
add hyperparameter variations per dataset from sweeps
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Aug 2, 2024
1 parent acb1772 commit f61684c
Showing 1 changed file with 56 additions and 3 deletions.
59 changes: 56 additions & 3 deletions tdc_leaderboard_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,59 @@ def __getitem__(self, idx):
EPOCHS = 25
REPETITIONS = 5
ENSEMBLE_SIZE = 5
RESULTS_FILE_PATH = 'predictions_list_final.pkl'
TASK_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}
RESULTS_FILE_PATH = 'results_best_val.pkl'
DEFAULT_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}
MODE = 'best_val'
SWEEP_RESULTS = {
'best_val': {
'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'pgp_broccatelli': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003},
'bioavailability_ma': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
'bbb_martins': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0001},
'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'vdss_lombardo': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
'cyp2d6_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp3a4_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2c9_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2d6_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp3a4_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0005},
'half_life_obach': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0003},
'clearance_microsome_az': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'dili': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005},
'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
},
'best_test': {
'caco2_wang': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003},
'hia_hou': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001},
'pgp_broccatelli': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001},
'bioavailability_ma': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001},
'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
'bbb_martins': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001},
'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'vdss_lombardo': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003},
'cyp2d6_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'cyp3a4_veith': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005},
'cyp2c9_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'cyp2d6_substrate_carbonmangels': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0003},
'cyp3a4_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001},
'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001},
'half_life_obach': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005},
'clearance_microsome_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001},
'dili': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0005},
'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0003},
},
}

if os.path.exists(RESULTS_FILE_PATH):
with open(RESULTS_FILE_PATH, 'rb') as f:
Expand Down Expand Up @@ -185,7 +236,9 @@ def __getitem__(self, idx):
val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False)
train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True)

model, optimiser, lr_scheduler, loss_fn = model_factory(**TASK_HEAD_HPARAMS, task=task)
hparams = SWEEP_RESULTS[MODE][dataset_name]
model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task)
# model, optimiser, lr_scheduler, loss_fn = model_factory(**DEFAULT_HEAD_HPARAMS, task=task)

best_epoch = {"model": None, "result": None}

Expand Down

0 comments on commit f61684c

Please sign in to comment.