Skip to content

Commit

Permalink
better evaluator thanks to torchmetrics?
Browse files Browse the repository at this point in the history
  • Loading branch information
johahi committed Dec 4, 2024
1 parent 9d9f1c1 commit 4e84412
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
34 changes: 15 additions & 19 deletions scooby/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import anndata as ad
from anndata.experimental import read_elem, sparse_dataset
from peft import get_peft_model, LoraConfig
from torchmetrics.regression import PearsonCorrCoef

def poisson_multinomial_torch(
y_pred,
Expand Down Expand Up @@ -214,7 +215,7 @@ def fix_rev_comp_rna(outputs_rev_comp):
return test_out


def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=2):
def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=50):
"""
Evaluates the model on the validation set.
Expand All @@ -226,45 +227,40 @@ def evaluate(accelerator, csb, val_loader, mode='multiome', stop_idx=2):
csb (torch.nn.Module): The model.
val_loader (torch.utils.data.DataLoader): The validation data loader.
mode (str, optional): The mode of the model. Either 'multiome' or 'rna'. Defaults to 'multiome'.
stop_idx (int, optional): The index at which to stop inference. Defaults to 2.
stop_idx (int, optional): The index at which to stop inference. Defaults to 50.
"""
device = accelerator.device
csb.eval()
output_list, target_list, pearsons_per_track = [], [], []

if mode == 'multiome':
range_val = 96
else:
range_val = 64

pearson = PearsonCorrCoef(range_val).to(device)

for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)):
if i < (stop_idx):
continue
if i == (stop_idx + 1):
break
inputs = inputs.permute(0, 2, 1).to(device, non_blocking=True)
target_list.append(targets.to(device, non_blocking=True))
with torch.no_grad():
with torch.autocast("cuda"):
output_list.append(csb(inputs, cell_emb_idx).detach())
break
targets = torch.vstack(target_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True)
outputs = torch.vstack(output_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True)

# accelerator.print (outputs.shape)
for x in range(0, range_val):
pearsons_per_track.append(stats.pearsonr(outputs.T[x].flatten(), targets.T[x].flatten())[0])

targets = targets.to(device, non_blocking=True)
with torch.no_grad(), torch.autocast("cuda"):
outputs = csb(inputs, cell_emb_idx).detach()

pearson.update(accelerator.gather(outputs).reshape(-1,range_val).float(), accelerator.gather(targets).reshape(-1,range_val).float())

pearsons_per_track = pearson.compute().numpy(force = True)

accelerator.log({"val_rnaseq_across_tracks_pearson_r": np.nanmean(pearsons_per_track)})
accelerator.log({"val_pearson_r": stats.pearsonr(outputs.flatten(), targets.flatten())[0]})

# Plot 'outputs' in the first subplot
fig, axes = plt.subplots(1, 2, figsize=(10, 5)) # Create a 1x2 subplot grid
axes[0].imshow(outputs.T, vmax=1, aspect="auto")
axes[0].imshow(outputs.squeeze().numpy(force=True).T, vmax=0.2, aspect="auto")
axes[0].set_title("Outputs") # You can add a title if desired

# Plot 'targets' in the second subplot
axes[1].imshow(targets.T, vmax=1, aspect="auto")
axes[1].imshow(targets.squeeze().numpy(force=True).T, vmax=0.2, aspect="auto")
axes[1].set_title("Targets") # You can add a title if desired
# plt.show()
accelerator.log({"val_sample_viz": fig})
Expand Down
3 changes: 2 additions & 1 deletion scripts/train_multiome.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ def train(config):
neighbors=neighbors,
embedding=embedding,
ds=val_ds,
cell_sample_size=32,
cell_weights=None,
random_cells=False,
cells_to_run = list(np.arange(32))
normalize_atac=True,
clip_soft=5,
)
Expand Down
3 changes: 2 additions & 1 deletion scripts/train_rna_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def train(config):
neighbors=neighbors,
embedding=embedding,
ds=val_ds,
cell_sample_size=32,
random_cells=False,
cells_to_run = list(np.arange(32))
cell_weights=None,
clip_soft=5,
)
Expand Down

0 comments on commit 4e84412

Please sign in to comment.