Skip to content

Commit

Permalink
add regmean
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 30, 2024
1 parent d6842d6 commit a4b33fb
Show file tree
Hide file tree
Showing 13 changed files with 579 additions and 13 deletions.
2 changes: 1 addition & 1 deletion config/method/clip_fisher_merging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ num_fisher_examples: 256
devices: 1
cache_dir: outputs
batch_size: 32
num_workers: 0
num_workers: 0
12 changes: 12 additions & 0 deletions config/method/clip_regmean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: clip_regmean

# list, regular expression of names of parameters that need to be excluded
exclude_param_names_regex: []
# numbers of examples to compute regmean weights
num_regmean_examples: 256
# float, reduce non-diagonal elements in regmean weights by multiplying this scalar
reduce_non_diagonal_ratio: 0.6
devices: 1
cache_dir: outputs
batch_size: 32
num_workers: 0
4 changes: 4 additions & 0 deletions config/method/regmean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: regmean

num_regmean_examples: 256
reduce_non_diagonal_ratio: 0.1
30 changes: 20 additions & 10 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,14 @@ FusionBench is a pioneering project that provides a comprehensive benchmark for

The general structure of the FusionBench project can be visualized through its modular framework, which is divided into several key components:

1. **Analysis Tools**: This includes utilities for analyzing the Loss Landscape, Task Similarity, and Routing Analysis. These tools help researchers understand and interpret the underlying mechanics and performance metrics of different model fusion strategies.

1. **Fusion Algorithm**: The core component where Model Fusion takes place. It integrates models from the Model Pool and adjusts them according to the specified fusion algorithms. The output is then evaluated for performance and effectiveness.
2. **Model Pool**: A repository of various pre-trained models that can be accessed and utilized for fusion. This pool serves as the foundation for creating new, fused models by leveraging the strengths of each individual model.

3. **Fusion Algorithm**: The core component where Model Fusion takes place. It integrates models from the Model Pool and adjusts them according to the specified fusion algorithms. The output is then evaluated for performance and effectiveness.

4. **Task Pool**: A collection of tasks that the fused models are evaluated on. These tasks help in assessing the practical applicability and robustness of the fused models.

5. **YAML Configurations**: Central to the project's modularity, YAML files are used to configure models, datasets, and metrics, allowing seamless customization and scalability.

6. **Models & Warpers, Datasets, and Metrics**: These underlying modules include:
3. **Task Pool**: A collection of tasks that the fused models are evaluated on. These tasks help in assessing the practical applicability and robustness of the fused models.
4. **Models & Warpers, Datasets, and Metrics**: These underlying modules include:
- Models & Warpers: Tools and scripts for model loading, wrapping, and pre-processing.
- Datasets: The datasets used for training, validation, and testing the fused models.
- Metrics: The performance metrics used to evaluate the models, providing a comprehensive understanding of their capabilities.
5. **YAML Configurations**: Central to the project's modularity, YAML files are used to configure models, datasets, and metrics, allowing seamless customization and scalability.

By organizing these components into a structured and modular codebase, FusionBench ensures flexibility, ease of use, and scalability for researchers and developers. The project not only serves as a benchmark but also as a robust platform for innovation in the realm of deep model fusion.

Expand All @@ -200,4 +194,20 @@ fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \
OPTION_1=VALUE_1 OPTION_2=VALUE_2 ...
```

This program will load the configuration file specified by `--config-path` and `--config-name`, and run the fusion algorithm on the model pool.
The pseudocode is as follows:

```python
# instantiate an algorithm, a modelpool object that manages the models,
# and a taskpool object that manages the tasks (dataset + metrics)
algorithm = load_algorithm(config.algorithm)
modelpool = load_modelpool(config.modelpool)
taskpool = load_taskpool(config.taskpool)

# run the fusion algorithm on the model pool
merged_model = algorithm.run(modelpool)
# evaluate the merged model on the tasks
report = taskpool.evaluate(merged_model)
```

For detailed information on the options available, you can refer to this [page](cli/fusion_bench.md).
Empty file added docs/algorithms/regmean.md
Empty file.
13 changes: 12 additions & 1 deletion docs/modelpool/clip_vit.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ fusion_bench method=clip_fisher_merging \
taskpool=clip-vit-classification_TA8
```

#### RegMean

merge CLIP-ViT-B/32 models using RegMean and evaluate on the eight tasks

```bash
fusion_bench method=clip_regmean \
modelpool=clip-vit-base-patch32_TA8 \
taskpool=clip-vit-classification_TA8
```

#### Task Arithmetic

merge CLIP-ViT-B/32 models using task arithmetic and evaluate on the eight tasks
Expand Down Expand Up @@ -319,7 +329,7 @@ We provide the experimental results of the CLIP-ViT models for open vocabulary i


=== "Table: Mutli-task model merging methods using CLIP-ViT-B/32 models."

| Model | SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD | Average |
| ------------------------------------- | ------ | ---- | -------- | ------- | ---- | ----- | ----- | ---- | ------- |
| Reference Results | | | | | | | | | |
Expand All @@ -328,6 +338,7 @@ We provide the experimental results of the CLIP-ViT models for open vocabulary i
| Model Merging | | | | | | | | | |
| Simple Averaging | 65.4 | 62.6 | 70.8 | 76.9 | 64.5 | 54.9 | 86.3 | 50.9 | 66.5 |
| Fisher Merging | 66.7 | 64.0 | 72.2 | 91.6 | 69.0 | 64.3 | 83.5 | 53.7 | 70.6 |
| RegMean | 67.8 | 68.9 | 82.5 | 94.4 | 90.6 | 79.2 | 97.6 | 63.2 | 80.5 |
| Task Arithmetic ($\lambda=0.3$) | 57.1 | 55.7 | 64.9 | 76.7 | 77.9 | 68.5 | 96.1 | 47.2 | 68.0 |
| Ties-Merging ($\lambda=0.3$) | 67.1 | 64.2 | 74.1 | 76.8 | 77.7 | 69.4 | 94.1 | 54.0 | 72.2 |
| Task-wise AdaMerging ($\lambda=0.3$) | 58.6 | 56.9 | 69.8 | 82.4 | 70.3 | 58.9 | 97.2 | 55.3 | 68.7 |
Expand Down
3 changes: 3 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .we_moe.clip_we_moe import CLIPWeightEnsemblingMoEAlgorithm
from .weighted_average import WeightedAverageAlgorithm
from .fisher_merging.clip_fisher_merging import FisherMergingAlgorithmForCLIP
from .regmean.clip_regmean import RegMeanAlgorithmForCLIP


def load_algorithm_from_config(method_config: DictConfig):
Expand Down Expand Up @@ -59,6 +60,8 @@ def load_algorithm_from_config(method_config: DictConfig):
return WeightedAverageAlgorithm(method_config)
elif method_config.name == "clip_fisher_merging":
return FisherMergingAlgorithmForCLIP(method_config)
elif method_config.name == "clip_regmean":
return RegMeanAlgorithmForCLIP(method_config)
elif method_config.name == "task_arithmetic":
return TaskArithmeticAlgorithm(method_config)
elif method_config.name == "ties_merging":
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/fisher_merging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fisher_merging import FisherMergingAlgorithm, get_param_names_to_merge
2 changes: 1 addition & 1 deletion fusion_bench/method/fisher_merging/clip_fisher_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_fisher_weights(
)
if self._fabric is not None:
train_dataloader = self._fabric.setup_dataloaders(train_dataloader)
model = self._fabric.setup(deepcopy(model))
model = self._fabric.setup(model)
num_fisher_examples = self.config.num_fisher_examples
if num_fisher_examples % train_dataloader.batch_size != 0:
print(
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/regmean/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .regmean import RegMeanAlgorithm
178 changes: 178 additions & 0 deletions fusion_bench/method/regmean/clip_regmean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import logging
import os
from copy import deepcopy
from functools import cache
from typing import Dict, List, cast

import lightning as L
import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.nn.modules import Module
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel

from fusion_bench.modelpool.huggingface_clip_vision import HuggingFaceClipVisionPool
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
from fusion_bench.utils import timeit_context

from .regmean import RegMeanAlgorithm

log = logging.getLogger(__name__)


class RegMeanAlgorithmForCLIP(RegMeanAlgorithm):
_fabric: L.Fabric = None
_clip_processor: CLIPProcessor = None
zeroshot_weights = {}

def __init__(self, algorithm_config: DictConfig):
super().__init__(algorithm_config)

# setup fabric
if self._fabric is None and torch.cuda.is_available():
self._fabric = L.Fabric(devices=self.config.devices)
self._fabric.launch()

def on_regmean_start(self):
clip_model_config = self.modelpool.get_model_config("_pretrained_")

with timeit_context("Loading CLIP processor and pretrained CLIP model."):
self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
clip_model = CLIPModel.from_pretrained(clip_model_config.path)

clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
self.logit_scale = clip_model.logit_scale.exp()
if self._fabric is not None:
self.visual_projection = self._fabric.to_device(self.visual_projection)
self.logit_scale = self._fabric.to_device(self.logit_scale)

for task in self.modelpool.model_names:
cache_file = os.path.join(
self.config.cache_dir,
f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
)
if os.path.exists(cache_file):
log.info(f"Loading cached zeroshot weights for task: {task}")
zeroshot_weights = torch.load(cache_file, map_location="cpu")
else:
log.info(f"Construct zero shot classification head for task: {task}")
classnames, templates = get_classnames_and_templates(
cast(HuggingFaceClipVisionPool, self.modelpool)
.get_train_dataset_config(task)["dataset"]
.name
)
clip_classifier.set_classification_task(classnames, templates)
zeroshot_weights = clip_classifier.zeroshot_weights
log.info(f"save zeroshot weights to {cache_file}")
torch.save(zeroshot_weights, cache_file)
self.zeroshot_weights[task] = zeroshot_weights
if self._fabric is not None:
self.zeroshot_weights[task] = self._fabric.to_device(
self.zeroshot_weights[task]
)

def compute_logits(self, module, batch, task: str) -> Tensor:
images, _ = batch
text_embeds = self.zeroshot_weights[task]

image_embeds = module(images)[1]
image_embeds = self.visual_projection(image_embeds)

# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale
logits_per_image = logits_per_text.t()

return logits_per_image

def get_regmean_weights(
self,
model_name: str,
model: Module,
train_dataset,
param_names_to_merge: List[str],
linear_modules_to_merge: Dict[str, Module],
):
# setup dataloader
train_dataloader = DataLoader(
train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=self.config.num_workers,
pin_memory=True,
)
if self._fabric is not None:
train_dataloader = self._fabric.setup_dataloaders(train_dataloader)
model = self._fabric.setup(model)

def compute_regmean_weights(module_name: str):
"""
compute the regmean weights, a hook function to deal with each module's input
:param module_name: str, module name
:return:
"""

def hook(module: nn.Module, input: tuple, output: torch.Tensor):
# Tensor, shape (batch_size, sequence_length, hidden_dim)
x = cast(Tensor, input[0]).detach()
batch_num_actual_examples = x.shape[0]
# Tensor, shape (batch_size * sequence_length, hidden_dim)
x = x.reshape(-1, x.shape[-1])
# Tensor, shape (hidden_dim, hidden_dim)
xtx = torch.matmul(x.transpose(0, 1), x)
# store the averaged weights in regmean_weights
if module_name not in regmean_weights.keys():
regmean_weights[module_name] = xtx / x.shape[0]
num_computed_examples[module_name] = x.shape[0]
num_actual_examples[module_name] = batch_num_actual_examples
else:
regmean_weights[module_name] = (
regmean_weights[module_name]
* num_computed_examples[module_name]
+ xtx
) / (num_computed_examples[module_name] + x.shape[0])
num_computed_examples[module_name] += x.shape[0]
num_actual_examples[module_name] += batch_num_actual_examples

return hook

handles = []
# dictionary, regmean matrices for each linear module inputs
regmean_weights = {}
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
num_computed_examples = {}
# dictionary, number of actual examples used for computing regmean matrices
num_actual_examples = {}

for module_name, linear_module_to_merge in linear_modules_to_merge.items():
# register a hook in the forward process
handle = linear_module_to_merge.register_forward_hook(
compute_regmean_weights(module_name=module_name)
)
handles.append(handle)
for step, batch in tqdm(
enumerate(train_dataloader),
desc=f"computing regmean weights for model {model_name}",
):
if (
len(num_actual_examples) > 0
and list(num_actual_examples.values())[0]
>= self.config.num_regmean_examples
):
break
logits = self.compute_logits(model, batch, model_name)

# remove the added hook
for handle in handles:
handle.remove()

for module_name in regmean_weights.keys():
regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()

return regmean_weights
Loading

0 comments on commit a4b33fb

Please sign in to comment.