Skip to content

Commit

Permalink
Add LoRA Implementation (#611)
Browse files Browse the repository at this point in the history
Add LoRA based PEFT finetuning
  • Loading branch information
anwai98 authored Jun 19, 2024
1 parent a75d581 commit 22edc30
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 0 deletions.
184 changes: 184 additions & 0 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def get_dataloaders(patch_shape, data_path, cell_type=None):
"""This returns the livecell data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py
It will automatically download the livecell data.
Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
The labels have to be in a label mask instance segmentation format.
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)

return train_loader, val_loader


def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = params / 1e6
return f"The number of trainable parameters for the provided model is {round(params, 2)}M"


def finetune_livecell(args):
"""Code for finetuning SAM (using LoRA) on LIVECell
Initial observations: There's no real memory advantage actually unless it's "truly" scaled up
# vit_b
# SAM: 93M (takes ~50GB)
# SAM-LoRA: 4.2M (takes ~49GB)
# vit_l
# SAM: 312M (takes ~63GB)
# SAM-LoRA: 4.4M (takes ~61GB)
# vit_h
# SAM: 641M (takes ~73GB)
# SAM-LoRA: 4.7M (takes ~67GB)
# Q: Would quantization lead to better results? (eg. QLoRA / DoRA)
"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = 4 # the rank

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=True,
rank=rank,
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
)
unetr.to(device)

# let's check the total number of trainable parameters
print(count_parameters(model))

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = model.parameters()

joint_model_params = [params for params in joint_model_params] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

trainer = sam_training.JointSamTrainer(
name="livecell_lora",
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
)
trainer.fit(args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s", default=None,
help="Where to save the checkpoint and logs. By default they will be saved where this script is run."
)
parser.add_argument(
"--iterations", type=int, default=int(1e4),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
args = parser.parse_args()
finetune_livecell(args)


if __name__ == "__main__":
main()
105 changes: 105 additions & 0 deletions micro_sam/training/peft_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
from typing import List, Union

import torch.nn as nn

from segment_anything.modeling import Sam


class LoRASurgery(nn.Module):
"""Operates on the attention layers for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
```python
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
```
"""
def __init__(
self,
rank: int,
block: nn.Module,
):
super().__init__()
self.qkv = block.attn.qkv
self.dim = self.qkv.in_features

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)

self.reset_parameters()

block.attn.qkv = self

def reset_parameters(self):
nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.w_b_linear_q.weight)
nn.init.zeros_(self.w_b_linear_v.weight)

def forward(self, x):
qkv = self.qkv(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv


class PEFT_Sam(nn.Module):
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Args:
model: The Segment Anything model.
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
"""

def __init__(
self,
model: Sam,
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None
):
super(PEFT_Sam, self).__init__()

assert rank > 0

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
else: # Applies PEFT to the image encoder by default
self.peft_layers = list(
range(len(model.image_encoder.blocks))
)

self.peft_module = peft_module
self.peft_blocks = []

# let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
param.requires_grad = False

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

peft_block = self.peft_module(rank=rank, block=blk)
self.peft_blocks.append(peft_block)

self.peft_blocks = nn.ModuleList(self.peft_blocks)

self.sam = model

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)
11 changes: 11 additions & 0 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_centers_and_bounding_boxes, get_sam_model, get_device,
segmentation_to_one_hot, _DEFAULT_MODEL,
)
from .peft_sam import PEFT_Sam
from .trainable_sam import TrainableSAM

from torch_em.transform.label import PerObjectDistanceTransform
Expand Down Expand Up @@ -42,6 +43,8 @@ def get_trainable_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> TrainableSAM:
"""Get the trainable sam model.
Expand All @@ -54,6 +57,8 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
Returns:
The trainable segment anything model.
Expand All @@ -80,8 +85,14 @@ def get_trainable_sam_model(
if name.startswith(f"{freeze}"):
param.requires_grad = False

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

# convert to trainable sam
trainable_sam = TrainableSAM(sam)

if return_state:
return trainable_sam, state
return trainable_sam
Expand Down
11 changes: 11 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def get_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_sam: bool = False,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.
Expand Down Expand Up @@ -302,6 +304,8 @@ def get_sam_model(
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
return_state: Return the unpickled checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
Returns:
The segment anything predictor.
Expand Down Expand Up @@ -347,6 +351,13 @@ def get_sam_model(

state, model_state = _load_checkpoint(checkpoint_path)
sam = sam_model_registry[abbreviated_model_type]()

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
from micro_sam.training.peft_sam import PEFT_Sam
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

sam.load_state_dict(model_state)
sam.to(device=device)

Expand Down
Loading

0 comments on commit 22edc30

Please sign in to comment.