From 908e765b184a92b6a9533ece82804ed02d28f245 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 1 Jan 2024 12:25:50 +0100 Subject: [PATCH] Update training to always resize input --- micro_sam/training/sam_trainer.py | 10 ++++++++-- micro_sam/training/trainable_sam.py | 30 ++++++++++++++++------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 94130a88c..a7e95a4ed 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -195,10 +195,16 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y): # Update Masks Iteratively while Training # def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output): - # estimating the image inputs to make the computations faster for the decoder - input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0) + # Precompute the image embeddings only once. + input_images, input_size = self.model.preprocess( + torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) + ) image_embeddings = self.model.image_embeddings_oft(input_images) + # Update the input size for each input in the batch. + for i in range(len(batched_inputs)): + batched_inputs[i]["input_size"] = input_size + loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 # this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 99728a1b8..6b773244e 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Tuple, Union import torch from torch import nn from torch.nn import functional as F from segment_anything.modeling import Sam +from segment_anything.utils.transforms import ResizeLongestSide # simple wrapper around SAM in order to keep things trainable @@ -23,25 +24,32 @@ def __init__( super().__init__() self.sam = sam self.device = device + self.transform = ResizeLongestSide(sam.image_encoder.img_size) - def preprocess(self, x: torch.Tensor) -> torch.Tensor: - """Normalize pixel values and pad to a square input. + def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Resize, normalize pixel values and pad to a square input. Args: x: The input tensor. Returns: - The normalized and padded tensor. + The resized, normalized and padded tensor. + The shape of the image after resizing. """ + + # Resize longest side to match the image encoder. + x = self.transform.apply_image_torch(x) + input_size = x.shape[-2:] + # Normalize colors - x = (x - self.sam.pixel_mean) / self.sam.pixel_std + x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) # Pad h, w = x.shape[-2:] padh = self.sam.image_encoder.img_size - h padw = self.sam.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) - return x + return x, input_size def image_embeddings_oft(self, input_images): """@private""" @@ -52,23 +60,19 @@ def image_embeddings_oft(self, input_images): def forward( self, batched_inputs: List[Dict[str, Any]], + image_embeddings: torch.Tensor, multimask_output: bool = False, - image_embeddings: Optional[torch.Tensor] = None, ) -> List[Dict[str, Any]]: """Forward pass. Args: batched_inputs: The batched input images and prompts. - multimask_output: Whether to predict mutiple or just a single mask. image_embeddings: The precompute image embeddings. If not passed then they will be computed. + multimask_output: Whether to predict mutiple or just a single mask. Returns: The predicted segmentation masks and iou values. """ - input_images = torch.stack([self.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0) - if image_embeddings is None: - image_embeddings = self.sam.image_encoder(input_images) - outputs = [] for image_record, curr_embedding in zip(batched_inputs, image_embeddings): if "point_coords" in image_record: @@ -102,7 +106,7 @@ def forward( masks = self.sam.postprocess_masks( low_res_masks, - input_size=image_record["image"].shape[-2:], + input_size=image_record["input_size"], original_size=image_record["original_size"], )