From 8a605339b8f7fe06ad4c95db10ff57d637b45a49 Mon Sep 17 00:00:00 2001 From: Steph Prince <40640337+stephprince@users.noreply.github.com> Date: Fri, 24 May 2024 13:48:30 -0700 Subject: [PATCH] add loggers and update checkpoint saving --- src/metfish/msa_model/msa_model.py | 37 ++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/metfish/msa_model/msa_model.py b/src/metfish/msa_model/msa_model.py index b258553..c4a3cb6 100644 --- a/src/metfish/msa_model/msa_model.py +++ b/src/metfish/msa_model/msa_model.py @@ -5,8 +5,10 @@ import logging import pandas as pd import pytorch_lightning as pl +import wandb from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger from torch.utils.data import Dataset, DataLoader from openfold.utils.import_weights import import_jax_weights_ @@ -99,6 +101,13 @@ def _log(self, loss_breakdown, batch, outputs, train=True): on_step=False, on_epoch=True, logger=True, ) + # save saxs msa attention weights + for name, params in self.named_parameters(): + if "saxs_msa_attention" in name: + wandb.log({f"params/{name}": wandb.Histogram(params.cpu().detach().numpy()), "epoch": self.current_epoch}) + # self.logger.experiment.add_histogram(f'weights_and_biases/{name}', params, self.current_epoch) # use this for TensorBoard + + with torch.no_grad(): metrics = self._compute_validation_metrics( batch, @@ -261,7 +270,6 @@ def load_from_jax(self, jax_path): if __name__ == "__main__": # set up data paths and configuration - ckpt_path = None metfish_dir = "/global/cfs/cdirs/m3513/metfish" data_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result" msa_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result_subset/" @@ -269,10 +277,11 @@ def load_from_jax(self, jax_path): val_csv = f'{msa_dir}/input_validation.csv' pdb_dir = f"{data_dir}/pdb" saxs_dir = f"{data_dir}/saxs_r" - os.environ["MODEL_DIR"] = "/pscratch/sd/s/smprince/projects/alphaflow/src/alphaflow/working_dir" - - ckpt_path = None + working_dir = "/pscratch/sd/s/smprince/projects/metfish/model_outputs" + ckpt_path = f"{working_dir}/checkpoints" jax_param_path = "/pscratch/sd/s/smprince/projects/alphaflow/params_model_1.npz" # these are the original AF weights + + resume_ckpt = False resume_model_weights_only = False deterministic = False @@ -288,22 +297,28 @@ def load_from_jax(self, jax_path): train_dataset = MSASAXSDataset(data_config, training_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir) val_dataset = MSASAXSDataset(data_config, val_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir) - train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) + train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4) + val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4) # initialize model and trainer msasaxsmodel = MSASAXSModel(config) + # logger = CSVLogger(f"{working_dir}/lightning_logs", name="msasaxs") + # logger = TensorBoardLogger(f"{working_dir}/lightning_logs", name="msasaxs") + logger = WandbLogger(name="msasaxs", save_dir=f"{working_dir}/lightning_logs") trainer = pl.Trainer(accelerator="gpu", max_epochs=100, gradient_clip_val=1., limit_train_batches=1.0, limit_val_batches=1.0, callbacks=[ModelCheckpoint( - dirpath=os.environ["MODEL_DIR"], + dirpath=f"{working_dir}/checkpoints", save_top_k=-1, - every_n_epochs=1, + every_n_epochs=25, )], - check_val_every_n_epoch=1,) # TODO - add default_root_dir? + check_val_every_n_epoch=1, + logger=logger, + log_every_n_steps=1, + default_root_dir=working_dir) # load existing weights if jax_param_path: @@ -312,10 +327,12 @@ def load_from_jax(self, jax_path): if resume_model_weights_only: msasaxsmodel.load_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict'], strict=False) - ckpt_path= None + ckpt_path = None msasaxsmodel.ema = ExponentialMovingAverage( model=msasaxsmodel.model, decay=config.ema.decay ) # need to initialize EMA this way at the beginning + + ckpt_path = None if not resume_ckpt else ckpt_path # fit the model trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path) \ No newline at end of file