Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA Implementation #611

Merged
merged 19 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def get_dataloaders(patch_shape, data_path, cell_type=None):
"""This returns the livecell data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py
It will automatically download the livecell data.

Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
The labels have to be in a label mask instance segmentation format.
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
)

return train_loader, val_loader


def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = params / 1e6
return f"The number of trainable parameters for the provided model is {round(params, 2)}M"


def finetune_livecell(args):
"""Code for finetuning SAM (using LoRA) on LIVECell

Initial observations: There's no real memory advantage actually unless it's "truly" scaled up
# vit_b
# SAM: 93M (takes ~50GB)
# SAM-LoRA: 4.2M (takes ~49GB)

# vit_l
# SAM: 312M (takes ~63GB)
# SAM-LoRA: 4.4M (takes ~61GB)

# vit_h
# SAM: 641M (takes ~73GB)
# SAM-LoRA: 4.7M (takes ~67GB)

# Q: Would quantization lead to better results? (eg. QLoRA / DoRA)
"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
get_lora=True,
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
)
unetr.to(device)

# let's check the total number of trainable parameters
print(count_parameters(model))

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = model.parameters()

joint_model_params = [params for params in joint_model_params] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

trainer = sam_training.JointSamTrainer(
name="livecell_lora",
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
)
trainer.fit(args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s", default=None,
help="Where to save the checkpoint and logs. By default they will be saved where this script is run."
)
parser.add_argument(
"--iterations", type=int, default=int(1e4),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
parser.add_argument(
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
args = parser.parse_args()
finetune_livecell(args)


if __name__ == "__main__":
main()
199 changes: 199 additions & 0 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import Any, Dict, List, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F
# from torch.nn.parameter import Parameter

from segment_anything.modeling import Sam
from segment_anything.utils.transforms import ResizeLongestSide
Expand Down Expand Up @@ -126,3 +128,200 @@ def forward(
)

return outputs


class LoRA_qkv(nn.Module):
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/

In SAM, it is implemented as:
```python
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
```
"""

def __init__(
self,
qkv: nn.Module,
linear_a_q: nn.Module,
linear_b_q: nn.Module,
linear_a_v: nn.Module,
linear_b_v: nn.Module,
):
super().__init__()
self.qkv = qkv
self.linear_a_q = linear_a_q
self.linear_b_q = linear_b_q
self.linear_a_v = linear_a_v
self.linear_b_v = linear_b_v
self.dim = qkv.in_features
self.w_identity = torch.eye(qkv.in_features)

def forward(self, x):
qkv = self.qkv(x) # B, N, N, 3 * org_C
new_q = self.linear_b_q(self.linear_a_q(x))
new_v = self.linear_b_v(self.linear_a_v(x))
qkv[:, :, :, : self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv


# TODO: the mask decoder has some attention blocks, need to decide if we perform lora on them as well.
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
# reference: SAMed and Maceij-SAM performs these experiments.
class LoRA_Sam(nn.Module):
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/

Applies low-rank adaptation to the Segment Anything model's image encoder.

Args:
sam_model: a vision transformer model.
rank: rank of LoRA.
lora_layer: which specific layers we apply LoRA to.

Examples::
>>> model = ViT('B_16_imagenet1k')
>>> lora_model = LoRA_ViT(model, rank=4)
>>> preds = lora_model(img)
>>> print(preds.shape)
torch.Size([1, 1000])
"""

def __init__(
self,
sam_model: Sam,
rank: int,
lora_layer=None
):
super(LoRA_Sam, self).__init__()

assert rank > 0

if lora_layer:
self.lora_layer = lora_layer
else: # Only apply lora to the image encoder by default
self.lora_layer = list(
range(len(sam_model.image_encoder.blocks))
)

# create for storage, then we can init them or load weights
self.w_As = [] # These are linear layers
self.w_Bs = []

# let's freeze all the image encoder layers first
for param in sam_model.image_encoder.parameters():
param.requires_grad = False

# Here, we do the surgery
for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
# If we only want few lora layer instead of all
if t_layer_i not in self.lora_layer:
continue

w_qkv_linear = blk.attn.qkv
self.dim = w_qkv_linear.in_features

w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
w_b_linear_v = nn.Linear(rank, self.dim, bias=False)

self.w_As.append(w_a_linear_q)
self.w_Bs.append(w_b_linear_q)
self.w_As.append(w_a_linear_v)
self.w_Bs.append(w_b_linear_v)

blk.attn.qkv = LoRA_qkv(
w_qkv_linear,
w_a_linear_q,
w_b_linear_q,
w_a_linear_v,
w_b_linear_v,
)

self.reset_parameters()
self.sam = sam_model

def reset_parameters(self) -> None:
for w_A in self.w_As:
nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
for w_B in self.w_Bs:
nn.init.zeros_(w_B.weight)

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)

# TODO: the codebase below is not required here (part of original LoRA implementation)
# we should port the relevant parts from below to `get_sam_model` (if required) for loading the model.

# def save_lora_parameters(self, filename: str) -> None:
# """Only safetensors is supported now.

# Please install safetensor using: `pip install safetensor`, if you do not have one installed yet.

# This function saves both lora and fc parameters.
# """

# assert filename.endswith(".pt") or filename.endswith('.pth')

# num_layer = len(self.w_As) # actually, it is half
# a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
# b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
# prompt_encoder_tensors = {}
# mask_decoder_tensors = {}

# # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted
# if isinstance(self.sam, torch.nn.DataParallel) or isinstance(
# self.sam, torch.nn.parallel.DistributedDataParallel
# ):
# state_dict = self.sam.module.state_dict()
# else:
# state_dict = self.sam.state_dict()

# for key, value in state_dict.items():
# if 'prompt_encoder' in key:
# prompt_encoder_tensors[key] = value
# if 'mask_decoder' in key:
# mask_decoder_tensors[key] = value

# merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors}
# torch.save(merged_dict, filename)

# def load_lora_parameters(self, filename: str) -> None:
# r"""Only safetensors is supported now.

# Please install safetensor using: `pip install safetensor`, if you do not have one installed yet.

# This function loads both lora and fc parameters.
# """

# assert filename.endswith(".pt") or filename.endswith('.pth')

# state_dict = torch.load(filename)

# for i, w_A_linear in enumerate(self.w_As):
# saved_key = f"w_a_{i:03d}"
# saved_tensor = state_dict[saved_key]
# w_A_linear.weight = Parameter(saved_tensor)

# for i, w_B_linear in enumerate(self.w_Bs):
# saved_key = f"w_b_{i:03d}"
# saved_tensor = state_dict[saved_key]
# w_B_linear.weight = Parameter(saved_tensor)

# sam_dict = self.sam.state_dict()
# sam_keys = sam_dict.keys()

# # load prompt encoder
# prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k]
# prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys]
# prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)}
# sam_dict.update(prompt_encoder_new_state_dict)

# # load mask decoder
# mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k]
# mask_decoder_values = [state_dict[k] for k in mask_decoder_keys]
# mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)}
# sam_dict.update(mask_decoder_new_state_dict)
# self.sam.load_state_dict(sam_dict)
Loading
Loading