forked from computational-cell-analytics/micro-sam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/dev' into dev
- Loading branch information
Showing
7 changed files
with
354 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
finetuning/specialists/training/histopathology/create_dataloaders.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
|
||
from torch_em.data import MinInstanceSampler | ||
from torch_em.util.debug import check_loader | ||
from torch_em.data.datasets import get_pannuke_loader | ||
from torch_em.transform.label import PerObjectDistanceTransform | ||
|
||
import micro_sam.training as sam_training | ||
|
||
|
||
def get_dataloaders(patch_shape, data_path): | ||
"""This returns the pannuke data loaders implemented in torch_em: | ||
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py | ||
It will automatically download the pannuke data. | ||
Note: to replace this with another data loader you need to return a torch data loader | ||
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. | ||
The labels have to be in a label mask instance segmentation format. | ||
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. | ||
Important: the ID 0 is reseved for background, and the IDs must be consecutive | ||
""" | ||
label_transform = PerObjectDistanceTransform( | ||
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 | ||
) | ||
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] | ||
sampler = MinInstanceSampler(min_num_instances=3) | ||
|
||
train_loader = get_pannuke_loader( | ||
path=data_path, | ||
patch_shape=patch_shape, | ||
batch_size=2, | ||
folds=["fold_1"], | ||
num_workers=16, | ||
download=True, | ||
shuffle=True, | ||
label_transform=label_transform, | ||
raw_transform=raw_transform, | ||
label_dtype=torch.float32, | ||
sampler=sampler, | ||
ndim=2, | ||
) | ||
val_loader = get_pannuke_loader( | ||
path=data_path, | ||
patch_shape=patch_shape, | ||
batch_size=1, | ||
folds=["fold_2"], | ||
num_workers=16, | ||
download=True, | ||
shuffle=True, | ||
label_transform=label_transform, | ||
raw_transform=raw_transform, | ||
label_dtype=torch.float32, | ||
sampler=sampler, | ||
ndim=2, | ||
) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
def visualize_images(data_path): | ||
train_loader, val_loader = get_dataloaders(patch_shape=(1, 512, 512), data_path=data_path) | ||
|
||
# let's visualize train loader first | ||
check_loader(train_loader, 8, plt=True, save_path="./fig.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
visualize_images(data_path="/scratch/projects/nim00007/sam/data/pannuke") |
94 changes: 94 additions & 0 deletions
94
finetuning/specialists/training/histopathology/pannuke_finetuning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import os | ||
import argparse | ||
|
||
import torch | ||
|
||
import micro_sam.training as sam_training | ||
from micro_sam.util import export_custom_sam_model | ||
|
||
from create_dataloaders import get_dataloaders | ||
|
||
|
||
def finetune_pannuke(args): | ||
"""Example code for finetuning SAM on PanNuke""" | ||
# override this (below) if you have some more complex set-up and need to specify the exact gpu | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# training settings: | ||
model_type = args.model_type | ||
checkpoint_path = None # override this to start training from a custom checkpoint | ||
patch_shape = (1, 512, 512) # the patch shape for training | ||
n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled (default: 25) | ||
freeze_parts = args.freeze # override this to freeze different parts of the model | ||
checkpoint_name = f"{args.model_type}/pannuke_sam" | ||
|
||
# all the stuff we need for training | ||
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) | ||
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} | ||
|
||
# Run training. | ||
sam_training.train_sam( | ||
name=checkpoint_name, | ||
model_type=model_type, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
early_stopping=10, | ||
n_objects_per_batch=n_objects_per_batch, | ||
checkpoint_path=checkpoint_path, | ||
freeze=freeze_parts, | ||
device=device, | ||
lr=1e-5, | ||
n_iterations=args.iterations, | ||
save_root=args.save_root, | ||
scheduler_kwargs=scheduler_kwargs, | ||
save_every_kth_epoch=args.save_every_kth_epoch, | ||
) | ||
|
||
if args.export_path is not None: | ||
checkpoint_path = os.path.join( | ||
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" | ||
) | ||
export_custom_sam_model( | ||
checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path, | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the PanNuke dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/pannuke/", | ||
help="The filepath to the PanNuke data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." | ||
) | ||
parser.add_argument( | ||
"--save_root", "-s", | ||
help="Where to save the checkpoint and logs. By default they will be saved where this script is run." | ||
) | ||
parser.add_argument( | ||
"--iterations", type=int, default=int(1e5), | ||
help="For how many iterations should the model be trained? By default 100k." | ||
) | ||
parser.add_argument( | ||
"--export_path", "-e", | ||
help="Where to export the finetuned model to. The exported model can be used in the annotation tools." | ||
) | ||
parser.add_argument( | ||
"--freeze", type=str, nargs="+", default=None, | ||
help="Which parts of the model to freeze for finetuning." | ||
) | ||
parser.add_argument( | ||
"--save_every_kth_epoch", type=int, default=None, | ||
help="To save every kth epoch while fine-tuning. Expects an integer value." | ||
) | ||
parser.add_argument( | ||
"--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." | ||
) | ||
args = parser.parse_args() | ||
finetune_pannuke(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from . import SamTrainer | ||
|
||
|
||
class MedSAMTrainer(SamTrainer): | ||
"""Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306) | ||
""" | ||
def __init__( | ||
self, | ||
**kwargs | ||
): | ||
n_sub_iteration = 1 | ||
mask_prob = 0 | ||
super().__init__(n_sub_iteration=n_sub_iteration, mask_prob=mask_prob, **kwargs) | ||
|
||
def _get_prompt_and_multimasking_choices(self, current_iteration): | ||
n_pos, n_neg = 0, 0 | ||
get_boxes = True | ||
multimask_output = False | ||
return n_pos, n_neg, get_boxes, multimask_output | ||
|
||
def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): | ||
return self._get_prompt_and_multimasking_choices(current_iteration=current_iteration) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.