Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA Implementation #611

Merged
merged 19 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def get_dataloaders(patch_shape, data_path, cell_type=None):
"""This returns the livecell data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py
It will automatically download the livecell data.

Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
The labels have to be in a label mask instance segmentation format.
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)

return train_loader, val_loader


def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = params / 1e6
return f"The number of trainable parameters for the provided model is {round(params, 2)}M"


def finetune_livecell(args):
"""Code for finetuning SAM (using LoRA) on LIVECell

Initial observations: There's no real memory advantage actually unless it's "truly" scaled up
# vit_b
# SAM: 93M (takes ~50GB)
# SAM-LoRA: 4.2M (takes ~49GB)

# vit_l
# SAM: 312M (takes ~63GB)
# SAM-LoRA: 4.4M (takes ~61GB)

# vit_h
# SAM: 641M (takes ~73GB)
# SAM-LoRA: 4.7M (takes ~67GB)

# Q: Would quantization lead to better results? (eg. QLoRA / DoRA)
"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = 4 # the rank

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=True,
rank=rank,
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
)
unetr.to(device)

# let's check the total number of trainable parameters
print(count_parameters(model))

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = model.parameters()

joint_model_params = [params for params in joint_model_params] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

trainer = sam_training.JointSamTrainer(
name="livecell_lora",
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
)
trainer.fit(args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s", default=None,
help="Where to save the checkpoint and logs. By default they will be saved where this script is run."
)
parser.add_argument(
"--iterations", type=int, default=int(1e4),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
args = parser.parse_args()
finetune_livecell(args)


if __name__ == "__main__":
main()
105 changes: 105 additions & 0 deletions micro_sam/training/peft_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
from typing import List, Union

import torch.nn as nn

from segment_anything.modeling import Sam


class LoRASurgery(nn.Module):
"""Operates on the attention layers for performing low-rank adaptation.

(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)

In SAM, it is implemented as:
```python
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
```
"""
def __init__(
self,
rank: int,
block: nn.Module,
):
super().__init__()
self.qkv = block.attn.qkv
self.dim = self.qkv.in_features

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)

self.reset_parameters()
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

block.attn.qkv = self

def reset_parameters(self):
nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.w_b_linear_q.weight)
nn.init.zeros_(self.w_b_linear_v.weight)

def forward(self, x):
qkv = self.qkv(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv


class PEFT_Sam(nn.Module):
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Args:
model: The Segment Anything model.
rank: The rank for low-rank adaptation.
peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
attention_layers_to_update: Which specific layers we apply PEFT methods to.
"""

def __init__(
self,
model: Sam,
rank: int,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None
):
super(PEFT_Sam, self).__init__()

assert rank > 0

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
else: # Applies PEFT to the image encoder by default
self.peft_layers = list(
range(len(model.image_encoder.blocks))
)

self.peft_module = peft_module
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
self.peft_blocks = []

# let's freeze all the pretrained image encoder layers first
for param in model.image_encoder.parameters():
param.requires_grad = False

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

peft_block = self.peft_module(rank=rank, block=blk)
self.peft_blocks.append(peft_block)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

self.peft_blocks = nn.ModuleList(self.peft_blocks)

self.sam = model

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)
11 changes: 11 additions & 0 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_centers_and_bounding_boxes, get_sam_model, get_device,
segmentation_to_one_hot, _DEFAULT_MODEL,
)
from .peft_sam import PEFT_Sam
from .trainable_sam import TrainableSAM

from torch_em.transform.label import PerObjectDistanceTransform
Expand Down Expand Up @@ -42,6 +43,8 @@ def get_trainable_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> TrainableSAM:
"""Get the trainable sam model.

Expand All @@ -54,6 +57,8 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.

Returns:
The trainable segment anything model.
Expand All @@ -80,8 +85,14 @@ def get_trainable_sam_model(
if name.startswith(f"{freeze}"):
param.requires_grad = False

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

# convert to trainable sam
trainable_sam = TrainableSAM(sam)

if return_state:
return trainable_sam, state
return trainable_sam
Expand Down
11 changes: 11 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def get_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_sam: bool = False,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.

Expand Down Expand Up @@ -302,6 +304,8 @@ def get_sam_model(
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
return_state: Return the unpickled checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.

Returns:
The segment anything predictor.
Expand Down Expand Up @@ -347,6 +351,13 @@ def get_sam_model(

state, model_state = _load_checkpoint(checkpoint_path)
sam = sam_model_registry[abbreviated_model_type]()

if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
from micro_sam.training.peft_sam import PEFT_Sam
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

sam.load_state_dict(model_state)
sam.to(device=device)

Expand Down
Loading
Loading