Skip to content

Commit

Permalink
udpate doc
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 22, 2024
1 parent df93484 commit 6ab4ae2
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 10 deletions.
10 changes: 10 additions & 0 deletions config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
type: huggingface_clip_vision
models:
- name: _pretrained_
path: google/flan-t5-base
- name: _pretrained_
path: openai/clip-vit-base-patch32
- name: svhn
path: tanganke/clip-vit-base-patch32_svhn
- name: mnist
path: tanganke/clip-vit-base-patch32_mnist
File renamed without changes.
24 changes: 24 additions & 0 deletions config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
type: clip_vit_classification
name: clip-vit-base-patch32_svhn_and_mnist # whatever you like

dataset_type: huggingface_image_classification
tasks:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: test
- name: mnist
dataset:
name: mnist
split: test

clip_model: openai/clip-vit-base-patch32
batch_size: 128
num_workers: 16
fast_dev_run: ${fast_dev_run}
8 changes: 6 additions & 2 deletions config/taskpool/flan-t5_glue_text_generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ tasks:
name: stsb
split: validation

# all flan-t5 models share the same tokenizer,
# so it is not necessary to change it when you evaluate other models,
# such as flan-t5-large, flan-t5-xxl
tokenizer: google/flan-t5-base
# cache directory for storing the preprocessed data
cache_dir: outputs
batch_size: 8
num_workers: 0
batch_size: 32
num_workers: 4
fast_dev_run: ${fast_dev_run}
77 changes: 77 additions & 0 deletions docs/algorithms/task_arithmetic.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@ $$ \theta = \theta_0 + \lambda \sum_{i} \tau_i. $$

The choice of the scaling coefficient $\lambda$ plays a crucial role in the final model performance. Typically, $\lambda$ is chosen based on validation set performance.

## Examples

To use the Task Arithmetic algorithm, you can use the `TaskArithmeticAlgorithm` class from the `fusion_bench.method` module.

```python
from fusion_bench.method import TaskArithmeticAlgorithm
from omegaconf import DictConfig

# Instantiate the TaskArithmeticAlgorithm
method_config = {'name': 'task_arithmetic', 'scaling_factor': 0.5}
algorithm = TaskArithmeticAlgorithm(DictConfig(method_config))

# Assume we have a dict of PyTorch models (nn.Module instances) that we want to merge.
# The models should all have the same architecture.
# the dict must contain the pre-trained model with the key '_pretrained_', and arbitrary number of fine-tuned models.
models = {'_pretrained_': nn.Linear(10,10), 'model_1': nn.Linear(10,10), 'model_2': nn.Linear(10,10)}

# Run the algorithm on the models.
# This will return a new model that is the result of task arithmetic on the input models.
merged_model = algorithm.run(models)
```


## Code Integration

Configuration template for the Task Arithmetic algorithm:
Expand All @@ -40,6 +63,60 @@ Use the following command to run the Task Arithmetic algorithm:
fusion_bench method=task_arithmetic ...
```

For example, to run the Task Arithmetic algorithm on two models with scaling factor 0.5:

```bash
fusion_bench method=task_arithmetic \
method.scaling_factor=0.5 \
modelpool=clip-vit-base-patch32_svhn_and_mnist \
taskpool=clip-vit-base-patch32_svhn_and_mnist
```

where the configuration for the model pool is:

```yaml title="config/modelpool/clip-vit-base-patch32_svhn_and_mnist.yaml"
type: huggingface_clip_vision
# the modelpool must contain the pre-trained model with the name '_pretrained_',
# and arbitrary number of fine-tuned models.
models:
- name: _pretrained_
path: google/flan-t5-base
- name: _pretrained_
path: openai/clip-vit-base-patch32
- name: svhn
path: tanganke/clip-vit-base-patch32_svhn
- name: mnist
path: tanganke/clip-vit-base-patch32_mnist
```
and the configuration for the task pool:
```yaml title="config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml"
type: clip_vit_classification

dataset_type: huggingface_image_classification
tasks:
- name: svhn
dataset:
type: instantiate
name: svhn
object:
_target_: datasets.load_dataset
_args_:
- svhn
- cropped_digits
split: test
- name: mnist
dataset:
name: mnist
split: test

...
```


## References

::: fusion_bench.method.TaskArithmeticAlgorithm
options:
members: true
Expand Down
9 changes: 9 additions & 0 deletions docs/taskpool/flan-t5_generation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flan-T5 Models for Text Generation Tasks

This task pool provides a set of text generation tasks from the GLUE benchmark for the Flan-T5 model.
Each task is associated with a dataset.
We report the exact match accuracy metric for CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2, and spearman's rho for STSB.

## References

::: fusion_bench.taskpool.flan_t5_glue_text_generation
9 changes: 9 additions & 0 deletions docs/taskpool/gpt2_classification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# GPT-2 Sequence Classification Tasks

This task pool provides a set of sequence classification tasks from the GLUE benchmark for the GPT-2 model.
Each task is associated with a dataset and the accuracy metric. The tasks are:
CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2.

## References

::: fusion_bench.taskpool.gpt2_text_classification
28 changes: 25 additions & 3 deletions fusion_bench/taskpool/flan_t5_glue_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from omegaconf import DictConfig, open_dict
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
T5ForConditionalGeneration,
default_data_collator,
)

from fusion_bench.tasks import BaseTask
from fusion_bench.tasks.flan_t5_text_generation.glue_evaluation import (
Expand Down Expand Up @@ -98,11 +103,18 @@ def evaluate(self, model):


class FlanT5GLUETextGenerationTaskPool(TaskPool):
"""
A task pool for FlanT5 GLUE text generation tasks.
This class manages the tasks and provides methods for loading and evaluating tasks.
"""
_fabric: L.Fabric = None
_tokenizer = None

@property
def tokenizer(self):
"""
Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.
"""
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer)
return self._tokenizer
Expand All @@ -117,6 +129,10 @@ def fabric(self):
return self._fabric

def load_task(self, task_name_or_config: str | DictConfig):
"""
Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
"""
if isinstance(task_name_or_config, str):
task_config = self.get_task_config(task_name_or_config)
else:
Expand All @@ -133,6 +149,12 @@ def load_task(self, task_name_or_config: str | DictConfig):
else:
raise ValueError(f"Unknown task {task_config.name}")

def evaluate(self, model):
def evaluate(self, model: T5ForConditionalGeneration):
if not isinstance(model, T5ForConditionalGeneration):
log.warning(
f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
)
model = self.fabric.setup(model)
return super().evaluate(model)
report = super().evaluate(model)
log.info(f"evaluation report: {report}")
return report
17 changes: 14 additions & 3 deletions fusion_bench/taskpool/gpt2_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def evaluate(self, model: GPT2Model):


class GPT2TextClassificationTaskPool(TaskPool):
"""
A task pool for GPT2 text classification tasks.
This class manages the tasks and provides methods for loading test dataset and evaluation.
"""

_fabric: L.Fabric = None
_tokenizer: GPT2Tokenizer = None
_modelpool: "fusion_bench.modelpool.HuggingFaceGPT2ClassificationPool" = None

def __init__(self, taskpool_config: DictConfig):
super().__init__(taskpool_config)

@property
def fabric(self):
if self._fabric is not None:
Expand All @@ -133,19 +135,28 @@ def tokenizer(self):
raise ValueError("Tokenizer not set")

def prepare_dataset_config(self, dataset_config: DictConfig):
"""
Set default values for dataset configuration.
"""
if not hasattr(dataset_config, "type"):
with open_dict(dataset_config):
dataset_config["type"] = self.config.dataset_type
return dataset_config

def prepare_task_config(self, task_config: DictConfig):
"""
Set default values for task configuration.
"""
for key in ["num_workers", "batch_size", "fast_dev_run"]:
if not hasattr(task_config, key):
with open_dict(task_config):
task_config[key] = self.config[key]
return task_config

def load_task(self, task_name_or_config: str | DictConfig):
"""
Loads a task given a task name or config. It prepares the task configuration and loads the task from it.
"""
if isinstance(task_name_or_config, str):
task_config = self.get_task_config(task_name_or_config)
else:
Expand Down
6 changes: 4 additions & 2 deletions fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def evaluate_accuracy(model, val_loader: DataLoader, tokenizer):

model = model.eval()
for batch_idx, batch in enumerate(
tqdm(val_loader, desc="Evaluate Exact Accuracy", leave=False)
tqdm(
val_loader, desc="Evaluate Exact Accuracy", leave=False, dynamic_ncols=True
)
):
with torch.no_grad():
outputs = model.generate(batch["input_ids"], max_length=10)
Expand Down Expand Up @@ -85,7 +87,7 @@ def evaluate_spearman_rho(model, val_loader: DataLoader, tokenizer):
all_preds: List[str] = []
all_labels: List[str] = []
for batch_idx, batch in enumerate(
tqdm(val_loader, desc="Evaluate Spearman Rho", leave=False)
tqdm(val_loader, desc="Evaluate Spearman Rho", leave=False, dynamic_ncols=True)
):
with torch.no_grad():
outputs = model.generate(batch["input_ids"], max_length=10)
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ nav:
- Dummy: taskpool/dummy.md
- Classification Tasks for CLIP: taskpool/clip_vit_classification.md
- GPT-2 Sequence Classification Tasks: taskpool/gpt2_classification.md
- Flan-T5 Models for Text Generation: taskpool/flan-t5_generation.md
- Command Line Interface:
- fusion_bench: cli/fusion_bench.md

Expand Down

0 comments on commit 6ab4ae2

Please sign in to comment.