Skip to content

Commit

Permalink
align predictions_list with group.evaluate_many expectation
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Aug 1, 2024
1 parent c2ad51c commit acb1772
Showing 1 changed file with 29 additions and 32 deletions.
61 changes: 29 additions & 32 deletions tdc_leaderboard_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def forward(self, x):
def model_factory(hidden_dim, depth, combine, task, lr, epochs=25, warmup=5, weight_decay=0.0001):
model = TaskHead(hidden_dim=hidden_dim, depth=depth, combine=combine)
optimiser = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

loss_fn = nn.BCELoss() if task == 'classification' else nn.MSELoss()

def lr_fn(epoch):
Expand Down Expand Up @@ -139,45 +138,43 @@ def __getitem__(self, idx):
target = torch.tensor(self.targets[idx])
return sample, target

EPOCHS = 25
RESULTS_FILE_PATH = 'predictions_list_largermodel.pkl'

group = admet_group(path='admet_data/')
featuriser = Minimol()
EPOCHS = 25
REPETITIONS = 5
ENSEMBLE_SIZE = 5
RESULTS_FILE_PATH = 'predictions_list_final.pkl'
TASK_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}

if os.path.exists(RESULTS_FILE_PATH):
with open(RESULTS_FILE_PATH, 'rb') as f:
predictions_list = pickle.load(f)
else:
predictions_list = []

# LOOP 1: datasets
for dataset_i, dataset_name in enumerate(group.dataset_names):
print(f"Dataset {dataset_name}, {dataset_i + 1} / {len(group.dataset_names)}")
if [list(d.keys())[0] for d in predictions_list].count(dataset_name) >= 5:
print(f"There are already 5 scores for the `{dataset_name}` dataset. Skipping.")
continue

benchmark = group.get(dataset_name)
name = benchmark['name']
mols_test = benchmark['test']
with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output
mols_test['Embedding'] = featuriser(list(mols_test['Drug']))
test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False)

# LOOP 2: repetitions
for rep_i, seed1 in enumerate([1, 2, 3, 4, 5]):
print(f"\tRepetition {rep_i + 1} / 5")

predictions = {}
group = admet_group(path='admet_data/')
featuriser = Minimol()

# LOOP 1: repetitions
for rep_i, seed1 in enumerate(range(1, REPETITIONS+1)):
print(f"Repetition {rep_i + 1} / 5")
predictions = {}

# LOOP 2: datasets
for dataset_i, dataset_name in enumerate(group.dataset_names):
print(f"\tDataset {dataset_name}, {dataset_i + 1} / {len(group.dataset_names)}")

benchmark = group.get(dataset_name)
name = benchmark['name']
mols_test = benchmark['test']
with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output
mols_test['Embedding'] = featuriser(list(mols_test['Drug']))
test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False)

task = 'classification' if len(benchmark['test']['Y'].unique()) == 2 else 'regression'
hparams = {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}
model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task)

best_models = []

# LOOP3: ensemble on folds
for fold_i, seed2 in enumerate([6, 7, 8, 9, 10]):
for fold_i, seed2 in enumerate(range(REPETITIONS+1, REPETITIONS+ENSEMBLE_SIZE+1)):
print(f"\t\tFold {fold_i + 1} / 5")
seed = cantor_pairing(seed1, seed2)

Expand All @@ -188,7 +185,7 @@ def __getitem__(self, idx):
val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False)
train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True)

model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task)
model, optimiser, lr_scheduler, loss_fn = model_factory(**TASK_HEAD_HPARAMS, task=task)

best_epoch = {"model": None, "result": None}

Expand All @@ -209,9 +206,9 @@ def __getitem__(self, idx):
y_pred_test = evaluate_ensemble(best_models, test_loader, task)

predictions[name] = y_pred_test
predictions_list.append(predictions)

with open(RESULTS_FILE_PATH, 'wb') as f:
pickle.dump(predictions_list, f)
predictions_list.append(predictions)
with open(RESULTS_FILE_PATH, 'wb') as f: pickle.dump(predictions_list, f)

results = group.evaluate_many(predictions_list)
print(results)

0 comments on commit acb1772

Please sign in to comment.