Skip to content

Commit

Permalink
add loggers and update checkpoint saving
Browse files Browse the repository at this point in the history
  • Loading branch information
stephprince committed May 24, 2024
1 parent 3b22463 commit 8a60533
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/metfish/msa_model/msa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import logging
import pandas as pd
import pytorch_lightning as pl
import wandb

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
from torch.utils.data import Dataset, DataLoader

from openfold.utils.import_weights import import_jax_weights_
Expand Down Expand Up @@ -99,6 +101,13 @@ def _log(self, loss_breakdown, batch, outputs, train=True):
on_step=False, on_epoch=True, logger=True,
)

# save saxs msa attention weights
for name, params in self.named_parameters():
if "saxs_msa_attention" in name:
wandb.log({f"params/{name}": wandb.Histogram(params.cpu().detach().numpy()), "epoch": self.current_epoch})
# self.logger.experiment.add_histogram(f'weights_and_biases/{name}', params, self.current_epoch) # use this for TensorBoard


with torch.no_grad():
metrics = self._compute_validation_metrics(
batch,
Expand Down Expand Up @@ -261,18 +270,18 @@ def load_from_jax(self, jax_path):

if __name__ == "__main__":
# set up data paths and configuration
ckpt_path = None
metfish_dir = "/global/cfs/cdirs/m3513/metfish"
data_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result"
msa_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result_subset/"
training_csv = f'{msa_dir}/input_training.csv'
val_csv = f'{msa_dir}/input_validation.csv'
pdb_dir = f"{data_dir}/pdb"
saxs_dir = f"{data_dir}/saxs_r"
os.environ["MODEL_DIR"] = "/pscratch/sd/s/smprince/projects/alphaflow/src/alphaflow/working_dir"

ckpt_path = None
working_dir = "/pscratch/sd/s/smprince/projects/metfish/model_outputs"
ckpt_path = f"{working_dir}/checkpoints"
jax_param_path = "/pscratch/sd/s/smprince/projects/alphaflow/params_model_1.npz" # these are the original AF weights

resume_ckpt = False
resume_model_weights_only = False
deterministic = False

Expand All @@ -288,22 +297,28 @@ def load_from_jax(self, jax_path):
train_dataset = MSASAXSDataset(data_config, training_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir)
val_dataset = MSASAXSDataset(data_config, val_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)

# initialize model and trainer
msasaxsmodel = MSASAXSModel(config)
# logger = CSVLogger(f"{working_dir}/lightning_logs", name="msasaxs")
# logger = TensorBoardLogger(f"{working_dir}/lightning_logs", name="msasaxs")
logger = WandbLogger(name="msasaxs", save_dir=f"{working_dir}/lightning_logs")
trainer = pl.Trainer(accelerator="gpu",
max_epochs=100,
gradient_clip_val=1.,
limit_train_batches=1.0,
limit_val_batches=1.0,
callbacks=[ModelCheckpoint(
dirpath=os.environ["MODEL_DIR"],
dirpath=f"{working_dir}/checkpoints",
save_top_k=-1,
every_n_epochs=1,
every_n_epochs=25,
)],
check_val_every_n_epoch=1,) # TODO - add default_root_dir?
check_val_every_n_epoch=1,
logger=logger,
log_every_n_steps=1,
default_root_dir=working_dir)

# load existing weights
if jax_param_path:
Expand All @@ -312,10 +327,12 @@ def load_from_jax(self, jax_path):

if resume_model_weights_only:
msasaxsmodel.load_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict'], strict=False)
ckpt_path= None
ckpt_path = None
msasaxsmodel.ema = ExponentialMovingAverage(
model=msasaxsmodel.model, decay=config.ema.decay
) # need to initialize EMA this way at the beginning

ckpt_path = None if not resume_ckpt else ckpt_path

# fit the model
trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)

0 comments on commit 8a60533

Please sign in to comment.