From e67166ff4fd17ae542036285c7ea4ed6592ee540 Mon Sep 17 00:00:00 2001 From: SakhinetiPraveena Date: Mon, 6 Oct 2025 10:54:13 +0530 Subject: [PATCH] Added default resize --- detectionmetrics/models/torch_detection.py | 43 +++++++++++++++++----- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/detectionmetrics/models/torch_detection.py b/detectionmetrics/models/torch_detection.py index adbac2c7..5f9c599b 100644 --- a/detectionmetrics/models/torch_detection.py +++ b/detectionmetrics/models/torch_detection.py @@ -243,16 +243,21 @@ def __init__( # Build input transforms (resize, normalize, etc.) self.transform_input = [] + # Default resize to 640x640 if not specified if "resize" in self.model_cfg: - self.transform_input += [ - transforms.Resize( - size=( - self.model_cfg["resize"].get("height", None), - self.model_cfg["resize"].get("width", None), - ), - interpolation=transforms.InterpolationMode.BILINEAR, - ) - ] + resize_height = self.model_cfg["resize"].get("height", 640) + resize_width = self.model_cfg["resize"].get("width", 640) + else: + # Default to 640x640 when no resize is specified + resize_height = 640 + resize_width = 640 + + self.transform_input += [ + transforms.Resize( + size=(resize_height, resize_width), + interpolation=transforms.InterpolationMode.BILINEAR, + ) + ] if "crop" in self.model_cfg: crop_size = ( @@ -403,7 +408,27 @@ def eval( print("Skipping batch: empty image tensor detected.") continue + # Move images to device and ensure consistent shapes for batching images = [img.to(self.device) for img in images] + + # For batch processing, we need to stack tensors, but they must have the same shape + # Even with resize transforms, there might be slight differences + if len(images) > 1: + # Get the target shape from the first image + target_shape = images[0].shape + # Ensure all images have the same shape + for i, img in enumerate(images): + if img.shape != target_shape: + # Resize to match the first image's shape + images[i] = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=target_shape[-2:], # [H, W] + mode='bilinear', + align_corners=False + ).squeeze(0) + + # Stack images for batch processing + images = torch.stack(images) predictions = self.model(images) for i in range(len(images)):