Skip to content

Commit

Permalink
Update training to always resize input
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 1, 2024
1 parent 729cf82 commit 908e765
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
10 changes: 8 additions & 2 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,16 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y):
# Update Masks Iteratively while Training
#
def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output):
# estimating the image inputs to make the computations faster for the decoder
input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0)
# Precompute the image embeddings only once.
input_images, input_size = self.model.preprocess(
torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device)
)
image_embeddings = self.model.image_embeddings_oft(input_images)

# Update the input size for each input in the batch.
for i in range(len(batched_inputs)):
batched_inputs[i]["input_size"] = input_size

loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0

# this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch
Expand Down
30 changes: 17 additions & 13 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from segment_anything.modeling import Sam
from segment_anything.utils.transforms import ResizeLongestSide


# simple wrapper around SAM in order to keep things trainable
Expand All @@ -23,25 +24,32 @@ def __init__(
super().__init__()
self.sam = sam
self.device = device
self.transform = ResizeLongestSide(sam.image_encoder.img_size)

def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input.
def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Resize, normalize pixel values and pad to a square input.
Args:
x: The input tensor.
Returns:
The normalized and padded tensor.
The resized, normalized and padded tensor.
The shape of the image after resizing.
"""

# Resize longest side to match the image encoder.
x = self.transform.apply_image_torch(x)
input_size = x.shape[-2:]

# Normalize colors
x = (x - self.sam.pixel_mean) / self.sam.pixel_std
x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)

# Pad
h, w = x.shape[-2:]
padh = self.sam.image_encoder.img_size - h
padw = self.sam.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
return x, input_size

def image_embeddings_oft(self, input_images):
"""@private"""
Expand All @@ -52,23 +60,19 @@ def image_embeddings_oft(self, input_images):
def forward(
self,
batched_inputs: List[Dict[str, Any]],
image_embeddings: torch.Tensor,
multimask_output: bool = False,
image_embeddings: Optional[torch.Tensor] = None,
) -> List[Dict[str, Any]]:
"""Forward pass.
Args:
batched_inputs: The batched input images and prompts.
multimask_output: Whether to predict mutiple or just a single mask.
image_embeddings: The precompute image embeddings. If not passed then they will be computed.
multimask_output: Whether to predict mutiple or just a single mask.
Returns:
The predicted segmentation masks and iou values.
"""
input_images = torch.stack([self.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0)
if image_embeddings is None:
image_embeddings = self.sam.image_encoder(input_images)

outputs = []
for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
if "point_coords" in image_record:
Expand Down Expand Up @@ -102,7 +106,7 @@ def forward(

masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)

Expand Down

0 comments on commit 908e765

Please sign in to comment.