Skip to content

Commit

Permalink
taking embeddings from a different place
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Aug 13, 2024
1 parent 7ac5ba7 commit 038d857
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
54 changes: 33 additions & 21 deletions minimol/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch_geometric.nn import global_max_pool

from graphium.trainer.predictor import PredictorModule
from graphium.finetuning.fingerprinting import Fingerprinter
from graphium.config._loader import (
load_accelerator,
Expand All @@ -24,12 +25,17 @@

class Minimol:

def __init__(self, batch_size: int = 100):
def __init__(self, batch_size: int = 100, base_path: str = None):
self.batch_size = batch_size
# handle the paths
state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth')
config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml')
base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml')
if base_path is None:
state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth')
config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml')
base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml')
else:
ckpt_path = os.path.join(base_path, 'model.ckpt')
config_path = os.path.join(base_path, 'config.yaml')
base_shape_path = os.path.join(base_path, 'base_shape.yaml')
# Load the config
cfg = self.load_config(os.path.basename(config_path))
cfg = OmegaConf.to_container(cfg, resolve=True)
Expand All @@ -42,21 +48,26 @@ def __init__(self, batch_size: int = 100):
# Load the model
model_class, model_kwargs = load_architecture(cfg, in_dims=self.datamodule.in_dims)
metrics = load_metrics(self.cfg)
predictor = load_predictor(
config=self.cfg,
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=self.datamodule.get_task_levels(),
accelerator_type=accelerator_type,
featurization=self.datamodule.featurization,
task_norms=self.datamodule.task_norms,
replicas=1,
gradient_acc=1,
global_bs=self.datamodule.batch_size_training,
)
predictor.load_state_dict(torch.load(state_dict_path), strict=False)
self.predictor = Fingerprinter(predictor, 'gnn:15')

if base_path is None:
predictor = load_predictor(
config=self.cfg,
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=self.datamodule.get_task_levels(),
accelerator_type=accelerator_type,
featurization=self.datamodule.featurization,
task_norms=self.datamodule.task_norms,
replicas=1,
gradient_acc=1,
global_bs=self.datamodule.batch_size_training,
)
predictor.load_state_dict(torch.load(state_dict_path), strict=False)
else:
predictor = PredictorModule.load_pretrained_model(name_or_path=ckpt_path, device=accelerator_type)

self.predictor = Fingerprinter(predictor, 'graph_output_nn-graph:0')
self.predictor.setup()


Expand All @@ -78,8 +89,9 @@ def __call__(self, smiles: Union[str,list]) -> torch.Tensor:

batch = Batch.from_data_list(input_features)
batch = {"features": batch, "batch_indices": batch.batch}
node_features = self.predictor.get_fingerprints_for_batch(batch)
fingerprint_graph = global_max_pool(node_features, batch['batch_indices'])
fingerprint_graph = self.predictor.get_fingerprints_for_batch(batch)
# node_features = self.predictor.get_fingerprints_for_batch(batch)
# fingerprint_graph = global_max_pool(node_features, batch['batch_indices'])
num_molecules = min(batch_size, fingerprint_graph.shape[0])
results += [fingerprint_graph[i] for i in range(num_molecules)]

Expand Down
4 changes: 2 additions & 2 deletions tdc_leaderboard_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __getitem__(self, idx):
EPOCHS = 25
REPETITIONS = 5
ENSEMBLE_SIZE = 5
RESULTS_FILE_PATH = 'results.pkl'
RESULTS_FILE_PATH = 'results_gpu.pkl'
SWEEP_RESULTS = {
'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005},
'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003},
Expand Down Expand Up @@ -175,7 +175,7 @@ def __getitem__(self, idx):
predictions_list = []

group = admet_group(path='admet_data/')
featuriser = Minimol()
featuriser = Minimol(base_path='/nethome/blazejb/minimol/minimol/ckpts/minimol_v1_gpu')

# LOOP 1: repetitions
for rep_i, seed1 in enumerate(range(1, REPETITIONS+1)):
Expand Down

0 comments on commit 038d857

Please sign in to comment.