diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py new file mode 100644 index 00000000..fa887437 --- /dev/null +++ b/finetuning/livecell/lora/train_livecell.py @@ -0,0 +1,184 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_livecell_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path, cell_type=None): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + train_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + val_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + + return train_loader, val_loader + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + params = params / 1e6 + return f"The number of trainable parameters for the provided model is {round(params, 2)}M" + + +def finetune_livecell(args): + """Code for finetuning SAM (using LoRA) on LIVECell + + Initial observations: There's no real memory advantage actually unless it's "truly" scaled up + # vit_b + # SAM: 93M (takes ~50GB) + # SAM-LoRA: 4.2M (takes ~49GB) + + # vit_l + # SAM: 312M (takes ~63GB) + # SAM-LoRA: 4.4M (takes ~61GB) + + # vit_h + # SAM: 641M (takes ~73GB) + # SAM-LoRA: 4.7M (takes ~67GB) + + # Q: Would quantization lead to better results? (eg. QLoRA / DoRA) + """ + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (520, 704) # the patch shape for training + n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + rank = 4 # the rank + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + use_lora=True, + rank=rank, + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's check the total number of trainable parameters + print(count_parameters(model)) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = model.parameters() + + joint_model_params = [params for params in joint_model_params] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + trainer = sam_training.JointSamTrainer( + name="livecell_lora", + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", default=None, + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e4), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/training/peft_sam.py b/micro_sam/training/peft_sam.py new file mode 100644 index 00000000..c67db7cb --- /dev/null +++ b/micro_sam/training/peft_sam.py @@ -0,0 +1,105 @@ +import math +from typing import List, Union + +import torch.nn as nn + +from segment_anything.modeling import Sam + + +class LoRASurgery(nn.Module): + """Operates on the attention layers for performing low-rank adaptation. + + (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) + + In SAM, it is implemented as: + ```python + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + ``` + """ + def __init__( + self, + rank: int, + block: nn.Module, + ): + super().__init__() + self.qkv = block.attn.qkv + self.dim = self.qkv.in_features + + self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) + self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) + self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + + self.reset_parameters() + + block.attn.qkv = self + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) + nn.init.zeros_(self.w_b_linear_q.weight) + nn.init.zeros_(self.w_b_linear_v.weight) + + def forward(self, x): + qkv = self.qkv(x) # B, N, N, 3 * org_C + new_q = self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + qkv[:, :, :, :self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + return qkv + + +class PEFT_Sam(nn.Module): + """Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/ + + Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. + + Args: + model: The Segment Anything model. + rank: The rank for low-rank adaptation. + peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. + attention_layers_to_update: Which specific layers we apply PEFT methods to. + """ + + def __init__( + self, + model: Sam, + rank: int, + peft_module: nn.Module = LoRASurgery, + attention_layers_to_update: Union[List[int]] = None + ): + super(PEFT_Sam, self).__init__() + + assert rank > 0 + + if attention_layers_to_update: + self.peft_layers = attention_layers_to_update + else: # Applies PEFT to the image encoder by default + self.peft_layers = list( + range(len(model.image_encoder.blocks)) + ) + + self.peft_module = peft_module + self.peft_blocks = [] + + # let's freeze all the pretrained image encoder layers first + for param in model.image_encoder.parameters(): + param.requires_grad = False + + for t_layer_i, blk in enumerate(model.image_encoder.blocks): + # If we only want specific layers with PEFT instead of all + if t_layer_i not in self.peft_layers: + continue + + peft_block = self.peft_module(rank=rank, block=blk) + self.peft_blocks.append(peft_block) + + self.peft_blocks = nn.ModuleList(self.peft_blocks) + + self.sam = model + + def forward(self, batched_input, multimask_output): + return self.sam(batched_input, multimask_output) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index b58cd9a6..ac9bda9b 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -12,6 +12,7 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) +from .peft_sam import PEFT_Sam from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -42,6 +43,8 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, + use_lora: bool = False, + rank: Optional[int] = None, ) -> TrainableSAM: """Get the trainable sam model. @@ -54,6 +57,8 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. + use_lora: Whether to use the low rank adaptation method for finetuning. + rank: The rank of the decomposition matrices for updating weights in each attention layer. Returns: The trainable segment anything model. @@ -80,8 +85,14 @@ def get_trainable_sam_model( if name.startswith(f"{freeze}"): param.requires_grad = False + if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers + if rank is None: + rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them + sam = PEFT_Sam(sam, rank=rank).sam + # convert to trainable sam trainable_sam = TrainableSAM(sam) + if return_state: return trainable_sam, state return trainable_sam diff --git a/micro_sam/util.py b/micro_sam/util.py index c9768dc8..e61a28f7 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -270,6 +270,8 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, + use_lora: bool = False, + rank: Optional[int] = None, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -302,6 +304,8 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. + use_lora: Whether to use the low rank adaptation method for finetuning. + rank: The rank of the decomposition matrices for updating weights in each attention layer. Returns: The segment anything predictor. @@ -347,6 +351,13 @@ def get_sam_model( state, model_state = _load_checkpoint(checkpoint_path) sam = sam_model_registry[abbreviated_model_type]() + + if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers + from micro_sam.training.peft_sam import PEFT_Sam + if rank is None: + rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them + sam = PEFT_Sam(sam, rank=rank).sam + sam.load_state_dict(model_state) sam.to(device=device) diff --git a/test/test_peft_training.py b/test/test_peft_training.py new file mode 100644 index 00000000..7c2f1270 --- /dev/null +++ b/test/test_peft_training.py @@ -0,0 +1,49 @@ +import unittest + +import torch + +from micro_sam.util import get_sam_model +from micro_sam.training.peft_sam import PEFT_Sam + + +class TestPEFTModule(unittest.TestCase): + """Integraton test for instantiating a PEFT SAM model. + """ + def _fetch_sam_model(self, model_type, device): + _, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True) + return sam_model + + def _create_dummy_inputs(self, shape): + input_image = torch.ones(shape) + return input_image + + def test_peft_sam(self): + model_type = "vit_b" + device = "cpu" + + # Load the dummy inputs. + input_shape = (1, 512, 512) + inputs = self._create_dummy_inputs(shape=input_shape) + + # Convert to the inputs expected by Segment Anything + batched_inputs = [ + {"image": inputs, "original_size": input_shape[1:]} + ] + + # Load the Segment Anything model. + sam_model = self._fetch_sam_model(model_type=model_type, device=device) + + # Wrap the Segment Anything model with PEFT methods. + peft_sam_model = PEFT_Sam(model=sam_model, rank=4) + + # Get the model outputs + outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False) + + # Check the expected shape of the outputs + mask_shapes = [output["masks"].shape[-2:] for output in outputs] + for shape in mask_shapes: + self.assertEqual(shape, input_shape[1:]) + + +if __name__ == "__main__": + unittest.main()