diff --git a/config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml b/config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml index 0950a67b..41bcb28f 100644 --- a/config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml +++ b/config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml @@ -1,7 +1,5 @@ type: huggingface_clip_vision models: - - name: _pretrained_ - path: google/flan-t5-base - name: _pretrained_ path: openai/clip-vit-base-patch32 - name: svhn diff --git a/docs/algorithms/README.md b/docs/algorithms/README.md index ce0b6330..43a044bf 100644 --- a/docs/algorithms/README.md +++ b/docs/algorithms/README.md @@ -44,3 +44,7 @@ def run_model_fusion(cfg: DictConfig): ``` In summary, the Fusion Algorithm module is vital for the model merging operations within FusionBench, leveraging sophisticated techniques to ensure optimal fusion and performance evaluation of deep learning models. This capability makes it an indispensable tool for researchers and practitioners focusing on model fusion strategies. + +### References + +::: fusion_bench.method.load_algorithm_from_config diff --git a/docs/algorithms/task_arithmetic.md b/docs/algorithms/task_arithmetic.md index a6059cd4..3929d0dd 100644 --- a/docs/algorithms/task_arithmetic.md +++ b/docs/algorithms/task_arithmetic.md @@ -4,7 +4,7 @@ In the rapidly advancing field of machine learning, multi-task learning has emer
![Image title](images/Task Arithmetic.png){ width="450" } -
Task Arithmetic. Credit to 2
+
Task Arithmetic. This figure credited to 2
**Task Vector**. A task vector is used to encapsulate the adjustments needed by a model to specialize in a specific task. diff --git a/docs/cli/fusion_bench.md b/docs/cli/fusion_bench.md index cb122d1a..4e10391d 100644 --- a/docs/cli/fusion_bench.md +++ b/docs/cli/fusion_bench.md @@ -14,22 +14,77 @@ fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \ | **Option** | **Default** | **Description** | | ------------- | --------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | -| **modelpool** | `clip-vit-base-patch32_TA8` | The pool of models to be fused. See [modelpool](../modelpool/README.md) for more information. | -| **method** | `dummy` | The fusion method to be used. See [fusion algorithms](../algorithms/README.md) for more information. | -| **taskpool** | `dummy` | The pool of tasks to be evaluated. See [taskpool](../taskpool/README.md) for more information. | +| **modelpool** | `clip-vit-base-patch32_TA8` | The pool of models to be fused. See [modelpool](../modelpool/README.md) for more information. | +| **method** | `dummy` | The fusion method to be used. See [fusion algorithms](../algorithms/README.md) for more information. | +| **taskpool** | `dummy` | The pool of tasks to be evaluated. See [taskpool](../taskpool/README.md) for more information. | | print_config | `true` | Whether to print the configuration to the console. | | save_report | `false` | the path to save the report. If not specified or is `false`, the report will not be saved. The report will be saved as json file. | +| --cfg, -c | | show the configuration instead of runing. | +| --help, -h | | show this help message and exit. | ## Basic Examples -merge multiple CLIP models using simple averaging: +merge two CLIP models using task arithmetic: ```bash -fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml taskpool=dummy +fusion_bench method=task_arithmetic \ + modelpool=clip-vit-base-patch32_svhn_and_mnist \ + taskpool=clip-vit-base-patch32_svhn_and_mnist +``` + +The overall configuration is as follows: + +```{.yaml .anotate} +method: # (1) + name: task_arithmetic + scaling_factor: 0.5 +modelpool: # (2) + type: huggingface_clip_vision + models: + - name: _pretrained_ + path: openai/clip-vit-base-patch32 + - name: svhn + path: tanganke/clip-vit-base-patch32_svhn + - name: mnist + path: tanganke/clip-vit-base-patch32_mnist +taskpool: # (3) + type: clip_vit_classification + name: clip-vit-base-patch32_svhn_and_mnist + dataset_type: huggingface_image_classification + tasks: + - name: svhn + dataset: + type: instantiate + name: svhn + object: + _target_: datasets.load_dataset + _args_: + - svhn + - cropped_digits + split: test + - name: mnist + dataset: + name: mnist + split: test + clip_model: openai/clip-vit-base-patch32 + batch_size: 128 + num_workers: 16 + fast_dev_run: ${fast_dev_run} +fast_dev_run: false +print_config: true +save_report: false ``` +1. Configuration for method, `fusion_bench.method.load_algorithm_from_config` checks the 'name' attribute of the configuration and returns an instance of the corresponding algorithm. +2. Configuration for model pool, `fusion_bench.modelpool.load_modelpool_from_config` checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool. +3. Configuration for task pool, `fusion_bench.taskpool.load_taskpool_from_config` checks the 'type' attribute of the configuration and returns an instance of the corresponding task pool. + -## Options +merge multiple CLIP models using simple averaging: + +```bash +fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml taskpool=dummy +``` ## References diff --git a/docs/modelpool/README.md b/docs/modelpool/README.md index ad89f106..c03d28a0 100644 --- a/docs/modelpool/README.md +++ b/docs/modelpool/README.md @@ -36,6 +36,8 @@ model = modelpool.load_model('model_name') ## References +::: fusion_bench.modelpool.load_modelpool_from_config + ::: fusion_bench.modelpool.ModelPool [^1]: AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575 diff --git a/docs/taskpool/README.md b/docs/taskpool/README.md index 793fbe9c..774c1731 100644 --- a/docs/taskpool/README.md +++ b/docs/taskpool/README.md @@ -12,5 +12,8 @@ A taskpool is specified by a `yaml` configuration file, which often contains the - `dataset`: The dataset used for the task. - `metric`: The metric used to evaluate the performance of the model on the task. +### References + +::: fusion_bench.taskpool.load_taskpool_from_config ::: fusion_bench.taskpool.TaskPool \ No newline at end of file diff --git a/fusion_bench/method/__init__.py b/fusion_bench/method/__init__.py index 761f9c1a..5cb20ad8 100644 --- a/fusion_bench/method/__init__.py +++ b/fusion_bench/method/__init__.py @@ -27,6 +27,21 @@ def load_algorithm_from_config(method_config: DictConfig): + """ + Loads an algorithm based on the provided configuration. + + The function checks the 'name' attribute of the configuration and returns an instance of the corresponding algorithm. + If the 'name' attribute is not found or does not match any known algorithm names, a ValueError is raised. + + Args: + method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm. + + Returns: + An instance of the specified algorithm. + + Raises: + ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names. + """ if method_config.name == "dummy": return DummyAlgorithm(method_config) # model ensemble methods diff --git a/fusion_bench/method/ties_merging/ties_merging.py b/fusion_bench/method/ties_merging/ties_merging.py index 9e61c23f..5870a42d 100644 --- a/fusion_bench/method/ties_merging/ties_merging.py +++ b/fusion_bench/method/ties_merging/ties_merging.py @@ -1,11 +1,11 @@ import logging from copy import deepcopy -from typing import List, Mapping, Union +from typing import Dict, List, Mapping, Union import torch from torch import Tensor, nn -from ...modelpool import ModelPool +from ...modelpool import ModelPool, to_modelpool from ...utils.type import _StateDict from ..base_algorithm import ModelFusionAlgorithm from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_state_dict @@ -15,8 +15,9 @@ class TiesMergingAlgorithm(ModelFusionAlgorithm): @torch.no_grad() - def run(self, modelpool: ModelPool): + def run(self, modelpool: ModelPool | Dict[str, nn.Module]): log.info("Fusing models using ties merging.") + modelpool = to_modelpool(modelpool) remove_keys = self.config.get("remove_keys", []) merge_func = self.config.get("merge_func", "sum") scaling_factor = self.config.scaling_factor diff --git a/fusion_bench/modelpool/__init__.py b/fusion_bench/modelpool/__init__.py index 3c43b348..010e82f6 100644 --- a/fusion_bench/modelpool/__init__.py +++ b/fusion_bench/modelpool/__init__.py @@ -8,6 +8,21 @@ def load_modelpool_from_config(modelpool_config: DictConfig): + """ + Loads a model pool based on the provided configuration. + + The function checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool. + If the 'type' attribute is not found or does not match any known model pool types, a ValueError is raised. + + Args: + modelpool_config (DictConfig): The configuration for the model pool. Must contain a 'type' attribute that specifies the type of the model pool. + + Returns: + An instance of the specified model pool. + + Raises: + ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types. + """ if hasattr(modelpool_config, "type"): if modelpool_config.type == "huggingface_clip_vision": return HuggingFaceClipVisionPool(modelpool_config) diff --git a/fusion_bench/taskpool/__init__.py b/fusion_bench/taskpool/__init__.py index f6e26d62..15b1b2d4 100644 --- a/fusion_bench/taskpool/__init__.py +++ b/fusion_bench/taskpool/__init__.py @@ -8,6 +8,21 @@ def load_taskpool_from_config(taskpool_config: DictConfig): + """ + Loads a task pool based on the provided configuration. + + The function checks the 'type' attribute of the configuration and returns an instance of the corresponding task pool. + If the 'type' attribute is not found or does not match any known task pool types, a ValueError is raised. + + Args: + taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool. + + Returns: + An instance of the specified task pool. + + Raises: + ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types. + """ if hasattr(taskpool_config, "type"): if taskpool_config.type == "dummy": return DummyTaskPool(taskpool_config) diff --git a/mkdocs.yml b/mkdocs.yml index 51751383..c8056768 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,6 +94,7 @@ theme: name: material features: - toc.follow + - content.code.annotate repo_url: https://github.com/tanganke/fusion_bench