Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
caroteu committed Jun 17, 2024
2 parents 200e011 + 401ea50 commit 9d08953
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 10 deletions.
6 changes: 6 additions & 0 deletions finetuning/specialists/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Code for finetuning Segment Anything on specific microscopy datasets.
- `resource_efficient_finetuning`: The experiments for finetuning a custom dataset on limited resources.


## Experimental Scripts

- `training/histopathology/`: The finetuning scripts for histopathology datasets.
- `pannuke_finetuning.py`: Finetuning Segment Anything on PanNuke datasets.


## Outdated Scripts
The scripts located at `outdated/` are not in working purpose with the latest version of `micro-sam`.
- It comprises of extensive experiments on "LIVECell" specialist, located at `outdated/livecell/`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_pannuke_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training


def get_dataloaders(patch_shape, data_path):
"""This returns the pannuke data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py
It will automatically download the pannuke data.
Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
The labels have to be in a label mask instance segmentation format.
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
sampler = MinInstanceSampler(min_num_instances=3)

train_loader = get_pannuke_loader(
path=data_path,
patch_shape=patch_shape,
batch_size=2,
folds=["fold_1"],
num_workers=16,
download=True,
shuffle=True,
label_transform=label_transform,
raw_transform=raw_transform,
label_dtype=torch.float32,
sampler=sampler,
ndim=2,
)
val_loader = get_pannuke_loader(
path=data_path,
patch_shape=patch_shape,
batch_size=1,
folds=["fold_2"],
num_workers=16,
download=True,
shuffle=True,
label_transform=label_transform,
raw_transform=raw_transform,
label_dtype=torch.float32,
sampler=sampler,
ndim=2,
)

return train_loader, val_loader


def visualize_images(data_path):
train_loader, val_loader = get_dataloaders(patch_shape=(1, 512, 512), data_path=data_path)

# let's visualize train loader first
check_loader(train_loader, 8, plt=True, save_path="./fig.png")


if __name__ == "__main__":
visualize_images(data_path="/scratch/projects/nim00007/sam/data/pannuke")
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import argparse

import torch

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model

from create_dataloaders import get_dataloaders


def finetune_pannuke(args):
"""Example code for finetuning SAM on PanNuke"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (1, 512, 512) # the patch shape for training
n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled (default: 25)
freeze_parts = args.freeze # override this to freeze different parts of the model
checkpoint_name = f"{args.model_type}/pannuke_sam"

# all the stuff we need for training
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}

# Run training.
sam_training.train_sam(
name=checkpoint_name,
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
early_stopping=10,
n_objects_per_batch=n_objects_per_batch,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
device=device,
lr=1e-5,
n_iterations=args.iterations,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,
)

if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the PanNuke dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/pannuke/",
help="The filepath to the PanNuke data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h."
)
parser.add_argument(
"--save_root", "-s",
help="Where to save the checkpoint and logs. By default they will be saved where this script is run."
)
parser.add_argument(
"--iterations", type=int, default=int(1e5),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
parser.add_argument(
"--save_every_kth_epoch", type=int, default=None,
help="To save every kth epoch while fine-tuning. Expects an integer value."
)
parser.add_argument(
"--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
)
args = parser.parse_args()
finetune_pannuke(args)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions micro_sam/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .sam_trainer import SamTrainer, SamLogger
from .util import ConvertToSamInputs, get_trainable_sam_model, identity
from .joint_sam_trainer import JointSamTrainer, JointSamLogger
from .medsam_trainer import MedSAMTrainer
from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS
22 changes: 22 additions & 0 deletions micro_sam/training/medsam_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from . import SamTrainer


class MedSAMTrainer(SamTrainer):
"""Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306)
"""
def __init__(
self,
**kwargs
):
n_sub_iteration = 1
mask_prob = 0
super().__init__(n_sub_iteration=n_sub_iteration, mask_prob=mask_prob, **kwargs)

def _get_prompt_and_multimasking_choices(self, current_iteration):
n_pos, n_neg = 0, 0
get_boxes = True
multimask_output = False
return n_pos, n_neg, get_boxes, multimask_output

def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
return self._get_prompt_and_multimasking_choices(current_iteration=current_iteration)
27 changes: 17 additions & 10 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SamTrainer(torch_em.trainer.DefaultTrainer):
mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
mask_loss: The loss to compare the predicted masks and the targets.
**kwargs: The keyword arguments of the DefaultTrainer super class.
"""

Expand All @@ -42,12 +43,17 @@ def __init__(
mse_loss: torch.nn.Module = torch.nn.MSELoss(),
prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
mask_prob: float = 0.5,
mask_loss: Optional[torch.nn.Module] = None,
**kwargs
):
# We have to use the Dice Loss with reduce channel set to None.
# Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
dice_loss = torch_em.loss.DiceLoss(reduce_channel=None)
super().__init__(loss=dice_loss, metric=dice_loss, **kwargs)
if mask_loss is None:
# We have to use the Dice Loss with reduce channel set to None.
# Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
else:
self.mask_loss = mask_loss

super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
self.convert_inputs = convert_inputs
self.mse_loss = mse_loss
self.n_objects_per_batch = n_objects_per_batch
Expand Down Expand Up @@ -216,12 +222,13 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim
iou_regression_loss += net_iou_regression_loss
mean_model_iou += net_mean_model_iou

# Determine the next prompts based on current predictions.
with torch.no_grad():
# Get the mask and logit predictions corresponding to the predicted object
# (per actual object) with the best IOU.
masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
if i < (num_subiter - 1): # We need not update the prompts for the last iteration.
# Determine the next prompts based on current predictions.
with torch.no_grad():
# Get the mask and logit predictions corresponding to the predicted object
# (per actual object) with the best IOU.
masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)

loss = loss / num_subiter
mask_loss = mask_loss / num_subiter
Expand Down
Loading

0 comments on commit 9d08953

Please sign in to comment.