forked from computational-cell-analytics/micro-sam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Add LoRA Implementation (computational-cell-analytics#611)
Add LoRA based PEFT finetuning
Showing
5 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import os | ||
import argparse | ||
|
||
import torch | ||
|
||
from torch_em.model import UNETR | ||
from torch_em.loss import DiceBasedDistanceLoss | ||
from torch_em.data.datasets import get_livecell_loader | ||
from torch_em.transform.label import PerObjectDistanceTransform | ||
|
||
import micro_sam.training as sam_training | ||
from micro_sam.util import export_custom_sam_model | ||
|
||
|
||
def get_dataloaders(patch_shape, data_path, cell_type=None): | ||
"""This returns the livecell data loaders implemented in torch_em: | ||
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py | ||
It will automatically download the livecell data. | ||
Note: to replace this with another data loader you need to return a torch data loader | ||
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. | ||
The labels have to be in a label mask instance segmentation format. | ||
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. | ||
Important: the ID 0 is reseved for background, and the IDs must be consecutive | ||
""" | ||
label_transform = PerObjectDistanceTransform( | ||
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 | ||
) | ||
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] | ||
train_loader = get_livecell_loader( | ||
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, | ||
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, | ||
raw_transform=raw_transform, label_dtype=torch.float32, | ||
) | ||
val_loader = get_livecell_loader( | ||
path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, | ||
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, | ||
raw_transform=raw_transform, label_dtype=torch.float32, | ||
) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
def count_parameters(model): | ||
params = sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
params = params / 1e6 | ||
return f"The number of trainable parameters for the provided model is {round(params, 2)}M" | ||
|
||
|
||
def finetune_livecell(args): | ||
"""Code for finetuning SAM (using LoRA) on LIVECell | ||
Initial observations: There's no real memory advantage actually unless it's "truly" scaled up | ||
# vit_b | ||
# SAM: 93M (takes ~50GB) | ||
# SAM-LoRA: 4.2M (takes ~49GB) | ||
# vit_l | ||
# SAM: 312M (takes ~63GB) | ||
# SAM-LoRA: 4.4M (takes ~61GB) | ||
# vit_h | ||
# SAM: 641M (takes ~73GB) | ||
# SAM-LoRA: 4.7M (takes ~67GB) | ||
# Q: Would quantization lead to better results? (eg. QLoRA / DoRA) | ||
""" | ||
# override this (below) if you have some more complex set-up and need to specify the exact gpu | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# training settings: | ||
model_type = args.model_type | ||
checkpoint_path = None # override this to start training from a custom checkpoint | ||
patch_shape = (520, 704) # the patch shape for training | ||
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled | ||
freeze_parts = args.freeze # override this to freeze different parts of the model | ||
rank = 4 # the rank | ||
|
||
# get the trainable segment anything model | ||
model = sam_training.get_trainable_sam_model( | ||
model_type=model_type, | ||
device=device, | ||
checkpoint_path=checkpoint_path, | ||
freeze=freeze_parts, | ||
use_lora=True, | ||
rank=rank, | ||
) | ||
model.to(device) | ||
|
||
# let's get the UNETR model for automatic instance segmentation pipeline | ||
unetr = UNETR( | ||
backbone="sam", | ||
encoder=model.sam.image_encoder, | ||
out_channels=3, | ||
use_sam_stats=True, | ||
final_activation="Sigmoid", | ||
use_skip_connection=False, | ||
resize_input=True, | ||
) | ||
unetr.to(device) | ||
|
||
# let's check the total number of trainable parameters | ||
print(count_parameters(model)) | ||
|
||
# let's get the parameters for SAM and the decoder from UNETR | ||
joint_model_params = model.parameters() | ||
|
||
joint_model_params = [params for params in joint_model_params] # sam parameters | ||
for name, params in unetr.named_parameters(): # unetr's decoder parameters | ||
if not name.startswith("encoder"): | ||
joint_model_params.append(params) | ||
|
||
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) | ||
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) | ||
|
||
# this class creates all the training data for a batch (inputs, prompts and labels) | ||
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) | ||
|
||
trainer = sam_training.JointSamTrainer( | ||
name="livecell_lora", | ||
save_root=args.save_root, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
model=model, | ||
optimizer=optimizer, | ||
device=device, | ||
lr_scheduler=scheduler, | ||
logger=sam_training.JointSamLogger, | ||
log_image_interval=100, | ||
mixed_precision=True, | ||
convert_inputs=convert_inputs, | ||
n_objects_per_batch=n_objects_per_batch, | ||
n_sub_iteration=8, | ||
compile_model=False, | ||
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training | ||
unetr=unetr, | ||
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), | ||
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) | ||
) | ||
trainer.fit(args.iterations) | ||
if args.export_path is not None: | ||
checkpoint_path = os.path.join( | ||
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" | ||
) | ||
export_custom_sam_model( | ||
checkpoint_path=checkpoint_path, | ||
model_type=model_type, | ||
save_path=args.export_path, | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." | ||
) | ||
parser.add_argument( | ||
"--save_root", "-s", default=None, | ||
help="Where to save the checkpoint and logs. By default they will be saved where this script is run." | ||
) | ||
parser.add_argument( | ||
"--iterations", type=int, default=int(1e4), | ||
help="For how many iterations should the model be trained? By default 100k." | ||
) | ||
parser.add_argument( | ||
"--export_path", "-e", | ||
help="Where to export the finetuned model to. The exported model can be used in the annotation tools." | ||
) | ||
parser.add_argument( | ||
"--freeze", type=str, nargs="+", default=None, | ||
help="Which parts of the model to freeze for finetuning." | ||
) | ||
args = parser.parse_args() | ||
finetune_livecell(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import math | ||
from typing import List, Union | ||
|
||
import torch.nn as nn | ||
|
||
from segment_anything.modeling import Sam | ||
|
||
|
||
class LoRASurgery(nn.Module): | ||
"""Operates on the attention layers for performing low-rank adaptation. | ||
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) | ||
In SAM, it is implemented as: | ||
```python | ||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
B, N, C = x.shape | ||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | ||
q, k, v = qkv.unbind(0) | ||
``` | ||
""" | ||
def __init__( | ||
self, | ||
rank: int, | ||
block: nn.Module, | ||
): | ||
super().__init__() | ||
self.qkv = block.attn.qkv | ||
self.dim = self.qkv.in_features | ||
|
||
self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) | ||
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) | ||
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) | ||
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) | ||
|
||
self.reset_parameters() | ||
|
||
block.attn.qkv = self | ||
|
||
def reset_parameters(self): | ||
nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) | ||
nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) | ||
nn.init.zeros_(self.w_b_linear_q.weight) | ||
nn.init.zeros_(self.w_b_linear_v.weight) | ||
|
||
def forward(self, x): | ||
qkv = self.qkv(x) # B, N, N, 3 * org_C | ||
new_q = self.w_b_linear_q(self.w_a_linear_q(x)) | ||
new_v = self.w_b_linear_v(self.w_a_linear_v(x)) | ||
qkv[:, :, :, :self.dim] += new_q | ||
qkv[:, :, :, -self.dim:] += new_v | ||
return qkv | ||
|
||
|
||
class PEFT_Sam(nn.Module): | ||
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/ | ||
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. | ||
Args: | ||
model: The Segment Anything model. | ||
rank: The rank for low-rank adaptation. | ||
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. | ||
attention_layers_to_update: Which specific layers we apply PEFT methods to. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Sam, | ||
rank: int, | ||
peft_module: nn.Module = LoRASurgery, | ||
attention_layers_to_update: Union[List[int]] = None | ||
): | ||
super(PEFT_Sam, self).__init__() | ||
|
||
assert rank > 0 | ||
|
||
if attention_layers_to_update: | ||
self.peft_layers = attention_layers_to_update | ||
else: # Applies PEFT to the image encoder by default | ||
self.peft_layers = list( | ||
range(len(model.image_encoder.blocks)) | ||
) | ||
|
||
self.peft_module = peft_module | ||
self.peft_blocks = [] | ||
|
||
# let's freeze all the pretrained image encoder layers first | ||
for param in model.image_encoder.parameters(): | ||
param.requires_grad = False | ||
|
||
for t_layer_i, blk in enumerate(model.image_encoder.blocks): | ||
# If we only want specific layers with PEFT instead of all | ||
if t_layer_i not in self.peft_layers: | ||
continue | ||
|
||
peft_block = self.peft_module(rank=rank, block=blk) | ||
self.peft_blocks.append(peft_block) | ||
|
||
self.peft_blocks = nn.ModuleList(self.peft_blocks) | ||
|
||
self.sam = model | ||
|
||
def forward(self, batched_input, multimask_output): | ||
return self.sam(batched_input, multimask_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from micro_sam.util import get_sam_model | ||
from micro_sam.training.peft_sam import PEFT_Sam | ||
|
||
|
||
class TestPEFTModule(unittest.TestCase): | ||
"""Integraton test for instantiating a PEFT SAM model. | ||
""" | ||
def _fetch_sam_model(self, model_type, device): | ||
_, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True) | ||
return sam_model | ||
|
||
def _create_dummy_inputs(self, shape): | ||
input_image = torch.ones(shape) | ||
return input_image | ||
|
||
def test_peft_sam(self): | ||
model_type = "vit_b" | ||
device = "cpu" | ||
|
||
# Load the dummy inputs. | ||
input_shape = (1, 512, 512) | ||
inputs = self._create_dummy_inputs(shape=input_shape) | ||
|
||
# Convert to the inputs expected by Segment Anything | ||
batched_inputs = [ | ||
{"image": inputs, "original_size": input_shape[1:]} | ||
] | ||
|
||
# Load the Segment Anything model. | ||
sam_model = self._fetch_sam_model(model_type=model_type, device=device) | ||
|
||
# Wrap the Segment Anything model with PEFT methods. | ||
peft_sam_model = PEFT_Sam(model=sam_model, rank=4) | ||
|
||
# Get the model outputs | ||
outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False) | ||
|
||
# Check the expected shape of the outputs | ||
mask_shapes = [output["masks"].shape[-2:] for output in outputs] | ||
for shape in mask_shapes: | ||
self.assertEqual(shape, input_shape[1:]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |