From 3b224638f9a0817accb98f0072160894dea576ab Mon Sep 17 00:00:00 2001 From: Steph Prince <40640337+stephprince@users.noreply.github.com> Date: Mon, 13 May 2024 18:06:58 -0700 Subject: [PATCH] return off distillation during feature processing --- src/metfish/msa_model/msa_model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/metfish/msa_model/msa_model.py b/src/metfish/msa_model/msa_model.py index 0baf745..b258553 100644 --- a/src/metfish/msa_model/msa_model.py +++ b/src/metfish/msa_model/msa_model.py @@ -54,7 +54,7 @@ def __getitem__(self, idx): saxs_features = self.data_pipeline._process_saxs_feats(f'{self.saxs_dir}/{item.name}.pdb.pr.csv') # pdb data - pdb_features = self.data_pipeline.process_pdb_feats(f'{self.pdb_dir}/fixed_{item.name}.pdb') + pdb_features = self.data_pipeline.process_pdb_feats(f'{self.pdb_dir}/fixed_{item.name}.pdb', is_distillation=False) data = {**sequence_feats, **msa_features, **saxs_features, **pdb_features} feats = self.feature_pipeline.process_features(data) @@ -265,7 +265,7 @@ def load_from_jax(self, jax_path): metfish_dir = "/global/cfs/cdirs/m3513/metfish" data_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result" msa_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result_subset/" - training_csv = f'{msa_dir}/input_training.csv' # was input.csv in apo_holo_data + training_csv = f'{msa_dir}/input_training.csv' val_csv = f'{msa_dir}/input_validation.csv' pdb_dir = f"{data_dir}/pdb" saxs_dir = f"{data_dir}/saxs_r" @@ -287,7 +287,6 @@ def load_from_jax(self, jax_path): # set up training and test datasets and dataloaders train_dataset = MSASAXSDataset(data_config, training_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir) val_dataset = MSASAXSDataset(data_config, val_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir) - train_dataset[0] train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) @@ -306,7 +305,7 @@ def load_from_jax(self, jax_path): )], check_val_every_n_epoch=1,) # TODO - add default_root_dir? - # load exisitng weights + # load existing weights if jax_param_path: msasaxsmodel.load_from_jax(jax_param_path) logging.info(f"Successfully loaded JAX parameters at {jax_param_path}...") @@ -319,5 +318,4 @@ def load_from_jax(self, jax_path): ) # need to initialize EMA this way at the beginning # fit the model - trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, ckpt_path=ckpt_path) - # trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path) \ No newline at end of file + trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path) \ No newline at end of file