Skip to content

Commit

Permalink
Simplify training implementations WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 2, 2024
1 parent ae46a20 commit 2812497
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 65 deletions.
3 changes: 1 addition & 2 deletions micro_sam/prompt_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def _sample_points(self, segmentation, bbox_coordinates, center_coordinates):

return all_coords, all_labels

# TODO make compatible with exact same input shape
def __call__(
self,
segmentation: torch.Tensor,
Expand All @@ -220,7 +219,7 @@ def __call__(
"""Generate the prompts for one object in the segmentation.
Args:
segmentation: Instance segmentation masks .
The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation.
center_coordinates: The precomputed center coordinates of particular object in the segmentation.
If passed, these coordinates will be used as the first positive point prompt.
Expand Down
117 changes: 56 additions & 61 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,62 +108,62 @@ def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):

return n_pos, n_neg, get_boxes, multimask_output

def _get_dice(self, input_, target):
"""Using the default "DiceLoss" called by the trainer from "torch_em"
"""
dice_loss = self.loss(input_, target)
return dice_loss

def _get_iou(self, pred, true, eps=1e-7):
"""Getting the IoU score for the predicted and true labels
def _compute_iou(self, pred, true, eps=1e-7):
"""Compute the IoU score for the predicted and true labels.
"""
pred_mask = pred > 0.5 # binarizing the output predictions
overlap = pred_mask.logical_and(true).sum()
union = pred_mask.logical_or(true).sum()
overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
iou = overlap / (union + eps)
return iou

def _get_net_loss(self, batched_outputs, y, sampled_ids):
def _get_net_loss(self, batched_outputs, one_hot_targets):
"""What do we do here? two **separate** things
1. compute the mask loss: loss between the predicted and ground-truth masks
for this we just use the dice of the prediction vs. the gt (binary) mask
2. compute the mask for the "IOU Regression Head": so we want the iou output from the decoder to
match the actual IOU between predicted and (binary) ground-truth mask. And we use L2Loss / MSE for this.
"""
masks = [m["masks"] for m in batched_outputs]
predicted_iou_values = [m["iou_predictions"] for m in batched_outputs]
batched_masks = [m["masks"] for m in batched_outputs]
batched_predicted_ious = [m["iou_predictions"] for m in batched_outputs]

# FIXME it's unclear why we need to do this here, it's unrelated to the loss computation
# and would simplify things to move it further up in the code so we don't need to
# return it several times
with torch.no_grad():
mean_model_iou = torch.mean(torch.stack([p.mean() for p in predicted_iou_values]))
mean_model_iou = torch.mean(torch.stack([p.mean() for p in batched_predicted_ious]))

mask_loss = 0.0 # this is the loss term for 1.
iou_regression_loss = 0.0 # this is the loss term for 2.

# outer loop is over the batch (different image/patch predictions)
for m_, y_, ids_, predicted_iou_ in zip(masks, y, sampled_ids, predicted_iou_values):
per_object_dice_scores, per_object_iou_scores = [], []

# inner loop is over the channels, this corresponds to the different predicted objects
for i, (predicted_obj, predicted_iou) in enumerate(zip(m_, predicted_iou_)):
predicted_obj = self._sigmoid(predicted_obj).to(self.device)
true_obj = (y_ == ids_[i]).to(self.device)

# this is computing the LOSS for 1.)
_dice_score = min([self._get_dice(p[None], true_obj) for p in predicted_obj])
per_object_dice_scores.append(_dice_score)

# now we need to compute the loss for 2.)
with torch.no_grad():
true_iou = torch.stack([self._get_iou(p[None], true_obj) for p in predicted_obj])
_iou_score = self.mse_loss(true_iou, predicted_iou)
per_object_iou_scores.append(_iou_score)

mask_loss = mask_loss + torch.mean(torch.stack(per_object_dice_scores))
iou_regression_loss = iou_regression_loss + torch.mean(torch.stack(per_object_iou_scores))
# Loop over the batch.
for masks, targets, predicted_iou in zip(batched_masks, one_hot_targets, batched_predicted_ious):
# TODO consider hard-coding the dice loss to make sure the reduction is set correctly.
# Compute the dice scores for the 1/3 predicted masks per object.
# FIXME why do we have the _sigmoid? Doesn't make sense?!
# TODO make a note on flipping the axes and the shapes after the dice
predicted_objects = self._sigmoid(masks)
dice_scores = torch.stack([
self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
for i in range(predicted_objects.shape[1])
])
dice_scores, _ = torch.min(dice_scores, dim=0)

# TODO explain this in comment
with torch.no_grad():
true_iou = torch.stack([
self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
])
iou_score = self.mse_loss(true_iou.swapaxes(0, 1), predicted_iou)

mask_loss = mask_loss + torch.mean(dice_scores)
iou_regression_loss = iou_regression_loss + iou_score

loss = mask_loss + iou_regression_loss

return loss, mask_loss, iou_regression_loss, mean_model_iou

# TODO simplify and check where this is used
def _postprocess_outputs(self, masks):
""" "masks" look like -> (B, 1, X, Y)
where, B is the number of objects, (X, Y) is the input image shape
Expand Down Expand Up @@ -194,7 +194,8 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y):
#
# Update Masks Iteratively while Training
#
def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output):
# TODO change this name, it does not match the function
def _update_masks(self, batched_inputs, sampled_binary_y, num_subiter, multimask_output):
image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)

loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
Expand All @@ -208,14 +209,18 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
multimask_output=multimask_output if i == 0 else False)

# we want to average the loss and then backprop over the net sub-iterations
net_loss, net_mask_loss, net_iou_regression_loss, net_mean_model_iou = self._get_net_loss(batched_outputs,
y, sampled_ids)
net_loss, net_mask_loss, net_iou_regression_loss, net_mean_model_iou = self._get_net_loss(
batched_outputs, sampled_binary_y
)

loss += net_loss
mask_loss += net_mask_loss
iou_regression_loss += net_iou_regression_loss
mean_model_iou += net_mean_model_iou

masks, logits_masks = [], []

# TODO simplify
# the loop below gets us the masks and logits from the batch-level outputs
for m in batched_outputs:
mask, l_mask = [], []
Expand Down Expand Up @@ -280,31 +285,22 @@ def _interactive_train_iteration(self, x, y):

batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples)

# TODO potentially refactor this so that we can use it in val
# TODO explain what's going on
assert len(y) == len(sampled_ids)
sampled_binary_y = []
for i in range(len(y)):
_sampled = [torch.isin(y[i], torch.tensor(idx)) for idx in sampled_ids[i]]
sampled_binary_y.append(_sampled)

# the steps below are done for one reason in a gist:
# to handle images where there aren't enough instances as expected
# (e.g. where one image has only one instance)
obj_lengths = [len(s) for s in sampled_binary_y]
sampled_binary_y = [s[:min(obj_lengths)] for s in sampled_binary_y]
sampled_binary_y = [torch.stack(s).to(torch.float32) for s in sampled_binary_y]
sampled_binary_y = torch.stack(sampled_binary_y)

# gist for below - while we find the mismatch, we need to update the batched inputs
# else it would still generate masks using mismatching prompts, and it doesn't help us
# with the subiterations again. hence we clip the number of input points as well
f_objs = sampled_binary_y.shape[1]
n_objects = min(len(ids) for ids in sampled_ids)
sampled_binary_y = torch.stack([
torch.stack([target == seg_id for seg_id in ids[:n_objects]])
for target, ids in zip(y, sampled_ids)
]).float()

batched_inputs = [
{k: (v[:f_objs] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
{k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
for inp in batched_inputs
]

loss, mask_loss, iou_regression_loss, model_iou = self._update_masks(
batched_inputs, y, sampled_binary_y, sampled_ids,
batched_inputs, sampled_binary_y,
num_subiter=self.n_sub_iteration, multimask_output=multimask_output
)
return loss, mask_loss, iou_regression_loss, model_iou, sampled_binary_y
Expand Down Expand Up @@ -376,15 +372,14 @@ def _interactive_val_iteration(self, x, y, val_iteration):
multimask_output=multimask_output,
)

# FIXME why don't we need to restrict the number of objects here? Does this only work for batch_size 1?
# (should re-use the functionality from the training iteration here)
assert len(y) == len(sampled_ids)
sampled_binary_y = torch.stack(
[torch.isin(y[i], torch.tensor(sampled_ids[i])) for i in range(len(y))]
).to(torch.float32)

loss, mask_loss, iou_regression_loss, model_iou = self._get_net_loss(
batched_outputs, y, sampled_ids
)

loss, mask_loss, iou_regression_loss, model_iou = self._get_net_loss(batched_outputs, sampled_binary_y)
metric = self._get_val_metric(batched_outputs, sampled_binary_y)

return loss, mask_loss, iou_regression_loss, model_iou, sampled_binary_y, metric
Expand Down
4 changes: 2 additions & 2 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _train_model(self, model_type, device):
train_loader=train_loader,
val_loader=val_loader,
model=model,
loss=torch_em.loss.DiceLoss(),
metric=torch_em.loss.DiceLoss(),
loss=torch_em.loss.DiceLoss(reduce_channel=None),
metric=torch_em.loss.DiceLoss(reduce_channel=None),
optimizer=optimizer,
lr_scheduler=scheduler,
device=device,
Expand Down

0 comments on commit 2812497

Please sign in to comment.