-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement joint training for interactive and automatic instance segmentation
- Loading branch information
Showing
5 changed files
with
579 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import os | ||
import argparse | ||
|
||
import torch | ||
|
||
import torch_em | ||
from torch_em.model import UNETR | ||
from torch_em.loss import DiceBasedDistanceLoss | ||
from torch_em.data.datasets import get_livecell_loader | ||
from torch_em.transform.label import PerObjectDistanceTransform | ||
|
||
import micro_sam.training as sam_training | ||
from micro_sam.util import export_custom_sam_model | ||
|
||
|
||
def get_dataloaders(patch_shape, data_path, cell_type=None): | ||
"""This returns the livecell data loaders implemented in torch_em: | ||
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py | ||
It will automatically download the livecell data. | ||
Note: to replace this with another data loader you need to return a torch data loader | ||
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. | ||
The labels have to be in a label mask instance segmentation format. | ||
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. | ||
Important: the ID 0 is reseved for background, and the IDs must be consecutive | ||
""" | ||
label_transform = PerObjectDistanceTransform( | ||
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 | ||
) | ||
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] | ||
train_loader = get_livecell_loader( | ||
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, | ||
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, | ||
raw_transform=raw_transform, label_dtype=torch.float32 | ||
) | ||
val_loader = get_livecell_loader( | ||
path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16, | ||
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, | ||
raw_transform=raw_transform, label_dtype=torch.float32 | ||
) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
def finetune_livecell(args): | ||
"""Example code for finetuning SAM on LiveCELL""" | ||
# override this (below) if you have some more complex set-up and need to specify the exact gpu | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# training settings: | ||
model_type = args.model_type | ||
checkpoint_path = None # override this to start training from a custom checkpoint | ||
patch_shape = (520, 704) # the patch shape for training | ||
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled | ||
freeze_parts = args.freeze # override this to freeze different parts of the model | ||
|
||
# get the trainable segment anything model | ||
model = sam_training.get_trainable_sam_model( | ||
model_type=model_type, | ||
device=device, | ||
checkpoint_path=checkpoint_path, | ||
freeze=freeze_parts | ||
) | ||
model.to(device) | ||
|
||
# let's get the UNETR model for automatic instance segmentation pipeline | ||
unetr = UNETR( | ||
backbone="sam", | ||
encoder=model.sam.image_encoder, | ||
out_channels=3, | ||
use_sam_stats=True, | ||
final_activation="Sigmoid", | ||
use_skip_connection=False | ||
) | ||
unetr.to(device) | ||
|
||
# let's get the parameters for SAM and the decoder from UNETR | ||
joint_model_params = [params for params in model.parameters()] # sam parameters | ||
for name, params in unetr.named_parameters(): # unetr's decoder parameters | ||
if not name.startswith("encoder"): | ||
joint_model_params.append(params) | ||
|
||
# all the stuff we need for training | ||
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) | ||
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) | ||
|
||
# this class creates all the training data for a batch (inputs, prompts and labels) | ||
convert_inputs = sam_training.ConvertToSamInputs() | ||
|
||
checkpoint_name = "livecell_sam" | ||
trainer = sam_training.JointSamTrainer( | ||
name=checkpoint_name, | ||
save_root=args.save_root, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
model=model, | ||
optimizer=optimizer, | ||
# currently we compute loss batch-wise, else we pass channelwise True | ||
loss=torch_em.loss.DiceLoss(channelwise=False), | ||
metric=torch_em.loss.DiceLoss(), | ||
device=device, | ||
lr_scheduler=scheduler, | ||
logger=sam_training.JointSamLogger, | ||
log_image_interval=100, | ||
mixed_precision=True, | ||
convert_inputs=convert_inputs, | ||
n_objects_per_batch=n_objects_per_batch, | ||
n_sub_iteration=8, | ||
compile_model=False, | ||
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training | ||
unetr=unetr, | ||
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), | ||
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) | ||
) | ||
trainer.fit(args.iterations) | ||
if args.export_path is not None: | ||
checkpoint_path = os.path.join( | ||
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" | ||
) | ||
export_custom_sam_model( | ||
checkpoint_path=checkpoint_path, | ||
model_type=model_type, | ||
save_path=args.export_path, | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/usr/nimanwai/data/livecell/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." | ||
) | ||
parser.add_argument( | ||
"--save_root", "-s", | ||
help="Where to save the checkpoint and logs. By default they will be saved where this script is run." | ||
) | ||
parser.add_argument( | ||
"--iterations", type=int, default=int(1e5), | ||
help="For how many iterations should the model be trained? By default 100k." | ||
) | ||
parser.add_argument( | ||
"--export_path", "-e", | ||
help="Where to export the finetuned model to. The exported model can be used in the annotation tools." | ||
) | ||
parser.add_argument( | ||
"--freeze", type=str, nargs="+", default=None, | ||
help="Which parts of the model to freeze for finetuning." | ||
) | ||
args = parser.parse_args() | ||
finetune_livecell(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import os | ||
import h5py | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
from glob import glob | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
import imageio.v3 as imageio | ||
from collections import OrderedDict | ||
|
||
import torch | ||
|
||
from torch_em.model import UNETR | ||
from torch_em.util import segmentation | ||
from torch_em.util.prediction import predict_with_padding | ||
|
||
from elf.evaluation import mean_segmentation_accuracy | ||
|
||
from micro_sam.util import get_sam_model | ||
|
||
|
||
def get_unetr_model(model_type, checkpoint, device): | ||
# let's get the sam finetuned model | ||
predictor = get_sam_model( | ||
model_type=model_type | ||
) | ||
|
||
# load the model with the respective unetr model state | ||
model = UNETR( | ||
encoder=predictor.model.image_encoder, | ||
out_channels=3, | ||
use_sam_stats=True, | ||
final_activation="Sigmoid", | ||
use_skip_connection=False | ||
) | ||
|
||
sam_state = torch.load(checkpoint, map_location="cpu")["model_state"] | ||
# let's get the vit parameters from sam | ||
encoder_state = [] | ||
prune_prefix = "sam.image_" | ||
for k, v in sam_state.items(): | ||
if k.startswith(prune_prefix): | ||
encoder_state.append((k[len(prune_prefix):], v)) | ||
encoder_state = OrderedDict(encoder_state) | ||
|
||
decoder_state = torch.load(checkpoint, map_location="cpu")["decoder_state"] | ||
|
||
unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) | ||
model.load_state_dict(unetr_state) | ||
model.to(device) | ||
model.eval() | ||
|
||
return model | ||
|
||
|
||
def predict_for_unetr(inputs, save_dir, model, device): | ||
save_dir = os.path.join(save_dir, "results") | ||
os.makedirs(save_dir, exist_ok=True) | ||
|
||
with torch.no_grad(): | ||
for img_path in tqdm(glob(os.path.join(inputs, "images", "livecell_test_images", "*")), | ||
desc="Run unetr inference"): | ||
fname = Path(img_path).stem | ||
save_path = os.path.join(save_dir, f"{fname}.h5") | ||
if os.path.exists(save_path): | ||
continue | ||
|
||
input_ = imageio.imread(img_path) | ||
|
||
outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) | ||
fg, cdist, bdist = outputs.squeeze() | ||
dm_seg = segmentation.watershed_from_center_and_boundary_distances( | ||
cdist, bdist, fg, min_size=50, | ||
center_distance_threshold=0.5, | ||
boundary_distance_threshold=0.6, | ||
distance_smoothing=1.0 | ||
) | ||
|
||
with h5py.File(save_path, "a") as f: | ||
ds = f.require_dataset("segmentation", shape=dm_seg.shape, compression="gzip", dtype=dm_seg.dtype) | ||
ds[:] = dm_seg | ||
|
||
|
||
def evaluation_for_unetr(inputs, save_dir, csv_path): | ||
if os.path.exists(csv_path): | ||
return | ||
|
||
msa_list, sa50_list = [], [] | ||
for gt_path in tqdm(glob(os.path.join(inputs, "annotations", "livecell_test_images", "*", "*")), | ||
desc="Run unetr evaluation"): | ||
gt = imageio.imread(gt_path) | ||
fname = Path(gt_path).stem | ||
|
||
output_file = os.path.join(save_dir, "results", f"{fname}.h5") | ||
with h5py.File(output_file, "r") as f: | ||
instances = f["segmentation"][:] | ||
|
||
msa, sa_acc = mean_segmentation_accuracy(instances, gt, return_accuracies=True) | ||
msa_list.append(msa) | ||
sa50_list.append(sa_acc[0]) | ||
|
||
res_dict = { | ||
"LiveCELL": "Metrics", | ||
"mSA": np.mean(msa_list), | ||
"SA50": np.mean(sa50_list) | ||
} | ||
df = pd.DataFrame.from_dict([res_dict]) | ||
df.to_csv(csv_path) | ||
|
||
|
||
def main(args): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# let's get the unetr model (initialized with the joint training setup) | ||
model = get_unetr_model(model_type=args.model_type, checkpoint=args.checkpoint, device=device) | ||
|
||
# let's get the predictions | ||
predict_for_unetr(inputs=args.inputs, save_dir=args.save_dir, model=model, device=device) | ||
|
||
# let's evaluate the predictions | ||
evaluation_for_unetr(inputs=args.inputs, save_dir=args.save_dir, csv_path=args.csv_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--inputs", default="/scratch/usr/nimanwai/data/livecell/") | ||
parser.add_argument("-c", "--checkpoint", type=str, required=True) | ||
parser.add_argument("-m", "--model_type", type=str, default="vit_b") | ||
parser.add_argument("--save_dir", type=str, required=True) | ||
parser.add_argument("--csv_path", type=str, default="livecell_joint_training.csv") | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.