diff --git a/scooby/utils/utils.py b/scooby/utils/utils.py index 1262625..63bc24f 100644 --- a/scooby/utils/utils.py +++ b/scooby/utils/utils.py @@ -8,6 +8,7 @@ import anndata as ad from anndata.experimental import read_elem, sparse_dataset from peft import get_peft_model, LoraConfig +from torchmetrics.regression import PearsonCorrCoef def poisson_multinomial_torch( y_pred, @@ -214,7 +215,7 @@ def fix_rev_comp_rna(outputs_rev_comp): return test_out -def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=2): +def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=50): """ Evaluates the model on the validation set. @@ -226,45 +227,40 @@ def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=2): csb (torch.nn.Module): The model. val_loader (torch.utils.data.DataLoader): The validation data loader. mode (str, optional): The mode of the model. Either 'multiome' or 'rna'. Defaults to 'multiome'. - stop_idx (int, optional): The index at which to stop inference. Defaults to 2. + stop_idx (int, optional): The index at which to stop inference. Defaults to 50. """ device = accelerator.device csb.eval() - output_list, target_list, pearsons_per_track = [], [], [] if mode == 'multiome': range_val = 96 else: range_val = 64 + + pearson = PearsonCorrCoef(range_val).to(device) for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): - if i < (stop_idx): - continue if i == (stop_idx + 1): break inputs = inputs.permute(0, 2, 1).to(device, non_blocking=True) - target_list.append(targets.to(device, non_blocking=True)) - with torch.no_grad(): - with torch.autocast("cuda"): - output_list.append(csb(inputs, cell_emb_idx).detach()) - break - targets = torch.vstack(target_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) - outputs = torch.vstack(output_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) - - # accelerator.print (outputs.shape) - for x in range(0, range_val): - pearsons_per_track.append(stats.pearsonr(outputs.T[x].flatten(), targets.T[x].flatten())[0]) - + targets = targets.to(device, non_blocking=True) + with torch.no_grad(), torch.autocast("cuda"): + outputs = csb(inputs, cell_emb_idx).detach() + + pearson.update(accelerator.gather(outputs).reshape(-1,range_val).float(), accelerator.gather(targets).reshape(-1,range_val).float()) + + pearsons_per_track = pearson.compute().numpy(force = True) + accelerator.log({"val_rnaseq_across_tracks_pearson_r": np.nanmean(pearsons_per_track)}) accelerator.log({"val_pearson_r": stats.pearsonr(outputs.flatten(), targets.flatten())[0]}) # Plot 'outputs' in the first subplot fig, axes = plt.subplots(1, 2, figsize=(10, 5)) # Create a 1x2 subplot grid - axes[0].imshow(outputs.T, vmax=1, aspect="auto") + axes[0].imshow(outputs.squeeze().numpy(force=True).T, vmax=0.2, aspect="auto") axes[0].set_title("Outputs") # You can add a title if desired # Plot 'targets' in the second subplot - axes[1].imshow(targets.T, vmax=1, aspect="auto") + axes[1].imshow(targets.squeeze().numpy(force=True).T, vmax=0.2, aspect="auto") axes[1].set_title("Targets") # You can add a title if desired # plt.show() accelerator.log({"val_sample_viz": fig}) diff --git a/scripts/train_multiome.py b/scripts/train_multiome.py index 0c06399..aa0611a 100644 --- a/scripts/train_multiome.py +++ b/scripts/train_multiome.py @@ -141,8 +141,9 @@ def train(config): neighbors=neighbors, embedding=embedding, ds=val_ds, - cell_sample_size=32, cell_weights=None, + random_cells=False, + cells_to_run = list(np.arange(32)) normalize_atac=True, clip_soft=5, ) diff --git a/scripts/train_rna_only.py b/scripts/train_rna_only.py index ebd34ca..8743b61 100644 --- a/scripts/train_rna_only.py +++ b/scripts/train_rna_only.py @@ -139,7 +139,8 @@ def train(config): neighbors=neighbors, embedding=embedding, ds=val_ds, - cell_sample_size=32, + random_cells=False, + cells_to_run = list(np.arange(32)) cell_weights=None, clip_soft=5, )