From 3bd88bbc6bc0d90efd26832dbb777d66e9b8b773 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 14 Dec 2023 12:50:00 +0100 Subject: [PATCH 01/16] Update SamTrainer --- micro_sam/training/sam_trainer.py | 155 +++++++++++++++++------------- 1 file changed, 87 insertions(+), 68 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 43be9cec..215e57b6 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -275,6 +275,62 @@ def _update_samples_for_gt_instances(self, y, n_samples): n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples return n_samples + def _interactive_train_iteration(self, x, y): + n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) + + n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) + + batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) + + assert len(y) == len(sampled_ids) + sampled_binary_y = [] + for i in range(len(y)): + _sampled = [torch.isin(y[i], torch.tensor(idx)) for idx in sampled_ids[i]] + sampled_binary_y.append(_sampled) + + # the steps below are done for one reason in a gist: + # to handle images where there aren't enough instances as expected + # (e.g. where one image has only one instance) + obj_lengths = [len(s) for s in sampled_binary_y] + sampled_binary_y = [s[:min(obj_lengths)] for s in sampled_binary_y] + sampled_binary_y = [torch.stack(s).to(torch.float32) for s in sampled_binary_y] + sampled_binary_y = torch.stack(sampled_binary_y) + + # gist for below - while we find the mismatch, we need to update the batched inputs + # else it would still generate masks using mismatching prompts, and it doesn't help us + # with the subiterations again. hence we clip the number of input points as well + f_objs = sampled_binary_y.shape[1] + batched_inputs = [ + {k: (v[:f_objs] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} + for inp in batched_inputs + ] + + loss, mask_loss, iou_regression_loss, model_iou = self._update_masks( + batched_inputs, y, sampled_binary_y, sampled_ids, + num_subiter=self.n_sub_iteration, multimask_output=multimask_output + ) + return loss, mask_loss, iou_regression_loss, model_iou, sampled_binary_y + + def _check_input_normalization(self, x, input_check_done): + # The expected data range of the SAM model is 8bit (0-255). + # It can easily happen that data is normalized beforehand in training. + # For some reasons we don't fully understand this still works, but it + # should still be avoided and is very detrimental in some settings + # (e.g. when freezing the image encoder) + # We check once per epoch if the data seems to be normalized already and + # raise a warning if this is the case. + if not input_check_done: + data_min, data_max = x.min(), x.max() + if (data_min < 0) or (data_max < 1): + warnings.warn( + "It looks like you are normalizing the training data." + "The SAM model takes care of normalization, so it is better to not do this." + "We recommend to remove data normalization and input data in the range [0, 255]." + ) + input_check_done = True + + return input_check_done + def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() @@ -283,60 +339,13 @@ def _train_epoch_impl(self, progress, forward_context, backprop): n_iter = 0 t_per_iter = time.time() for x, y in self.train_loader: - - # The expected data range of the SAM model is 8bit (0-255). - # It can easily happen that data is normalized beforehand in training. - # For some reasons we don't fully understand this still works, but it - # should still be avoided and is very detrimental in some settings - # (e.g. when freezing the image encoder) - # We check once per epoch if the data seems to be normalized already and - # raise a warning if this is the case. - if not input_check_done: - data_min, data_max = x.min(), x.max() - if (data_min < 0) or (data_max < 1): - warnings.warn( - "It looks like you are normalizing the training data." - "The SAM model takes care of normalization, so it is better to not do this." - "We recommend to remove data normalization and input data in the range [0, 255]." - ) - input_check_done = True + input_check_done = self._check_input_normalization(x, input_check_done) self.optimizer.zero_grad() with forward_context(): - n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) - - n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) - - batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) - - assert len(y) == len(sampled_ids) - sampled_binary_y = [] - for i in range(len(y)): - _sampled = [torch.isin(y[i], torch.tensor(idx)) for idx in sampled_ids[i]] - sampled_binary_y.append(_sampled) - - # the steps below are done for one reason in a gist: - # to handle images where there aren't enough instances as expected - # (e.g. where one image has only one instance) - obj_lengths = [len(s) for s in sampled_binary_y] - sampled_binary_y = [s[:min(obj_lengths)] for s in sampled_binary_y] - sampled_binary_y = [torch.stack(s).to(torch.float32) for s in sampled_binary_y] - sampled_binary_y = torch.stack(sampled_binary_y) - - # gist for below - while we find the mismatch, we need to update the batched inputs - # else it would still generate masks using mismatching prompts, and it doesn't help us - # with the subiterations again. hence we clip the number of input points as well - f_objs = sampled_binary_y.shape[1] - batched_inputs = [ - {k: (v[:f_objs] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} - for inp in batched_inputs - ] - - loss, mask_loss, iou_regression_loss, model_iou = self._update_masks(batched_inputs, y, - sampled_binary_y, sampled_ids, - num_subiter=self.n_sub_iteration, - multimask_output=multimask_output) + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y) = self._interactive_train_iteration(self, x, y, self._iteration) backprop(loss) @@ -355,33 +364,43 @@ def _train_epoch_impl(self, progress, forward_context, backprop): t_per_iter = (time.time() - t_per_iter) / n_iter return t_per_iter + def _interactive_val_iteration(self, x, y, val_iteration): + n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) + + (n_pos, n_neg, get_boxes, + multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration) + + batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) + + batched_outputs = self.model(batched_inputs, multimask_output=multimask_output) + + assert len(y) == len(sampled_ids) + sampled_binary_y = torch.stack( + [torch.isin(y[i], torch.tensor(sampled_ids[i])) for i in range(len(y))] + ).to(torch.float32) + + loss, mask_loss, iou_regression_loss, model_iou = self._get_net_loss(batched_outputs, + y, sampled_ids) + + metric = self._get_val_metric(batched_outputs, sampled_binary_y) + + return loss, mask_loss, iou_regression_loss, model_iou, sampled_binary_y, metric + def _validate_impl(self, forward_context): self.model.eval() + input_check_done = False + val_iteration = 0 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 with torch.no_grad(): for x, y in self.val_loader: - with forward_context(): - n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) - - (n_pos, n_neg, - get_boxes, multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration) - - batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) + input_check_done = self._check_input_normalization(x, input_check_done) - batched_outputs = self.model(batched_inputs, multimask_output=multimask_output) - - assert len(y) == len(sampled_ids) - sampled_binary_y = torch.stack( - [torch.isin(y[i], torch.tensor(sampled_ids[i])) for i in range(len(y))] - ).to(torch.float32) - - loss, mask_loss, iou_regression_loss, model_iou = self._get_net_loss(batched_outputs, - y, sampled_ids) - - metric = self._get_val_metric(batched_outputs, sampled_binary_y) + with forward_context(): + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) loss_val += loss.item() metric_val += metric.item() From 7f16cf58232608b47e5b160c0b4feb478f2423f9 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 14 Dec 2023 13:54:19 +0100 Subject: [PATCH 02/16] Fix interactive train iteration --- micro_sam/training/sam_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 215e57b6..08e67cfa 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -345,7 +345,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(self, x, y, self._iteration) + sampled_binary_y) = self._interactive_train_iteration(x, y, self._iteration) backprop(loss) From 93bea33852d7aa04c5d02b1fd55f60da7c69ff70 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 14 Dec 2023 14:53:15 +0100 Subject: [PATCH 03/16] WIP - Add trainer for joint training --- micro_sam/training/joint_sam_trainer.py | 214 ++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 micro_sam/training/joint_sam_trainer.py diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py new file mode 100644 index 00000000..12cee355 --- /dev/null +++ b/micro_sam/training/joint_sam_trainer.py @@ -0,0 +1,214 @@ +import os +import time +from typing import Optional, Union + +import torch +import torch.nn as nn +from torchvision.utils import make_grid + +from .sam_trainer import SamTrainer + +from torch_em.model import UNETR +from torch_em.trainer.logger_base import TorchEmLogger +from torch_em.model.unet import Decoder, ConvBlock2d, Upsampler2d + + +class CustomDecoder(Decoder): + "To make use of the `V-Net` level logic - as we can't make use of the skip connections" + def forward(self, x): + for block, sampler in zip(self.blocks, self.samplers): + x = sampler(x) + x = block(x) + + return x + + +class UNETRForJointTraining(UNETR): + def __init__( + self, + encoder: Optional[nn.Module] = None, + out_channels: int = 1, + final_activation: Optional[Union[str, nn.Module]] = None, + **kwargs + ) -> None: + super().__init__(encoder, out_channels, **kwargs) + + self.encoder = encoder + + # parameters for the decoder network + depth = 3 + initial_features = 64 + gain = 2 + features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] + scale_factors = depth * [2] + self.out_channels = out_channels + + self.decoder = Decoder( + features=features_decoder, + scale_factors=scale_factors[::-1], + conv_block_impl=ConvBlock2d, + sampler_impl=Upsampler2d + ) + + self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) + self.final_activation = self._get_activation(final_activation) + + def forward(self, x): + org_shape = x.shape[-2:] + + x = torch.stack([self.preprocess(e) for e in x], dim=0) + + x = self.encoder(x) + x = self.decoder(x) + + x = self.out_conv(x) + if self.final_activation is not None: + x = self.final_activation(x) + + x = self.postprocess_masks(x, org_shape, org_shape) + + return x + + +class JointSamTrainer(SamTrainer): + def __init__( + self, **kwargs + ): + super().__init__(**kwargs) + self.unetr = UNETRForJointTraining( + img_size=self.model.img_size, + backbone="sam", + encoder="vit_b", + out_channels=self.model.out_channels, + use_sam_stats=True, + final_activation="Sigmoid" + ) + + def _instance_train_iteration(self, x, y): + # we pass the inputs to the unetr model + # get the outputs, calculate the dice loss + # return the dice loss + instance_loss = ... + return instance_loss + + def _train_epoch_impl(self, progress, forward_context, backprop): + self.model.train() + + input_check_done = False + + n_iter = 0 + t_per_iter = time.time() + for x, y in self.train_loader: + input_check_done = self._check_input_normalization(x, input_check_done) + + self.optimizer.zero_grad() + + with forward_context(): + # 1. train for the interactive segmentation + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y) = self._interactive_train_iteration(x, y, self._iteration) + + backprop(loss) + + # let's get the unetr decoder for doing the instance segmentation + self.unetr.encoder = self.model.encoder # TODO: revisit + + with forward_context(): + # 2. train for the automatic instance segmentation + instance_loss = self._instance_train_iteration(x, y) + + backprop(...) + + if self.logger is not None: + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None + self.logger.log_train( + self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + ) + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _validate_impl(self, forward_context): + self.model.eval() + + input_check_done = False + + val_iteration = 0 + metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 + + with torch.no_grad(): + for x, y in self.val_loader: + input_check_done = self._check_input_normalization(x, input_check_done) + + with forward_context(): + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) + + # TODO: instance segmentation for validation + + loss_val += loss.item() + metric_val += metric.item() + model_iou_val += model_iou.item() + val_iteration += 1 + + loss_val /= len(self.val_loader) + metric_val /= len(self.val_loader) + model_iou_val /= len(self.val_loader) + print() + print(...) # provide a message for the respective metric score + + if self.logger is not None: + self.logger.log_validation( + self._iteration, metric_val, loss_val, x, y, sampled_binary_y, + mask_loss, iou_regression_loss, model_iou_val, instance_loss + ) + + return metric_val + + +class JointSamLogger(TorchEmLogger): + """@private""" + def __init__(self, trainer, save_root, **unused_kwargs): + super().__init__(trainer, save_root) + self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ + os.path.join(save_root, "logs", trainer.name) + os.makedirs(self.log_dir, exist_ok=True) + + self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) + self.log_image_interval = trainer.log_image_interval + + def add_image(self, x, y, samples, name, step): + self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) + self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) + sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) + self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) + + def log_train( + self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + ): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) + self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) + self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) + self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + if step % self.log_image_interval == 0: + self.add_image(x, y, samples, "train", step) + + def log_validation( + self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + ): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) + self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) + self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) + self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + self.add_image(x, y, samples, "validation", step) From 63e1cfe8e0050dc48e2d01743ebb72ac7f190c62 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 14 Dec 2023 15:42:42 +0100 Subject: [PATCH 04/16] Update mentions --- micro_sam/training/joint_sam_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 12cee355..ea9a7c2d 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -78,7 +78,7 @@ def __init__( self.unetr = UNETRForJointTraining( img_size=self.model.img_size, backbone="sam", - encoder="vit_b", + encoder=self.model.encoder, out_channels=self.model.out_channels, use_sam_stats=True, final_activation="Sigmoid" @@ -111,7 +111,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(loss) # let's get the unetr decoder for doing the instance segmentation - self.unetr.encoder = self.model.encoder # TODO: revisit + # TODO: we need to ship the weights from the encoder of SAM to UNETR with forward_context(): # 2. train for the automatic instance segmentation From fe940f51894eccf19425a80fbd688a04175f44f1 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 16 Dec 2023 21:27:06 +0100 Subject: [PATCH 05/16] Fix joint trainer --- micro_sam/training/joint_sam_trainer.py | 83 ++++--------------------- 1 file changed, 12 insertions(+), 71 deletions(-) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index ea9a7c2d..4784a1d8 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -1,73 +1,14 @@ import os import time -from typing import Optional, Union import torch -import torch.nn as nn from torchvision.utils import make_grid from .sam_trainer import SamTrainer from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss from torch_em.trainer.logger_base import TorchEmLogger -from torch_em.model.unet import Decoder, ConvBlock2d, Upsampler2d - - -class CustomDecoder(Decoder): - "To make use of the `V-Net` level logic - as we can't make use of the skip connections" - def forward(self, x): - for block, sampler in zip(self.blocks, self.samplers): - x = sampler(x) - x = block(x) - - return x - - -class UNETRForJointTraining(UNETR): - def __init__( - self, - encoder: Optional[nn.Module] = None, - out_channels: int = 1, - final_activation: Optional[Union[str, nn.Module]] = None, - **kwargs - ) -> None: - super().__init__(encoder, out_channels, **kwargs) - - self.encoder = encoder - - # parameters for the decoder network - depth = 3 - initial_features = 64 - gain = 2 - features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] - scale_factors = depth * [2] - self.out_channels = out_channels - - self.decoder = Decoder( - features=features_decoder, - scale_factors=scale_factors[::-1], - conv_block_impl=ConvBlock2d, - sampler_impl=Upsampler2d - ) - - self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) - self.final_activation = self._get_activation(final_activation) - - def forward(self, x): - org_shape = x.shape[-2:] - - x = torch.stack([self.preprocess(e) for e in x], dim=0) - - x = self.encoder(x) - x = self.decoder(x) - - x = self.out_conv(x) - if self.final_activation is not None: - x = self.final_activation(x) - - x = self.postprocess_masks(x, org_shape, org_shape) - - return x class JointSamTrainer(SamTrainer): @@ -75,21 +16,21 @@ def __init__( self, **kwargs ): super().__init__(**kwargs) - self.unetr = UNETRForJointTraining( - img_size=self.model.img_size, + dist_channels = 3 + self.unetr = UNETR( backbone="sam", encoder=self.model.encoder, - out_channels=self.model.out_channels, + out_channels=dist_channels, use_sam_stats=True, - final_activation="Sigmoid" + final_activation="Sigmoid", + use_skip_connection=False ) def _instance_train_iteration(self, x, y): - # we pass the inputs to the unetr model - # get the outputs, calculate the dice loss - # return the dice loss - instance_loss = ... - return instance_loss + outputs = self.unetr(x) + instance_loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + loss = instance_loss(outputs, y) + return loss def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() @@ -111,13 +52,13 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(loss) # let's get the unetr decoder for doing the instance segmentation - # TODO: we need to ship the weights from the encoder of SAM to UNETR + self.unetr.encoder = self.model.encoder with forward_context(): # 2. train for the automatic instance segmentation instance_loss = self._instance_train_iteration(x, y) - backprop(...) + backprop(instance_loss) if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] From cb8209c9899621f4d0eeff49cb168af33abfb4e7 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 16 Dec 2023 21:29:31 +0100 Subject: [PATCH 06/16] Update mentions --- micro_sam/training/joint_sam_trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 4784a1d8..b3b4139b 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -51,9 +51,6 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(loss) - # let's get the unetr decoder for doing the instance segmentation - self.unetr.encoder = self.model.encoder - with forward_context(): # 2. train for the automatic instance segmentation instance_loss = self._instance_train_iteration(x, y) From 592e18235988505c982d245f534abf4277467562 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 16 Dec 2023 22:26:58 +0100 Subject: [PATCH 07/16] Fix image encoder ini --- micro_sam/training/joint_sam_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index b3b4139b..28c91f71 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -19,7 +19,7 @@ def __init__( dist_channels = 3 self.unetr = UNETR( backbone="sam", - encoder=self.model.encoder, + encoder=self.model.image_encoder, out_channels=dist_channels, use_sam_stats=True, final_activation="Sigmoid", From 8b0a2aaf4e30ab11ab80c1ba304151266059556d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 18 Dec 2023 18:52:18 +0100 Subject: [PATCH 08/16] Update joint training --- .../joint_training/joint_finetuning.py | 152 ++++++++++++++++++ micro_sam/training/__init__.py | 1 + micro_sam/training/joint_sam_trainer.py | 55 ++++--- 3 files changed, 185 insertions(+), 23 deletions(-) create mode 100644 finetuning/livecell/joint_training/joint_finetuning.py diff --git a/finetuning/livecell/joint_training/joint_finetuning.py b/finetuning/livecell/joint_training/joint_finetuning.py new file mode 100644 index 00000000..71c8e265 --- /dev/null +++ b/finetuning/livecell/joint_training/joint_finetuning.py @@ -0,0 +1,152 @@ +import os +import argparse + +import torch + +import torch_em +from torch_em.model import UNETR +from torch_em.data.datasets import get_livecell_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path, cell_type=None): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + train_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="train", batch_size=2, + num_workers=16, cell_types=cell_type, download=True, shuffle=True, + label_transform=label_transform, raw_transform=raw_transform) + val_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="val", batch_size=1, + num_workers=16, cell_types=cell_type, download=True, shuffle=True, + label_transform=label_transform, raw_transform=raw_transform) + + return train_loader, val_loader + + +def finetune_livecell(args): + """Example code for finetuning SAM on LiveCELL""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (520, 704) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False + ) + unetr.to(device) + + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs() + + checkpoint_name = "livecell_sam" + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + # currently we compute loss batch-wise, else we pass channelwise True + loss=torch_em.loss.DiceLoss(channelwise=False), + metric=torch_em.loss.DiceLoss(), + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr + + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/usr/nimanwai/data/livecell/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e5), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index a4396af3..a50db8af 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -3,3 +3,4 @@ from .sam_trainer import SamTrainer, SamLogger from .util import ConvertToSamInputs, get_trainable_sam_model, identity +from .joint_sam_trainer import JointSamTrainer, JointSamLogger diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 28c91f71..3a441210 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -1,35 +1,30 @@ import os import time +import numpy as np import torch from torchvision.utils import make_grid from .sam_trainer import SamTrainer -from torch_em.model import UNETR from torch_em.loss import DiceBasedDistanceLoss from torch_em.trainer.logger_base import TorchEmLogger +from torch_em.trainer.tensorboard_logger import normalize_im class JointSamTrainer(SamTrainer): def __init__( - self, **kwargs + self, + unetr: torch.nn.Module, + **kwargs ): super().__init__(**kwargs) - dist_channels = 3 - self.unetr = UNETR( - backbone="sam", - encoder=self.model.image_encoder, - out_channels=dist_channels, - use_sam_stats=True, - final_activation="Sigmoid", - use_skip_connection=False - ) + self.unetr = unetr def _instance_train_iteration(self, x, y): - outputs = self.unetr(x) + outputs = self.unetr(x.to(self.device)) instance_loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) - loss = instance_loss(outputs, y) + loss = instance_loss(outputs, y.to(self.device)) return loss def _train_epoch_impl(self, progress, forward_context, backprop): @@ -40,6 +35,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop): n_iter = 0 t_per_iter = time.time() for x, y in self.train_loader: + labels_instances = y[:, 0, ...].unsqueeze(1) + labels_for_unetr = y[:, 1:, ...] + input_check_done = self._check_input_normalization(x, input_check_done) self.optimizer.zero_grad() @@ -47,13 +45,15 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): # 1. train for the interactive segmentation (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, y, self._iteration) + sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) backprop(loss) + self.optimizer.zero_grad() + with forward_context(): # 2. train for the automatic instance segmentation - instance_loss = self._instance_train_iteration(x, y) + instance_loss = self._instance_train_iteration(x, labels_for_unetr) backprop(instance_loss) @@ -61,7 +61,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop): lr = [pm["lr"] for pm in self.optimizer.param_groups][0] samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None self.logger.log_train( - self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + self._iteration, loss, lr, x, labels_instances, samples, + mask_loss, iou_regression_loss, model_iou, instance_loss ) self._iteration += 1 @@ -83,13 +84,19 @@ def _validate_impl(self, forward_context): with torch.no_grad(): for x, y in self.val_loader: + labels_instances = y[:, 0, ...].unsqueeze(1) + labels_for_unetr = y[:, 1:, ...] + input_check_done = self._check_input_normalization(x, input_check_done) with forward_context(): + # 1. validate for the interactive segmentation (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) + sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) - # TODO: instance segmentation for validation + with forward_context(): + # 2. validate for the automatic instance segmentation + instance_loss = self._instance_train_iteration(x, labels_for_unetr) loss_val += loss.item() metric_val += metric.item() @@ -99,12 +106,10 @@ def _validate_impl(self, forward_context): loss_val /= len(self.val_loader) metric_val /= len(self.val_loader) model_iou_val /= len(self.val_loader) - print() - print(...) # provide a message for the respective metric score if self.logger is not None: self.logger.log_validation( - self._iteration, metric_val, loss_val, x, y, sampled_binary_y, + self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val, instance_loss ) @@ -123,8 +128,12 @@ def __init__(self, trainer, save_root, **unused_kwargs): self.log_image_interval = trainer.log_image_interval def add_image(self, x, y, samples, name, step): - self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) - self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) + selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] + + image = normalize_im(x[selection].cpu()) + + self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) + self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) From 0cdd94561edf7eedfd56817e6bc73e235032402b Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 18 Dec 2023 19:54:52 +0100 Subject: [PATCH 09/16] Fix label type in dataloaders --- .../joint_training/joint_finetuning.py | 21 ++++++++++++------- micro_sam/training/joint_sam_trainer.py | 16 +++++++------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/finetuning/livecell/joint_training/joint_finetuning.py b/finetuning/livecell/joint_training/joint_finetuning.py index 71c8e265..cd0a7172 100644 --- a/finetuning/livecell/joint_training/joint_finetuning.py +++ b/finetuning/livecell/joint_training/joint_finetuning.py @@ -5,6 +5,7 @@ import torch_em from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss from torch_em.data.datasets import get_livecell_loader from torch_em.transform.label import PerObjectDistanceTransform @@ -27,12 +28,16 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 ) raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] - train_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="train", batch_size=2, - num_workers=16, cell_types=cell_type, download=True, shuffle=True, - label_transform=label_transform, raw_transform=raw_transform) - val_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="val", batch_size=1, - num_workers=16, cell_types=cell_type, download=True, shuffle=True, - label_transform=label_transform, raw_transform=raw_transform) + train_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32 + ) + val_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32 + ) return train_loader, val_loader @@ -103,8 +108,8 @@ def finetune_livecell(args): n_sub_iteration=8, compile_model=False, mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training - unetr=unetr - + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True) ) trainer.fit(args.iterations) if args.export_path is not None: diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 3a441210..54ec76d5 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -7,7 +7,6 @@ from .sam_trainer import SamTrainer -from torch_em.loss import DiceBasedDistanceLoss from torch_em.trainer.logger_base import TorchEmLogger from torch_em.trainer.tensorboard_logger import normalize_im @@ -16,15 +15,16 @@ class JointSamTrainer(SamTrainer): def __init__( self, unetr: torch.nn.Module, + instance_loss: torch.nn.Module, **kwargs ): super().__init__(**kwargs) self.unetr = unetr + self.instance_loss = instance_loss def _instance_train_iteration(self, x, y): outputs = self.unetr(x.to(self.device)) - instance_loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) - loss = instance_loss(outputs, y.to(self.device)) + loss = self.instance_loss(outputs, y.to(self.device)) return loss def _train_epoch_impl(self, progress, forward_context, backprop): @@ -53,16 +53,16 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): # 2. train for the automatic instance segmentation - instance_loss = self._instance_train_iteration(x, labels_for_unetr) + unetr_loss = self._instance_train_iteration(x, labels_for_unetr) - backprop(instance_loss) + backprop(unetr_loss) if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None self.logger.log_train( self._iteration, loss, lr, x, labels_instances, samples, - mask_loss, iou_regression_loss, model_iou, instance_loss + mask_loss, iou_regression_loss, model_iou, unetr_loss ) self._iteration += 1 @@ -96,7 +96,7 @@ def _validate_impl(self, forward_context): with forward_context(): # 2. validate for the automatic instance segmentation - instance_loss = self._instance_train_iteration(x, labels_for_unetr) + unetr_loss = self._instance_train_iteration(x, labels_for_unetr) loss_val += loss.item() metric_val += metric.item() @@ -110,7 +110,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: self.logger.log_validation( self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, - mask_loss, iou_regression_loss, model_iou_val, instance_loss + mask_loss, iou_regression_loss, model_iou_val, unetr_loss ) return metric_val From 9f0b31fa7bcac49910f80adab4b57066ff50afe3 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 19 Dec 2023 13:24:45 +0100 Subject: [PATCH 10/16] Add checkpoint saving for unetr model state --- micro_sam/training/joint_sam_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 54ec76d5..61ca8f4d 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -22,6 +22,15 @@ def __init__( self.unetr = unetr self.instance_loss = instance_loss + def save_checkpoint(self, name, best_metric, **extra_save_dict): + super().save_checkpoint(name, best_metric, unetr_state=self.unetr.state_dict(), **extra_save_dict) + + def load_checkpoint(self, checkpoint="best"): + save_dict = super().load_checkpoint(checkpoint) + self.unetr.load_state_dict(save_dict["unetr_state"]) + self.unetr.to(self.device) + return save_dict + def _instance_train_iteration(self, x, y): outputs = self.unetr(x.to(self.device)) loss = self.instance_loss(outputs, y.to(self.device)) From 23f1b4d387e1ffc46dbe7d9be8ce386645afb043 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 19 Dec 2023 21:13:32 +0100 Subject: [PATCH 11/16] Add unetr inference script --- .../joint_training/unetr_inference.py | 122 ++++++++++++++++++ micro_sam/training/joint_sam_trainer.py | 3 +- micro_sam/training/sam_trainer.py | 2 +- 3 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 finetuning/livecell/joint_training/unetr_inference.py diff --git a/finetuning/livecell/joint_training/unetr_inference.py b/finetuning/livecell/joint_training/unetr_inference.py new file mode 100644 index 00000000..0ba13fb2 --- /dev/null +++ b/finetuning/livecell/joint_training/unetr_inference.py @@ -0,0 +1,122 @@ +import os +import h5py +import argparse +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm +from pathlib import Path +import imageio.v3 as imageio + +import torch + +from torch_em.model import UNETR +from torch_em.util import segmentation +from torch_em.util.prediction import predict_with_padding + +from elf.evaluation import mean_segmentation_accuracy + +from micro_sam.util import get_sam_model + + +def get_unetr_model(model_type, checkpoint, device): + # let's get the sam finetuned model + predictor = get_sam_model( + model_type=model_type + ) + + # load the model with the respective unetr model state + model = UNETR( + encoder=predictor.model.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False + ) + + # FIXME: ideally, we would like to merge the params of encoder from SAM and decoder from unetr state + unetr_state = torch.load(checkpoint, map_location="cpu")["unetr_state"] + model.load_state_dict(unetr_state) + model.to(device) + model.eval() + + return model + + +def predict_for_unetr(inputs, save_dir, model, device): + save_dir = os.path.join(save_dir, "results") + os.makedirs(save_dir, exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(os.path.join(inputs, "images", "livecell_test_images", "*")), + desc="Run unetr inference"): + fname = Path(img_path).stem + save_path = os.path.join(save_dir, f"{fname}.h5") + if os.path.exists(save_path): + continue + + input_ = imageio.imread(img_path) + + outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) + fg, cdist, bdist = outputs.squeeze() + dm_seg = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) + + with h5py.File(save_path, "a") as f: + ds = f.require_dataset("segmentation", shape=dm_seg.shape, compression="gzip", dtype=dm_seg.dtype) + ds[:] = dm_seg + + +def evaluation_for_unetr(inputs, save_dir, csv_path): + if os.path.exists(csv_path): + return + + msa_list, sa50_list = [], [] + for gt_path in tqdm(glob(os.path.join(inputs, "annotations", "livecell_test_images", "*", "*")), + desc="Run unetr evaluation"): + gt = imageio.imread(gt_path) + fname = Path(gt_path).stem + + output_file = os.path.join(save_dir, "results", f"{fname}.h5") + with h5py.File(output_file, "r") as f: + instances = f["segmentation"][:] + + msa, sa_acc = mean_segmentation_accuracy(instances, gt, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + + res_dict = { + "LiveCELL": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list) + } + df = pd.DataFrame.from_dict([res_dict]) + df.to_csv(csv_path) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # let's get the unetr model (initialized with the joint training setup) + model = get_unetr_model(model_type=args.model_type, checkpoint=args.checkpoint, device=device) + + # let's get the predictions + predict_for_unetr(inputs=args.inputs, save_dir=args.save_dir, model=model, device=device) + + # let's evaluate the predictions + evaluation_for_unetr(inputs=args.inputs, save_dir=args.save_dir, csv_path=args.csv_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--inputs", default="/scratch/usr/nimanwai/data/livecell/") + parser.add_argument("-c", "--checkpoint", type=str, required=True) + parser.add_argument("-m", "--model_type", type=str, default="vit_b") + parser.add_argument("--save_dir", type=str, required=True) + parser.add_argument("--csv_path", type=str, default="livecell_joint_training.csv") + args = parser.parse_args() + main(args) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 61ca8f4d..c925f95d 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -23,6 +23,7 @@ def __init__( self.instance_loss = instance_loss def save_checkpoint(self, name, best_metric, **extra_save_dict): + # FIXME: in case of unetr, save state dict only for the decoder super().save_checkpoint(name, best_metric, unetr_state=self.unetr.state_dict(), **extra_save_dict) def load_checkpoint(self, checkpoint="best"): @@ -108,7 +109,7 @@ def _validate_impl(self, forward_context): unetr_loss = self._instance_train_iteration(x, labels_for_unetr) loss_val += loss.item() - metric_val += metric.item() + metric_val += metric.item() # FIXME: update the metric to consider unetr metric as well model_iou_val += model_iou.item() val_iteration += 1 diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 08e67cfa..94130a88 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -345,7 +345,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, y, self._iteration) + sampled_binary_y) = self._interactive_train_iteration(x, y) backprop(loss) From de11c1524c14ae50f2b49c130a49d13f766bb668 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 20 Dec 2023 15:18:11 +0100 Subject: [PATCH 12/16] Update metric tracking for instance iteration --- .../livecell/joint_training/joint_finetuning.py | 4 +++- micro_sam/training/joint_sam_trainer.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/finetuning/livecell/joint_training/joint_finetuning.py b/finetuning/livecell/joint_training/joint_finetuning.py index cd0a7172..6c7c57ff 100644 --- a/finetuning/livecell/joint_training/joint_finetuning.py +++ b/finetuning/livecell/joint_training/joint_finetuning.py @@ -74,6 +74,7 @@ def finetune_livecell(args): ) unetr.to(device) + # let's get the parameters for SAM and the decoder from UNETR joint_model_params = [params for params in model.parameters()] # sam parameters for name, params in unetr.named_parameters(): # unetr's decoder parameters if not name.startswith("encoder"): @@ -109,7 +110,8 @@ def finetune_livecell(args): compile_model=False, mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training unetr=unetr, - instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True) + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) ) trainer.fit(args.iterations) if args.export_path is not None: diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index c925f95d..1314b713 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -16,11 +16,13 @@ def __init__( self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, + instance_metric: torch.nn.Module, **kwargs ): super().__init__(**kwargs) self.unetr = unetr self.instance_loss = instance_loss + self.instance_metric = instance_metric def save_checkpoint(self, name, best_metric, **extra_save_dict): # FIXME: in case of unetr, save state dict only for the decoder @@ -32,10 +34,14 @@ def load_checkpoint(self, checkpoint="best"): self.unetr.to(self.device) return save_dict - def _instance_train_iteration(self, x, y): + def _instance_iteration(self, x, y, metrc_for_val=False): outputs = self.unetr(x.to(self.device)) loss = self.instance_loss(outputs, y.to(self.device)) - return loss + if metrc_for_val: + metric = self.instance_metric(outputs, y.to(self.device)) + return loss, metric + else: + return loss def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() @@ -63,7 +69,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): # 2. train for the automatic instance segmentation - unetr_loss = self._instance_train_iteration(x, labels_for_unetr) + unetr_loss = self._instance_iteration(x, labels_for_unetr) backprop(unetr_loss) @@ -106,10 +112,10 @@ def _validate_impl(self, forward_context): with forward_context(): # 2. validate for the automatic instance segmentation - unetr_loss = self._instance_train_iteration(x, labels_for_unetr) + unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metrc_for_val=True) loss_val += loss.item() - metric_val += metric.item() # FIXME: update the metric to consider unetr metric as well + metric_val += metric.item() + (unetr_metric.item() / 3) model_iou_val += model_iou.item() val_iteration += 1 From be0adbfe47e438ed99a05241400c9f4292709642 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 20 Dec 2023 15:30:14 +0100 Subject: [PATCH 13/16] Fix argument spelling --- micro_sam/training/joint_sam_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 1314b713..23d6238f 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -34,10 +34,10 @@ def load_checkpoint(self, checkpoint="best"): self.unetr.to(self.device) return save_dict - def _instance_iteration(self, x, y, metrc_for_val=False): + def _instance_iteration(self, x, y, metric_for_val=False): outputs = self.unetr(x.to(self.device)) loss = self.instance_loss(outputs, y.to(self.device)) - if metrc_for_val: + if metric_for_val: metric = self.instance_metric(outputs, y.to(self.device)) return loss, metric else: @@ -112,7 +112,7 @@ def _validate_impl(self, forward_context): with forward_context(): # 2. validate for the automatic instance segmentation - unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metrc_for_val=True) + unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) loss_val += loss.item() metric_val += metric.item() + (unetr_metric.item() / 3) From 098c9fa5bab534f1dad1bf21db3c574424d4b84e Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 20 Dec 2023 16:00:13 +0100 Subject: [PATCH 14/16] (tmp) Update unetr to save decoder params --- .../livecell/joint_training/unetr_inference.py | 15 +++++++++++++++ micro_sam/training/joint_sam_trainer.py | 10 ++++++++++ 2 files changed, 25 insertions(+) diff --git a/finetuning/livecell/joint_training/unetr_inference.py b/finetuning/livecell/joint_training/unetr_inference.py index 0ba13fb2..764da5d5 100644 --- a/finetuning/livecell/joint_training/unetr_inference.py +++ b/finetuning/livecell/joint_training/unetr_inference.py @@ -7,6 +7,7 @@ from tqdm import tqdm from pathlib import Path import imageio.v3 as imageio +from collections import OrderedDict import torch @@ -34,6 +35,20 @@ def get_unetr_model(model_type, checkpoint, device): use_skip_connection=False ) + """ + sam_state = torch.load(checkpoint, map_location="cpu")["model_state"] + # let's get the vit parameters from sam + encoder_state = [] + for k, v in sam_state.items(): + if k.startswith("image_encoder"): + encoder_state.append((k, v)) + encoder_state = OrderedDict(encoder_state) + + decoder_state = torch.load(checkpoint, map_location="cpu")["unetr_state"] + + unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) + """ + # FIXME: ideally, we would like to merge the params of encoder from SAM and decoder from unetr state unetr_state = torch.load(checkpoint, map_location="cpu")["unetr_state"] model.load_state_dict(unetr_state) diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 23d6238f..5dfeb048 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -1,6 +1,7 @@ import os import time import numpy as np +from collections import OrderedDict import torch from torchvision.utils import make_grid @@ -25,6 +26,15 @@ def __init__( self.instance_metric = instance_metric def save_checkpoint(self, name, best_metric, **extra_save_dict): + """ + current_unetr_state = self.unetr.state_dict() + decoder_state = [] + for k, v in current_unetr_state.items(): + if not k.startswith("encoder"): + decoder_state.append((k, v)) + decoder_state = OrderedDict(decoder_state) + """ + # FIXME: in case of unetr, save state dict only for the decoder super().save_checkpoint(name, best_metric, unetr_state=self.unetr.state_dict(), **extra_save_dict) From 4decb3e21d24fb36676f7363152570f393d557c8 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 20 Dec 2023 17:27:18 +0100 Subject: [PATCH 15/16] Update unetr model param loading --- .../joint_training/unetr_inference.py | 7 +----- micro_sam/training/joint_sam_trainer.py | 22 ++++++++++++++----- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/finetuning/livecell/joint_training/unetr_inference.py b/finetuning/livecell/joint_training/unetr_inference.py index 764da5d5..d34e0372 100644 --- a/finetuning/livecell/joint_training/unetr_inference.py +++ b/finetuning/livecell/joint_training/unetr_inference.py @@ -35,7 +35,6 @@ def get_unetr_model(model_type, checkpoint, device): use_skip_connection=False ) - """ sam_state = torch.load(checkpoint, map_location="cpu")["model_state"] # let's get the vit parameters from sam encoder_state = [] @@ -44,13 +43,9 @@ def get_unetr_model(model_type, checkpoint, device): encoder_state.append((k, v)) encoder_state = OrderedDict(encoder_state) - decoder_state = torch.load(checkpoint, map_location="cpu")["unetr_state"] + decoder_state = torch.load(checkpoint, map_location="cpu")["decoder_state"] unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) - """ - - # FIXME: ideally, we would like to merge the params of encoder from SAM and decoder from unetr state - unetr_state = torch.load(checkpoint, map_location="cpu")["unetr_state"] model.load_state_dict(unetr_state) model.to(device) model.eval() diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 5dfeb048..832399d2 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -26,21 +26,33 @@ def __init__( self.instance_metric = instance_metric def save_checkpoint(self, name, best_metric, **extra_save_dict): - """ current_unetr_state = self.unetr.state_dict() decoder_state = [] for k, v in current_unetr_state.items(): if not k.startswith("encoder"): decoder_state.append((k, v)) decoder_state = OrderedDict(decoder_state) - """ - # FIXME: in case of unetr, save state dict only for the decoder - super().save_checkpoint(name, best_metric, unetr_state=self.unetr.state_dict(), **extra_save_dict) + super().save_checkpoint(name, best_metric, decoder_state=decoder_state, **extra_save_dict) def load_checkpoint(self, checkpoint="best"): save_dict = super().load_checkpoint(checkpoint) - self.unetr.load_state_dict(save_dict["unetr_state"]) + + # let's get the image encoder params from sam + sam_state = save_dict["model_state"] + encoder_state = [] + for k, v in sam_state.items(): + if k.startswith("image_encoder"): + encoder_state.append((k, v)) + encoder_state = OrderedDict(encoder_state) + + # let's get the decoder params from unetr + decoder_state = save_dict["decoder_state"] + + # now let's merge the two to get the params for the unetr + unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) + + self.unetr.load_state_dict(unetr_state) self.unetr.to(self.device) return save_dict From f1db4051922b178880ebfb58ed216be2b37170d3 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 20 Dec 2023 17:45:21 +0100 Subject: [PATCH 16/16] Update unetr automatic instance seg inference --- finetuning/livecell/joint_training/unetr_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/finetuning/livecell/joint_training/unetr_inference.py b/finetuning/livecell/joint_training/unetr_inference.py index d34e0372..3dfc2e59 100644 --- a/finetuning/livecell/joint_training/unetr_inference.py +++ b/finetuning/livecell/joint_training/unetr_inference.py @@ -38,9 +38,10 @@ def get_unetr_model(model_type, checkpoint, device): sam_state = torch.load(checkpoint, map_location="cpu")["model_state"] # let's get the vit parameters from sam encoder_state = [] + prune_prefix = "sam.image_" for k, v in sam_state.items(): - if k.startswith("image_encoder"): - encoder_state.append((k, v)) + if k.startswith(prune_prefix): + encoder_state.append((k[len(prune_prefix):], v)) encoder_state = OrderedDict(encoder_state) decoder_state = torch.load(checkpoint, map_location="cpu")["decoder_state"]