Skip to content

Commit

Permalink
update fisher_merging
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 30, 2024
1 parent 0b298ae commit 5b27d7b
Show file tree
Hide file tree
Showing 7 changed files with 648 additions and 7 deletions.
11 changes: 11 additions & 0 deletions config/method/fisher_merging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: fisher_merging

# this should be a list of strings, regular expressions that match the names of the parameters that should be excluded from the fisher merging
exclude_param_names_regex: []
# boolean, whether to normalize fisher weights (L2 norm) or not
normalize_fisher_weight: true
# float, the minimal value in fisher weights, used for tackling the potential numerical issues
minimal_fisher_weight: 1e-6
# common choices: 256, 512, 1024, 2048
num_fisher_examples: 256

42 changes: 41 additions & 1 deletion config/modelpool/clip-vit-large-patch14_TA8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tta_datasets:
dataset:
type: instantiate
name: svhn
object:
object:
_target_: datasets.load_dataset
_args_:
- svhn
Expand Down Expand Up @@ -60,3 +60,43 @@ tta_datasets:
dataset:
name: tanganke/sun397
split: test

train_datasets:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: train
- name: stanford_cars
dataset:
name: tanganke/stanford_cars
split: train
- name: resisc45
dataset:
name: tanganke/resisc45
split: train
- name: eurosat
dataset:
name: tanganke/eurosat
split: train
- name: gtsrb
dataset:
name: tanganke/gtsrb
split: train
- name: mnist
dataset:
name: mnist
split: train
- name: dtd
dataset:
name: tanganke/dtd
split: train
- name: sun397
dataset:
name: tanganke/sun397
split: train
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def on_test_time_adaptation_start(self):
self.zeroshot_weights[task]
)

def compute_logits(self, module, batch, task) -> Tensor:
def compute_logits(self, module, batch, task: str) -> Tensor:
images, _ = batch
text_embeds = self.zeroshot_weights[task]

Expand Down
154 changes: 154 additions & 0 deletions fusion_bench/method/fisher_merging/clip_fisher_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
import os
from copy import deepcopy
from functools import cache
from typing import Dict, List

import lightning as L
import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.nn.modules import Module
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel

from fusion_bench.dataset import CLIPDataset, load_dataset_from_config
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
from fusion_bench.utils import timeit_context

from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients

log = logging.getLogger(__name__)


class FisherMergingAlgorithmForCLIP(FisherMergingAlgorithm):
_fabric: L.Fabric = None
_clip_processor: CLIPProcessor = None
zeroshot_weights = {}

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

# setup fabric
if self._fabric is None and torch.cuda.is_available():
self._fabric = L.Fabric(devices=self.config.devices)
self._fabric.launch()

def on_fisher_merging_start(self):
clip_model_config = self.modelpool.get_model_config("_pretrained_")

with timeit_context("Loading CLIP processor and pretrained CLIP model."):
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
clip_model = CLIPModel.from_pretrained(clip_model_config.path)

clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
self.logit_scale = clip_model.logit_scale.exp()
if self._fabric is not None:
self.visual_projection = self._fabric.to_device(self.visual_projection)
self.logit_scale = self._fabric.to_device(self.logit_scale)

for task in self.modelpool.model_names:
cache_file = os.path.join(
self.config.cache_dir,
f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
)
if os.path.exists(cache_file):
log.info(f"Loading cached zeroshot weights for task: {task}")
zeroshot_weights = torch.load(cache_file, map_location="cpu")
else:
# TODO: Construct zero shot classification head for task
raise NotImplementedError

self.zeroshot_weights[task] = zeroshot_weights
if self._fabric is not None:
self.zeroshot_weights[task] = self._fabric.to_device(
self.zeroshot_weights[task]
)

def compute_logits(self, module, batch, task: str) -> Tensor:
images, _ = batch
text_embeds = self.zeroshot_weights[task]

image_embeds = module(images)[1]
image_embeds = self.visual_projection(image_embeds)

# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale
logits_per_image = logits_per_text.t()

return logits_per_image

def get_fisher_weights(
self,
model_name: str,
model: Module,
train_dataset,
param_names_to_merge: List[str],
) -> Dict[str, Tensor]:
# setup dataloader
train_dataset = CLIPDataset(train_dataset, self._clip_processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=self.config.num_workers,
pin_memory=True,
)
if self._fabric is not None:
train_dataloader = self._fabric.setup_dataloaders(train_dataloader)
model = self._fabric.setup(deepcopy(model))
num_fisher_examples = self.config.num_fisher_examples
if num_fisher_examples % train_dataloader.batch_size != 0:
print(
f"warning: the number of examples for computing fisher cannot be fully divided by the batch size for model, "
"which may lead to a slightly different number of the actually used examples."
)
num_computed_examples = 0
batches_fisher_weights_list = []
for step, batch in tqdm(
enumerate(train_dataloader), desc=f"computing fisher weights"
):
if num_computed_examples >= num_fisher_examples:
break
logits = self.compute_logits(model, batch, model_name)
# Tensor, shape (batch_size, num_label_classes)

# compute fisher weights for classifxication task
# use detach() to detach from the computation graph
# Tensor, shape (batch_size, num_label_classes)
labels_probabilities = torch.softmax(logits, dim=-1).detach()
labels_log_probabilities = torch.log_softmax(logits, dim=-1)
# sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
labels_expectations = (
torch.sqrt(labels_probabilities) * labels_log_probabilities
)
# sum over label classes and batch dimension
sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
model.zero_grad()
sum_labels_expectations.backward()
# dict, fisher weights of a batch
batch_fisher_weights = get_param_squared_gradients(
model=model, param_names_to_merge=param_names_to_merge
)

batches_fisher_weights_list.append(batch_fisher_weights)
num_computed_examples += batch[0].size(0)

model_to_merge_fisher_weights = {}
for batch_fisher_weights in batches_fisher_weights_list:
for key in batch_fisher_weights:
if key not in model_to_merge_fisher_weights:
model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
else:
model_to_merge_fisher_weights[key] += batch_fisher_weights[key]

# mean over batches
for key in model_to_merge_fisher_weights:
model_to_merge_fisher_weights[key] /= num_computed_examples
return model_to_merge_fisher_weights
Loading

0 comments on commit 5b27d7b

Please sign in to comment.