From f08c3fbeca605a5a74898ae37c13b881eee2a2b9 Mon Sep 17 00:00:00 2001 From: YousefMetwally Date: Tue, 11 Jun 2024 15:20:08 +0200 Subject: [PATCH] train speed trial 1 --- tomotwin/modules/training/torchtrainer.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tomotwin/modules/training/torchtrainer.py b/tomotwin/modules/training/torchtrainer.py index 0a8de74..4ce6522 100644 --- a/tomotwin/modules/training/torchtrainer.py +++ b/tomotwin/modules/training/torchtrainer.py @@ -220,12 +220,14 @@ def classification_f1_score(self, test_loader: DataLoader) -> float: anchor_vol = batch["anchor"].to(self.device, non_blocking=True) positive_vol = batch["positive"].to(self.device, non_blocking=True) negative_vol = batch["negative"].to(self.device, non_blocking=True) + full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0) filenames = batch["filenames"] with autocast(): - # TODO: Probably concat anchor, positive and vol into one batch and run only one forward pass is enough. - anchor_out = self.model.forward(anchor_vol) - positive_out = self.model.forward(positive_vol) - negative_out = self.model.forward(negative_vol) + out = self.model.forward(full_input) + out = torch.split(out, anchor_vol.shape[0], dim=0) + anchor_out = out[0] + positive_out = out[1] + negative_out = out[2] anchor_out_np = anchor_out.cpu().detach().numpy() for i, anchor_filename in enumerate(filenames[0]): @@ -258,16 +260,14 @@ def run_batch(self, batch: Dict): anchor_vol = batch["anchor"].to(self.device, non_blocking=True) positive_vol = batch["positive"].to(self.device, non_blocking=True) negative_vol = batch["negative"].to(self.device, non_blocking=True) + full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0) with autocast(): - # TODO: Probably concat anchor, positive and vol into one batch and run only on forward pass is enough. - anchor_out = self.model.forward(anchor_vol) - positive_out = self.model.forward(positive_vol) - negative_out = self.model.forward(negative_vol) - + out = self.model.forward(full_input) + out = torch.split(out, anchor_vol.shape[0], dim=0) loss = self.criterion( - anchor_out, - positive_out, - negative_out, + out[0], + out[1], + out[2], label_anchor=batch["label_anchor"], label_positive=batch["label_positive"], label_negative=batch["label_negative"],