Skip to content

Commit

Permalink
Experiments with Mask Inputs Probability (#265)
Browse files Browse the repository at this point in the history
Use mask prompts stochastic in finetuning
  • Loading branch information
anwai98 authored Nov 15, 2023
1 parent e124595 commit c236df3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
9 changes: 7 additions & 2 deletions finetuning/livecell/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from micro_sam.evaluation import inference
from micro_sam.evaluation.evaluation import run_evaluation

from util import get_paths, get_checkpoint
from util import get_paths, get_checkpoint, MODELS

LIVECELL_GT_ROOT = "/scratch/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images"
PREDICTION_ROOT = "/scratch/projects/nim00007/sam/iterative_evaluation"
Expand Down Expand Up @@ -78,7 +78,11 @@ def main(args):
prediction_root = get_prediction_root(start_with_box_prompt, model_description)

# get the model checkpoints and desired model name to initialize the predictor
checkpoint, model_type = get_checkpoint(model_description)
if args.checkpoint is None and model_description in MODELS.keys():
checkpoint, model_type = get_checkpoint(model_description)
else:
checkpoint = args.checkpoint
model_type = model_description[:5]
# get the predictor to perform inference
predictor = inference.get_predictor(checkpoint, model_type)

Expand All @@ -94,5 +98,6 @@ def main(args):
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)
7 changes: 4 additions & 3 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def finetune_livecell(args):
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
model = sam_training.get_trainable_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)

# all the stuff we need for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
Expand Down Expand Up @@ -72,7 +72,8 @@ def finetune_livecell(args):
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False
compile_model=False,
mask_prob=0.5 # (optional) overwrite to provide the probability of using mask inputs while training
)
trainer.fit(args.iterations)
if args.export_path is not None:
Expand All @@ -89,7 +90,7 @@ def finetune_livecell(args):
def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="",
"--input_path", "-i", default="/scratch/projects/nim00007/data/LiveCELL/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
Expand Down
20 changes: 16 additions & 4 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import random
from typing import Optional

import numpy as np
Expand All @@ -9,7 +10,7 @@
from torchvision.utils import make_grid
from torch_em.trainer.logger_base import TorchEmLogger

from ..prompt_generators import IterativePromptGenerator
from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator


class SamTrainer(torch_em.trainer.DefaultTrainer):
Expand All @@ -20,14 +21,16 @@ class SamTrainer(torch_em.trainer.DefaultTrainer):
for details on its usage and implementation.
Args:
convert_inputs: Class that converts the output of the dataloader to the expected input format of SAM.
convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
In each sub-iteration new point prompts are sampled where the model was wrong.
n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
sigmoid: The activation function for normalizing the model output.
prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
**kwargs: The keyword arguments of the DefaultTrainer super class.
"""

Expand All @@ -38,7 +41,8 @@ def __init__(
n_objects_per_batch: Optional[int] = None,
mse_loss: torch.nn.Module = torch.nn.MSELoss(),
_sigmoid: torch.nn.Module = torch.nn.Sigmoid(),
prompt_generator=IterativePromptGenerator(),
prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
mask_prob: float = 0.5,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -48,6 +52,7 @@ def __init__(
self.n_objects_per_batch = n_objects_per_batch
self.n_sub_iteration = n_sub_iteration
self.prompt_generator = prompt_generator
self.mask_prob = mask_prob
self._kwargs = kwargs

def _get_prompt_and_multimasking_choices(self, current_iteration):
Expand Down Expand Up @@ -250,7 +255,14 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc

_inp["point_coords"] = updated_point_coords
_inp["point_labels"] = updated_point_labels
_inp["mask_inputs"] = logits

if self.mask_prob > 0:
# using mask inputs for iterative prompting while training, with a probability
use_mask_inputs = (random.random() < self.mask_prob)
if use_mask_inputs:
_inp["mask_inputs"] = logits
else: # remove previously existing mask inputs to avoid using them in next sub-iteration
_inp.pop("mask_inputs", None)

#
# Training Loop
Expand Down

0 comments on commit c236df3

Please sign in to comment.