Skip to content

Commit

Permalink
Add LoRA Implementation (computational-cell-analytics#611)
Browse files Browse the repository at this point in the history
Add LoRA based PEFT finetuning
anwai98 authored Jun 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent a75d581 commit 22edc30
Showing 5 changed files with 360 additions and 0 deletions.
184 changes: 184 additions & 0 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def get_dataloaders(patch_shape, data_path, cell_type=None):
"""This returns the livecell data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py
It will automatically download the livecell data.
Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
The labels have to be in a label mask instance segmentation format.
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)

return train_loader, val_loader


def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = params / 1e6
return f"The number of trainable parameters for the provided model is {round(params, 2)}M"


def finetune_livecell(args):
"""Code for finetuning SAM (using LoRA) on LIVECell
Initial observations: There's no real memory advantage actually unless it's "truly" scaled up
# vit_b
# SAM: 93M (takes ~50GB)
# SAM-LoRA: 4.2M (takes ~49GB)
# vit_l
# SAM: 312M (takes ~63GB)
# SAM-LoRA: 4.4M (takes ~61GB)
# vit_h
# SAM: 641M (takes ~73GB)
# SAM-LoRA: 4.7M (takes ~67GB)
# Q: Would quantization lead to better results? (eg. QLoRA / DoRA)
"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = 4 # the rank

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=True,
rank=rank,
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
)
unetr.to(device)

# let's check the total number of trainable parameters
print(count_parameters(model))

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = model.parameters()

joint_model_params = [params for params in joint_model_params] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

trainer = sam_training.JointSamTrainer(
name="livecell_lora",
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
)
trainer.fit(args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s", default=None,
help="Where to save the checkpoint and logs. By default they will be saved where this script is run."
)
parser.add_argument(
"--iterations", type=int, default=int(1e4),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
args = parser.parse_args()
finetune_livecell(args)


if __name__ == "__main__":
main()
105 changes: 105 additions & 0 deletions micro_sam/training/peft_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
from typing import List, Union

import torch.nn as nn

from segment_anything.modeling import Sam


class LoRASurgery(nn.Module):
"""Operates on the attention layers for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
```python
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
```
"""
def __init__(
self,
rank: int,
block: nn.Module,
):
super().__init__()
self.qkv = block.attn.qkv
self.dim = self.qkv.in_features

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)

self.reset_parameters()

block.attn.qkv = self

def reset_parameters(self):
nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.w_b_linear_q.weight)
nn.init.zeros_(self.w_b_linear_v.weight)

def forward(self, x):
qkv = self.qkv(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv


class PEFT_Sam(nn.Module):
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Args:
model: The Segment Anything model.
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
"""

def __init__(
self,
model: Sam,
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None
):
super(PEFT_Sam, self).__init__()

assert rank > 0

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
else: # Applies PEFT to the image encoder by default
self.peft_layers = list(
range(len(model.image_encoder.blocks))
)

self.peft_module = peft_module
self.peft_blocks = []

# let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
param.requires_grad = False

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

peft_block = self.peft_module(rank=rank, block=blk)
self.peft_blocks.append(peft_block)

self.peft_blocks = nn.ModuleList(self.peft_blocks)

self.sam = model

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)
11 changes: 11 additions & 0 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
get_centers_and_bounding_boxes, get_sam_model, get_device,
segmentation_to_one_hot, _DEFAULT_MODEL,
)
from .peft_sam import PEFT_Sam
from .trainable_sam import TrainableSAM

from torch_em.transform.label import PerObjectDistanceTransform
@@ -42,6 +43,8 @@ def get_trainable_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> TrainableSAM:
"""Get the trainable sam model.
@@ -54,6 +57,8 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
Returns:
The trainable segment anything model.
@@ -80,8 +85,14 @@ def get_trainable_sam_model(
if name.startswith(f"{freeze}"):
param.requires_grad = False

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

# convert to trainable sam
trainable_sam = TrainableSAM(sam)

if return_state:
return trainable_sam, state
return trainable_sam
11 changes: 11 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
@@ -270,6 +270,8 @@ def get_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_sam: bool = False,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.
@@ -302,6 +304,8 @@ def get_sam_model(
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
return_state: Return the unpickled checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
Returns:
The segment anything predictor.
@@ -347,6 +351,13 @@ def get_sam_model(

state, model_state = _load_checkpoint(checkpoint_path)
sam = sam_model_registry[abbreviated_model_type]()

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
from micro_sam.training.peft_sam import PEFT_Sam
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

sam.load_state_dict(model_state)
sam.to(device=device)

49 changes: 49 additions & 0 deletions test/test_peft_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import torch

from micro_sam.util import get_sam_model
from micro_sam.training.peft_sam import PEFT_Sam


class TestPEFTModule(unittest.TestCase):
"""Integraton test for instantiating a PEFT SAM model.
"""
def _fetch_sam_model(self, model_type, device):
_, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True)
return sam_model

def _create_dummy_inputs(self, shape):
input_image = torch.ones(shape)
return input_image

def test_peft_sam(self):
model_type = "vit_b"
device = "cpu"

# Load the dummy inputs.
input_shape = (1, 512, 512)
inputs = self._create_dummy_inputs(shape=input_shape)

# Convert to the inputs expected by Segment Anything
batched_inputs = [
{"image": inputs, "original_size": input_shape[1:]}
]

# Load the Segment Anything model.
sam_model = self._fetch_sam_model(model_type=model_type, device=device)

# Wrap the Segment Anything model with PEFT methods.
peft_sam_model = PEFT_Sam(model=sam_model, rank=4)

# Get the model outputs
outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False)

# Check the expected shape of the outputs
mask_shapes = [output["masks"].shape[-2:] for output in outputs]
for shape in mask_shapes:
self.assertEqual(shape, input_shape[1:])


if __name__ == "__main__":
unittest.main()

0 comments on commit 22edc30

Please sign in to comment.