Skip to content

Commit

Permalink
return off distillation during feature processing
Browse files Browse the repository at this point in the history
  • Loading branch information
stephprince committed May 14, 2024
1 parent 5be795e commit 3b22463
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/metfish/msa_model/msa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __getitem__(self, idx):
saxs_features = self.data_pipeline._process_saxs_feats(f'{self.saxs_dir}/{item.name}.pdb.pr.csv')

# pdb data
pdb_features = self.data_pipeline.process_pdb_feats(f'{self.pdb_dir}/fixed_{item.name}.pdb')
pdb_features = self.data_pipeline.process_pdb_feats(f'{self.pdb_dir}/fixed_{item.name}.pdb', is_distillation=False)
data = {**sequence_feats, **msa_features, **saxs_features, **pdb_features}

feats = self.feature_pipeline.process_features(data)
Expand Down Expand Up @@ -265,7 +265,7 @@ def load_from_jax(self, jax_path):
metfish_dir = "/global/cfs/cdirs/m3513/metfish"
data_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result"
msa_dir = f"{metfish_dir}/PDB70_verB_fixed_data/result_subset/"
training_csv = f'{msa_dir}/input_training.csv' # was input.csv in apo_holo_data
training_csv = f'{msa_dir}/input_training.csv'
val_csv = f'{msa_dir}/input_validation.csv'
pdb_dir = f"{data_dir}/pdb"
saxs_dir = f"{data_dir}/saxs_r"
Expand All @@ -287,7 +287,6 @@ def load_from_jax(self, jax_path):
# set up training and test datasets and dataloaders
train_dataset = MSASAXSDataset(data_config, training_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir)
val_dataset = MSASAXSDataset(data_config, val_csv, msa_dir=msa_dir, saxs_dir=saxs_dir, pdb_dir=pdb_dir)
train_dataset[0]

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
Expand All @@ -306,7 +305,7 @@ def load_from_jax(self, jax_path):
)],
check_val_every_n_epoch=1,) # TODO - add default_root_dir?

# load exisitng weights
# load existing weights
if jax_param_path:
msasaxsmodel.load_from_jax(jax_param_path)
logging.info(f"Successfully loaded JAX parameters at {jax_param_path}...")
Expand All @@ -319,5 +318,4 @@ def load_from_jax(self, jax_path):
) # need to initialize EMA this way at the beginning

# fit the model
trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, ckpt_path=ckpt_path)
# trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)
trainer.fit(model=msasaxsmodel, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)

0 comments on commit 3b22463

Please sign in to comment.