Skip to content

Commit

Permalink
udpate documents
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 23, 2024
1 parent 6ab4ae2 commit c8f9c02
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 12 deletions.
2 changes: 0 additions & 2 deletions config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
type: huggingface_clip_vision
models:
- name: _pretrained_
path: google/flan-t5-base
- name: _pretrained_
path: openai/clip-vit-base-patch32
- name: svhn
Expand Down
4 changes: 4 additions & 0 deletions docs/algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ def run_model_fusion(cfg: DictConfig):
```

In summary, the Fusion Algorithm module is vital for the model merging operations within FusionBench, leveraging sophisticated techniques to ensure optimal fusion and performance evaluation of deep learning models. This capability makes it an indispensable tool for researchers and practitioners focusing on model fusion strategies.

### References

::: fusion_bench.method.load_algorithm_from_config
2 changes: 1 addition & 1 deletion docs/algorithms/task_arithmetic.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ In the rapidly advancing field of machine learning, multi-task learning has emer

<figure markdown="span">
![Image title](images/Task Arithmetic.png){ width="450" }
<figcaption>Task Arithmetic. Credit to <sup id="fnref:2"><a class="footnote-ref" href="#fn:2">2</a></sup></figcaption>
<figcaption>Task Arithmetic. This figure credited to <sup id="fnref:2"><a class="footnote-ref" href="#fn:2">2</a></sup></figcaption>
</figure>

**Task Vector**. A task vector is used to encapsulate the adjustments needed by a model to specialize in a specific task.
Expand Down
67 changes: 61 additions & 6 deletions docs/cli/fusion_bench.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,77 @@ fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \

| **Option** | **Default** | **Description** |
| ------------- | --------------------------- | --------------------------------------------------------------------------------------------------------------------------------- |
| **modelpool** | `clip-vit-base-patch32_TA8` | The pool of models to be fused. See [modelpool](../modelpool/README.md) for more information. |
| **method** | `dummy` | The fusion method to be used. See [fusion algorithms](../algorithms/README.md) for more information. |
| **taskpool** | `dummy` | The pool of tasks to be evaluated. See [taskpool](../taskpool/README.md) for more information. |
| **modelpool** | `clip-vit-base-patch32_TA8` | The pool of models to be fused. See [modelpool](../modelpool/README.md) for more information. |
| **method** | `dummy` | The fusion method to be used. See [fusion algorithms](../algorithms/README.md) for more information. |
| **taskpool** | `dummy` | The pool of tasks to be evaluated. See [taskpool](../taskpool/README.md) for more information. |
| print_config | `true` | Whether to print the configuration to the console. |
| save_report | `false` | the path to save the report. If not specified or is `false`, the report will not be saved. The report will be saved as json file. |
| --cfg, -c | | show the configuration instead of runing. |
| --help, -h | | show this help message and exit. |

## Basic Examples

merge multiple CLIP models using simple averaging:
merge two CLIP models using task arithmetic:

```bash
fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml taskpool=dummy
fusion_bench method=task_arithmetic \
modelpool=clip-vit-base-patch32_svhn_and_mnist \
taskpool=clip-vit-base-patch32_svhn_and_mnist
```

The overall configuration is as follows:

```{.yaml .anotate}
method: # (1)
name: task_arithmetic
scaling_factor: 0.5
modelpool: # (2)
type: huggingface_clip_vision
models:
- name: _pretrained_
path: openai/clip-vit-base-patch32
- name: svhn
path: tanganke/clip-vit-base-patch32_svhn
- name: mnist
path: tanganke/clip-vit-base-patch32_mnist
taskpool: # (3)
type: clip_vit_classification
name: clip-vit-base-patch32_svhn_and_mnist
dataset_type: huggingface_image_classification
tasks:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: test
- name: mnist
dataset:
name: mnist
split: test
clip_model: openai/clip-vit-base-patch32
batch_size: 128
num_workers: 16
fast_dev_run: ${fast_dev_run}
fast_dev_run: false
print_config: true
save_report: false
```

1. Configuration for method, `fusion_bench.method.load_algorithm_from_config` checks the 'name' attribute of the configuration and returns an instance of the corresponding algorithm.
2. Configuration for model pool, `fusion_bench.modelpool.load_modelpool_from_config` checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool.
3. Configuration for task pool, `fusion_bench.taskpool.load_taskpool_from_config` checks the 'type' attribute of the configuration and returns an instance of the corresponding task pool.


## Options
merge multiple CLIP models using simple averaging:

```bash
fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml taskpool=dummy
```


## References
Expand Down
2 changes: 2 additions & 0 deletions docs/modelpool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ model = modelpool.load_model('model_name')

## References

::: fusion_bench.modelpool.load_modelpool_from_config

::: fusion_bench.modelpool.ModelPool

[^1]: AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
3 changes: 3 additions & 0 deletions docs/taskpool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ A taskpool is specified by a `yaml` configuration file, which often contains the
- `dataset`: The dataset used for the task.
- `metric`: The metric used to evaluate the performance of the model on the task.

### References

::: fusion_bench.taskpool.load_taskpool_from_config

::: fusion_bench.taskpool.TaskPool
15 changes: 15 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@


def load_algorithm_from_config(method_config: DictConfig):
"""
Loads an algorithm based on the provided configuration.
The function checks the 'name' attribute of the configuration and returns an instance of the corresponding algorithm.
If the 'name' attribute is not found or does not match any known algorithm names, a ValueError is raised.
Args:
method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.
Returns:
An instance of the specified algorithm.
Raises:
ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
"""
if method_config.name == "dummy":
return DummyAlgorithm(method_config)
# model ensemble methods
Expand Down
7 changes: 4 additions & 3 deletions fusion_bench/method/ties_merging/ties_merging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from copy import deepcopy
from typing import List, Mapping, Union
from typing import Dict, List, Mapping, Union

import torch
from torch import Tensor, nn

from ...modelpool import ModelPool
from ...modelpool import ModelPool, to_modelpool
from ...utils.type import _StateDict
from ..base_algorithm import ModelFusionAlgorithm
from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_state_dict
Expand All @@ -15,8 +15,9 @@

class TiesMergingAlgorithm(ModelFusionAlgorithm):
@torch.no_grad()
def run(self, modelpool: ModelPool):
def run(self, modelpool: ModelPool | Dict[str, nn.Module]):
log.info("Fusing models using ties merging.")
modelpool = to_modelpool(modelpool)
remove_keys = self.config.get("remove_keys", [])
merge_func = self.config.get("merge_func", "sum")
scaling_factor = self.config.scaling_factor
Expand Down
15 changes: 15 additions & 0 deletions fusion_bench/modelpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@


def load_modelpool_from_config(modelpool_config: DictConfig):
"""
Loads a model pool based on the provided configuration.
The function checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool.
If the 'type' attribute is not found or does not match any known model pool types, a ValueError is raised.
Args:
modelpool_config (DictConfig): The configuration for the model pool. Must contain a 'type' attribute that specifies the type of the model pool.
Returns:
An instance of the specified model pool.
Raises:
ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
"""
if hasattr(modelpool_config, "type"):
if modelpool_config.type == "huggingface_clip_vision":
return HuggingFaceClipVisionPool(modelpool_config)
Expand Down
15 changes: 15 additions & 0 deletions fusion_bench/taskpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@


def load_taskpool_from_config(taskpool_config: DictConfig):
"""
Loads a task pool based on the provided configuration.
The function checks the 'type' attribute of the configuration and returns an instance of the corresponding task pool.
If the 'type' attribute is not found or does not match any known task pool types, a ValueError is raised.
Args:
taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.
Returns:
An instance of the specified task pool.
Raises:
ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
"""
if hasattr(taskpool_config, "type"):
if taskpool_config.type == "dummy":
return DummyTaskPool(taskpool_config)
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ theme:
name: material
features:
- toc.follow
- content.code.annotate

repo_url: https://github.com/tanganke/fusion_bench

0 comments on commit c8f9c02

Please sign in to comment.