Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions detectionmetrics/models/torch_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,21 @@ def __init__(
# Build input transforms (resize, normalize, etc.)
self.transform_input = []

# Default resize to 640x640 if not specified
if "resize" in self.model_cfg:
self.transform_input += [
transforms.Resize(
size=(
self.model_cfg["resize"].get("height", None),
self.model_cfg["resize"].get("width", None),
),
interpolation=transforms.InterpolationMode.BILINEAR,
)
]
resize_height = self.model_cfg["resize"].get("height", 640)
resize_width = self.model_cfg["resize"].get("width", 640)
else:
# Default to 640x640 when no resize is specified
resize_height = 640
resize_width = 640

self.transform_input += [
transforms.Resize(
size=(resize_height, resize_width),
interpolation=transforms.InterpolationMode.BILINEAR,
)
]

if "crop" in self.model_cfg:
crop_size = (
Expand Down Expand Up @@ -403,7 +408,27 @@ def eval(
print("Skipping batch: empty image tensor detected.")
continue

# Move images to device and ensure consistent shapes for batching
images = [img.to(self.device) for img in images]

# For batch processing, we need to stack tensors, but they must have the same shape
# Even with resize transforms, there might be slight differences
if len(images) > 1:
# Get the target shape from the first image
target_shape = images[0].shape
# Ensure all images have the same shape
for i, img in enumerate(images):
if img.shape != target_shape:
# Resize to match the first image's shape
images[i] = torch.nn.functional.interpolate(
img.unsqueeze(0),
size=target_shape[-2:], # [H, W]
mode='bilinear',
align_corners=False
).squeeze(0)

# Stack images for batch processing
images = torch.stack(images)
predictions = self.model(images)

for i in range(len(images)):
Expand Down