Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated pipeline for evaluating pre-computed fingerprints #2

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
884f98e
download LargeMix datasets in a single command
Dec 11, 2023
8a35f54
add a script for finetuning on fingerprints on cpu
Dec 21, 2023
1c80129
fix target types for regression
Dec 21, 2023
63e54ad
add lr scheduling
Dec 22, 2023
8ab2887
sweeping script ready
Dec 22, 2023
e521b52
Delete sweep_finetunning_on_fingerprints.sh
blazejba Dec 22, 2023
e4fc5fa
various fixes
Dec 22, 2023
43495d8
Merge branch 'bb/finetuning-on-fingerprints' of github.com:graphcore-…
Dec 22, 2023
a501ad2
full TDC eval for a checkpoint implemented
Dec 22, 2023
b5d0fb6
add config + small fixes
Jan 3, 2024
20485bd
allow disabling wandb + drop last batch if not full
Jan 3, 2024
65cc2da
log number of trainable params to wandb
Jan 3, 2024
dee7d2e
add spearmanr metric to regression tasks
Jan 3, 2024
cbf847f
add fingerprint extraction
Jan 3, 2024
71dfc69
small bug fix + allow to add a suffix to the fingerprint filenames
Jan 4, 2024
a1e8e45
refactoring
Jan 4, 2024
5722964
auto extraction of scores from the fingerprinting sweeps
Jan 4, 2024
30b0a0f
fix a bug with regression eval + add worker to unfinished sweeps inst…
Jan 4, 2024
c3733f2
use args instead of consts for the sweeping script
Jan 4, 2024
e0fc687
fix small bugs
Jan 4, 2024
16c0b71
analyze_results now produces a table in a csv file that can be copied…
Jan 4, 2024
656410b
simplify the script by removing repeating stuff
Jan 5, 2024
9ba63c9
fix a bug in the sweeper
Jan 5, 2024
c4d2109
dump a csv with extracted scores from the sweeps
Jan 6, 2024
4237502
add weight decay + filter out nans
Jan 6, 2024
d5163cb
combine all results correctly
Jan 6, 2024
7a2fcdc
randomly choose with dataset to sweep to avoid workers collision
Jan 6, 2024
98fb812
filter out samples with NaNs in training
Jan 8, 2024
f9451ba
order the csv table and include empty rows to match the excel template
Jan 8, 2024
d2e14d0
a bug with casting targets
Jan 8, 2024
c41a606
run finetuning with 5-fold scaffold split and report min,max,mean and…
Jan 8, 2024
6e54218
analzye 5-fold cross validation results
Jan 9, 2024
a8f5876
script for analyzing the best hparams
Jan 11, 2024
282a944
fixing the metric for the metabolism datasets
Jan 11, 2024
0352b96
load mup in validation/fingerprint extraction
Jan 16, 2024
8880d0d
extract node and edge level features
Jan 17, 2024
41afd4e
yolo.py extends fingerprint training with node and edge level features
Jan 17, 2024
69a0a2c
data analysis
Jan 25, 2024
8b5c2a6
smaller sweeps + improved analytics
Jan 28, 2024
69f1fc1
add "fair" test score selection
Jan 30, 2024
074e009
ensemble evaluation
Jan 30, 2024
051e5a1
allow extracting node features as fingerprints
Jan 30, 2024
c5cbf5f
change program name
Jan 30, 2024
a06c802
remove Gradient link from README
kerstink-GC Mar 12, 2024
3fabe6f
update
Mar 26, 2024
b2cbe3d
cleaning
Apr 8, 2024
76aa26d
move files
Apr 8, 2024
abc4a4b
Merge branch 'bb/finetuning-on-fingerprints' of github.com:graphcore-…
Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions download_datasets.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

# Function to download dataset
download_dataset() {
local dataset_url=$1
local dataset_path=$2

# Create directory if it does not exist
mkdir -p $(dirname "${dataset_path}")

# Download the dataset
wget -O "${dataset_path}" "${dataset_url}"
}

# L1000_VCAP
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz" "graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz"
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt" "graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt"

# l1000_MCF7
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz" "graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz"
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt" "graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt"

# PCBA_1328
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet" "graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet"
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt" "graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt"

# PCQM4M_G25
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet" "graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet"
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt" "graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt"

# PCQM4M_N4
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet" "graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet"
download_dataset "https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt" "graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt"
244 changes: 244 additions & 0 deletions finetune_on_fingerprints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import wandb
import argparse
import torch
import json
from torch.utils.data import DataLoader, Dataset
from tdc.benchmark_group import admet_group
from functools import partial
import datamol as dm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from sklearn.metrics import roc_auc_score, average_precision_score, r2_score, mean_absolute_error


def train_one_epoch(model, dataloader, loss_fn, optimizer, task_type, epoch):
model.train()
total_loss = 0
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs.float())
targets = targets.long() if task_type == 'classification' else targets.float()
loss = loss_fn(outputs.squeeze(), targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
loss = total_loss / len(dataloader)
wandb.log({'epoch': epoch, 'train_loss': loss})
print(f"Epoch {epoch+1} - Train Loss: {loss}")
return model


def evaluate(model, dataloader, loss_fn, task_type, evaluation_type, epoch):
model.eval()
total_loss = 0
all_outputs = [] # For regression, store raw outputs
all_probs = [] # For classification, store probabilities
all_targets = []

with torch.no_grad():
for inputs, targets in dataloader:
outputs = model(inputs.float())
loss = loss_fn(outputs, targets.long() if task_type == 'classification' else targets.float())
total_loss += loss.item()

if task_type == 'classification':
probs = torch.softmax(outputs, dim=1)[:, 1]
all_probs.extend(probs.tolist())
else:
all_outputs.extend(outputs.squeeze().tolist())

all_targets.extend(targets.tolist())

loss = total_loss / len(dataloader)
metrics = {f'{evaluation_type}_loss': loss}

if task_type == 'classification':
auroc = roc_auc_score(all_targets, all_probs)
avpr = average_precision_score(all_targets, all_probs)
metrics.update({
f'{evaluation_type}_auroc': auroc,
f'{evaluation_type}_avpr': avpr,
})
else:
r2 = r2_score(all_targets, all_outputs)
mae = mean_absolute_error(all_targets, all_outputs)
metrics.update({
f'{evaluation_type}_r2': r2,
f'{evaluation_type}_mae': mae,
})

wandb.log({**metrics, 'epoch': epoch})
print(json.dumps(metrics, indent=5))
print()


class Model(nn.Module):
def __init__(self, input_dim, depth=3, hidden_dim=512, activation_fn='relu', combine_input='concat', num_classes=None, dropout_rate=0.1, **kwargs):
super(Model, self).__init__()

if depth < 2:
raise ValueError("Depth must be at least 2")

if depth == 2 and combine_input == 'concat' and hidden_dim != input_dim:
raise ValueError("When depth is 2 and combine_input is 'concat', hidden_dim must match input_dim")

self.depth = depth
self.hidden_dim = hidden_dim
self.combine_input = combine_input
self.dropout = nn.Dropout(dropout_rate)
self.layers = nn.ModuleList()

# Determine activation function
if activation_fn == 'relu':
self.activation_fn = F.relu
# Add other activation functions if necessary

# Create layers
for i in range(depth):
if self.combine_input == 'concat' and i == depth - 2:
in_dim = input_dim
out_dim = input_dim
elif self.combine_input == 'concat' and i == depth - 1:
in_dim = input_dim + hidden_dim # Doubled due to concatenation
out_dim = num_classes if num_classes is not None else 1
else:
in_dim = input_dim if i == 0 else hidden_dim
out_dim = hidden_dim if i < depth - 1 else (num_classes if num_classes is not None else 1)

self.layers.append(nn.Linear(in_dim, out_dim))

def forward(self, x):
original_x = x
for i, layer in enumerate(self.layers):
x = layer(x)
if i < self.depth - 1:
x = self.activation_fn(x)
x = self.dropout(x)

if self.combine_input == 'concat' and i == self.depth - 2:
x = torch.cat((x, original_x), dim=1)

return x


class SingleInstancePredictionDataset(Dataset):
def __init__(self, samples_df, task_type):
self.samples = samples_df['Drug'].tolist()
self.targets = samples_df['Y'].tolist()
if task_type == 'classification':
self.targets = [float(target) for target in self.targets]

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
sample = torch.tensor(self.samples[idx])
target = torch.tensor(self.targets[idx], dtype=torch.long)
return sample, target

def match_and_replace_input_column(samples_df, i2v):
transformed_df = samples_df.copy()
transformed_df["Drug"] = transformed_df['Drug'].apply(
lambda s: i2v[dm.unique_id(s)].detach().numpy())
return transformed_df

def determine_task_type(samples_df):
if np.issubdtype(samples_df['Y'].dtype, np.integer):
return 'classification', len(samples_df['Y'].unique())
else:
return 'regression', None

def model_summary(model):
print("Model Summary:")
print(model)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable Parameters: {trainable_params / 1e6:.2f} M")
return trainable_params

def dataloader_factory(benchmark, i2v, args):
match_and_replace_input_column_partial = partial(match_and_replace_input_column, i2v=i2v)

# Split the samples into train, val and test
train_samples = match_and_replace_input_column_partial(benchmark['train_val'])
test_samples = match_and_replace_input_column_partial(benchmark['test'])
train_samples = train_samples.sample(frac=1, random_state=42).reset_index(drop=True)
val_size = int(len(train_samples) * args.split)
val_samples = train_samples[:val_size]
train_samples = train_samples[val_size:]

# Create datasets
train_dataset = SingleInstancePredictionDataset(train_samples, args.task_type)
val_dataset = SingleInstancePredictionDataset(val_samples, args.task_type)
test_dataset = SingleInstancePredictionDataset(test_samples, args.task_type)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

input_dim = train_samples['Drug'].iloc[0].shape[0]

return train_loader, val_loader, test_loader, input_dim


def main():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--fingerprints-path', type=str, default='ids_to_fingerprint.pt', help='Path to ids_to_fingerprint.pt')
parser.add_argument('--bench', type=str, default='Caco2_Wang', help='Name of the benchmark from admet_group')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training')
parser.add_argument('--split', type=float, default=0.1, help='Ratio of validation set split')
parser.add_argument('--depth', type=int, default=3, help='Depth of the model')
# Model arch
parser.add_argument('--hidden-dim', type=int, default=512, help='Dimension of hidden layers')
parser.add_argument('--activation-fn', type=str, default='relu', choices=['relu'], help='Activation function')
parser.add_argument('--combine-input', type=str, default='concat', choices=['concat', 'other_option'], help='Method to combine input')
parser.add_argument('--dropout-rate', type=float, default=0.1, help='Dropout rate')

args = parser.parse_args()

# Load the id to fingerprint mapping
i2v = torch.load(args.fingerprints_path)

# Get the TDC data
group = admet_group(path='data/')
benchmark = group.get(args.bench)

# Determine task type and number of classes if classification
args.task_type, args.num_classes = determine_task_type(benchmark['train_val'])

# Construct dataloaders
train_dl, val_dl, test_dl, args.input_dim = dataloader_factory(benchmark, i2v, args)

# Define a model
model = Model(**vars(args))
args.trainable_params = model_summary(model)

# Initialize wandb
wandb.init(project='scaling_mol_gnns', entity='graphcore', name='10M-ipu_fingerprints_bbb-martins')
wandb.config.update(args)

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss() if args.task_type == 'classification' else nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Test random model
epoch = 0
evaluate(model, test_dl, loss_fn, args.task_type, evaluation_type='test', epoch=epoch)

# Training and validation loop
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}/{args.epochs}")
model = train_one_epoch(model, train_dl, loss_fn, optimizer, args.task_type, epoch)
evaluate(model, val_dl, loss_fn, args.task_type, evaluation_type='val', epoch=epoch)

# Test trained model
evaluate(model, test_dl, loss_fn, args.task_type, evaluation_type='test', epoch=epoch)
wandb.finish()

if __name__ == "__main__":
main()