From b3b7f328482e0b9b5de055c120e4dd2212f87959 Mon Sep 17 00:00:00 2001 From: tanganke Date: Mon, 20 May 2024 17:01:15 +0800 Subject: [PATCH] update gpt2 taskpool --- config/taskpool/gpt-2_glue.yaml | 38 +++++ fusion_bench/dataset/__init__.py | 7 +- fusion_bench/dataset/gpt2_glue.py | 12 +- fusion_bench/modelpool/base_pool.py | 7 + .../huggingface_gpt2_classification.py | 31 +++- fusion_bench/scripts/cli.py | 4 +- fusion_bench/taskpool/__init__.py | 5 +- fusion_bench/taskpool/base_pool.py | 2 +- .../taskpool/gpt2_text_classification.py | 154 ++++++++++++++++++ fusion_bench/tasks/classification.py | 12 +- 10 files changed, 248 insertions(+), 24 deletions(-) create mode 100644 config/taskpool/gpt-2_glue.yaml diff --git a/config/taskpool/gpt-2_glue.yaml b/config/taskpool/gpt-2_glue.yaml new file mode 100644 index 00000000..099942a2 --- /dev/null +++ b/config/taskpool/gpt-2_glue.yaml @@ -0,0 +1,38 @@ +type: GPT2TextClassificationTaskPool +name: gpt2_classification_on_glue + +dataset_type: GPT2ClassificationGLUETask +tasks: + - name: cola + dataset: + name: cola + split: validation + - name: mnli + dataset: + name: mnli + split: validation_matched + - name: mrpc + dataset: + name: mrpc + split: validation + - name: qnli + dataset: + name: qnli + split: validation + - name: qqp + dataset: + name: qqp + split: validation + - name: rte + dataset: + name: rte + split: validation + - name: sst2 + dataset: + name: sst2 + split: validation + +tokenizer: gpt2 +batch_size: 8 +num_workers: 0 +fast_dev_run: ${fast_dev_run} diff --git a/fusion_bench/dataset/__init__.py b/fusion_bench/dataset/__init__.py index 91ea3cef..75f4a65f 100644 --- a/fusion_bench/dataset/__init__.py +++ b/fusion_bench/dataset/__init__.py @@ -10,11 +10,12 @@ def load_dataset_from_config(dataset_config: DictConfig): Load the dataset from the configuration. """ assert hasattr(dataset_config, "type"), "Dataset type not specified" - if dataset_config.type == "huggingface_image_classification": + if dataset_config.type == "instantiate": + return instantiate(dataset_config.object) + elif dataset_config.type == "huggingface_image_classification": if not hasattr(dataset_config, "path"): with open_dict(dataset_config): dataset_config.path = dataset_config.name - dataset = load_dataset( dataset_config.path, **(dataset_config.kwargs if hasattr(dataset_config, "kwargs") else {}), @@ -22,7 +23,5 @@ def load_dataset_from_config(dataset_config: DictConfig): if hasattr(dataset_config, "split"): dataset = dataset[dataset_config.split] return dataset - if dataset_config.type == "instantiate": - return instantiate(dataset_config.object) else: raise ValueError(f"Unknown dataset type: {dataset_config.type}") diff --git a/fusion_bench/dataset/gpt2_glue.py b/fusion_bench/dataset/gpt2_glue.py index 6be94246..6cca204c 100644 --- a/fusion_bench/dataset/gpt2_glue.py +++ b/fusion_bench/dataset/gpt2_glue.py @@ -35,14 +35,14 @@ def wrapper(*args, **kwargs): dataset = load_from_disk(str(cache_path)) else: dataset = func(*args, **kwargs) - dataset.save_to_disk(cache_path) + dataset.save_to_disk(str(cache_path.absolute())) return dataset return wrapper # Tokenize and convert examples to features -def mrpc_tokenize_function(tokenizer, examples): +def mrpc_tokenize_function(examples, tokenizer): inputs = tokenizer( examples["sentence1"], examples["sentence2"], @@ -53,7 +53,7 @@ def mrpc_tokenize_function(tokenizer, examples): return inputs -def mnli_tokenize_function(tokenizer, examples): +def mnli_tokenize_function(examples, tokenizer): inputs = tokenizer( examples["premise"], examples["hypothesis"], @@ -64,7 +64,7 @@ def mnli_tokenize_function(tokenizer, examples): return inputs -def cola_tokenize_function(tokenizer, examples): +def cola_tokenize_function(examples, tokenizer): inputs = tokenizer( examples["sentence"], padding="max_length", @@ -74,7 +74,7 @@ def cola_tokenize_function(tokenizer, examples): return inputs -def qnli_tokenize_function(tokenizer, examples): +def qnli_tokenize_function(examples, tokenizer): inputs = tokenizer( examples["question"], examples["sentence"], @@ -85,7 +85,7 @@ def qnli_tokenize_function(tokenizer, examples): return inputs -def qqp_tokenize_function(tokenizer, examples): +def qqp_tokenize_function(examples, tokenizer): inputs = tokenizer( examples["question1"], examples["question2"], diff --git a/fusion_bench/modelpool/base_pool.py b/fusion_bench/modelpool/base_pool.py index a4ae4f5d..4fedb8ab 100644 --- a/fusion_bench/modelpool/base_pool.py +++ b/fusion_bench/modelpool/base_pool.py @@ -82,6 +82,13 @@ def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module: """ raise NotImplementedError + def setup_taskpool(self, taskpool): + """ + Setup the taskpool before evaluation. + Such as setting the fabric, processor, tokenizer, etc. + """ + pass + class ListModelPool(ModelPool): """ diff --git a/fusion_bench/modelpool/huggingface_gpt2_classification.py b/fusion_bench/modelpool/huggingface_gpt2_classification.py index 6336d049..f0cff504 100644 --- a/fusion_bench/modelpool/huggingface_gpt2_classification.py +++ b/fusion_bench/modelpool/huggingface_gpt2_classification.py @@ -1,10 +1,11 @@ import logging from omegaconf import DictConfig -from transformers import GPT2ForSequenceClassification, GPT2Tokenizer +from transformers import GPT2ForSequenceClassification, GPT2Model, GPT2Tokenizer from fusion_bench.modelpool import ModelPool from fusion_bench.utils import timeit_context +from torch import nn log = logging.getLogger(__name__) @@ -18,26 +19,46 @@ def __init__(self, modelpool_config: DictConfig): @property def tokenizer(self): if self._tokenizer is None: + log.info(f"Loading tokenizer classification model.") if "_pretrained_" in self._model_names: - self._tokenizer = GPT2Tokenizer.from_pretrained( + tokenizer = GPT2Tokenizer.from_pretrained( self.get_model_config("_pretrained_")["path"] ) else: log.warning( "No pretrained model found in the model pool. Returning the first model." ) - self._tokenizer = GPT2Tokenizer.from_pretrained( + tokenizer = GPT2Tokenizer.from_pretrained( self.get_model_config(self.model_names[0])["path"] ) + tokenizer.model_max_length = 512 + if tokenizer.pad_token is None: + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + elif tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + else: + raise ValueError + self._tokenizer = tokenizer + return self._tokenizer - def load_model( + def load_classifier( self, model_config: str | DictConfig ) -> GPT2ForSequenceClassification: if isinstance(model_config, str): model_config = self.get_model_config(model_config) with timeit_context( - f"Loading GPT2 classification model: '{model_config.name}' from '{model_config.path}'." + f"Loading GPT2 classification head from {model_config.path}." ): model = GPT2ForSequenceClassification.from_pretrained(model_config.path) return model + + def load_model(self, model_config: str | DictConfig) -> GPT2Model: + model = self.load_classifier(model_config) + return model.transformer + + def setup_taskpool(self, taskpool): + if getattr(taskpool, "_tokenizer", None) is None: + taskpool._tokenizer = self.tokenizer + taskpool._modelpool = self diff --git a/fusion_bench/scripts/cli.py b/fusion_bench/scripts/cli.py index 863af3b6..3b106bdf 100644 --- a/fusion_bench/scripts/cli.py +++ b/fusion_bench/scripts/cli.py @@ -34,7 +34,9 @@ def run_model_fusion(cfg: DictConfig): if hasattr(cfg, "taskpool") and cfg.taskpool is not None: taskpool = load_taskpool_from_config(cfg.taskpool) if hasattr(modelpool, "_fabric") and hasattr(taskpool, "_fabric"): - taskpool._fabric = modelpool._fabric + if taskpool._fabric is None: + taskpool._fabric = modelpool._fabric + modelpool.setup_taskpool(taskpool) report = taskpool.evaluate(merged_model) if cfg.get("save_report", False): # save report (Dict) to a file diff --git a/fusion_bench/taskpool/__init__.py b/fusion_bench/taskpool/__init__.py index fc10793d..c6dda5e9 100644 --- a/fusion_bench/taskpool/__init__.py +++ b/fusion_bench/taskpool/__init__.py @@ -2,6 +2,7 @@ from .base_pool import TaskPool from .clip_image_classification import CLIPImageClassificationTaskPool +from .gpt2_text_classification import GPT2TextClassificationTaskPool from .dummy import DummyTaskPool @@ -9,8 +10,10 @@ def load_taskpool_from_config(taskpool_config: DictConfig): if hasattr(taskpool_config, "type"): if taskpool_config.type == "dummy": return DummyTaskPool(taskpool_config) - if taskpool_config.type == "clip_vit_classification": + elif taskpool_config.type == "clip_vit_classification": return CLIPImageClassificationTaskPool(taskpool_config) + elif taskpool_config.type == "GPT2TextClassificationTaskPool": + return GPT2TextClassificationTaskPool(taskpool_config) else: raise ValueError(f"Unknown task pool type: {taskpool_config.type}") else: diff --git a/fusion_bench/taskpool/base_pool.py b/fusion_bench/taskpool/base_pool.py index 441c641d..038a9580 100644 --- a/fusion_bench/taskpool/base_pool.py +++ b/fusion_bench/taskpool/base_pool.py @@ -42,7 +42,7 @@ def evaluate(self, model): dict: A dictionary containing the results of the evaluation for each task. """ report = {} - for task_name in self.task_names(): + for task_name in self.task_names: task = self.load_task(task_name) result = task.evaluate(model) report[task_name] = result diff --git a/fusion_bench/taskpool/gpt2_text_classification.py b/fusion_bench/taskpool/gpt2_text_classification.py index 5564cee9..73a03928 100644 --- a/fusion_bench/taskpool/gpt2_text_classification.py +++ b/fusion_bench/taskpool/gpt2_text_classification.py @@ -1,8 +1,162 @@ +import functools +import itertools +import logging +from copy import deepcopy + +import lightning as L +import torch +import torch.nn.functional as F from omegaconf import DictConfig, open_dict +from torch.nn.modules import Module +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MeanMetric +from tqdm.autonotebook import tqdm +from transformers import ( + GPT2ForSequenceClassification, + GPT2Model, + GPT2Tokenizer, + default_data_collator, +) +import fusion_bench +from fusion_bench.dataset.gpt2_glue import TokenizedGLUE from fusion_bench.taskpool import TaskPool +from fusion_bench.tasks import BaseTask + +log = logging.getLogger(__name__) + + +class GPT2ClassificationTask(BaseTask): + _taskpool: "GPT2TextClassificationTaskPool" = None + + def __init__( + self, task_config: DictConfig, fabric: L.Fabric, tokenizer: GPT2Tokenizer + ): + super().__init__(task_config) + self._fabric = fabric + self._tokenizer = tokenizer + + @property + def num_classes(self): + return len(self.test_dataset.unique("label")) + + @functools.cached_property + def dataset(self): + log.info('Loading dataset: "{}"'.format(self.config.dataset.name)) + dataset = TokenizedGLUE(tokenizer=self._tokenizer).load_dataset( + self.config.dataset.name + ) + return dataset + + @property + def test_dataset(self): + return self.dataset[self.config.dataset.split] + + @property + def test_loader(self): + test_dataset = self.test_dataset + loader = DataLoader( + test_dataset, + collate_fn=default_data_collator, + batch_size=self.config["batch_size"], + num_workers=self.config["num_workers"], + shuffle=False, + ) + if self._fabric is not None: + loader = self._fabric.setup_dataloaders(loader) + return loader + + def get_classifier(self, model: GPT2Model) -> GPT2ForSequenceClassification: + modelpool = self._taskpool._modelpool + classifier = modelpool.load_classifier(self.config.name) + classifier.transformer = deepcopy(model) + return classifier + + @torch.no_grad() + def evaluate(self, model: GPT2Model): + accuracy = Accuracy("multiclass", num_classes=self.num_classes) + loss_metric = MeanMetric() + model: GPT2ForSequenceClassification = self.get_classifier(model) + model = self._fabric.setup(model) + + if self.config.get("fast_dev_run", False): + log.info("Running under fast_dev_run mode, evaluating on a single batch.") + test_loader = itertools.islice(self.test_loader, 1) + else: + test_loader = self.test_loader + + for batch in ( + pbar := tqdm( + test_loader, desc="Evaluating", leave=False, dynamic_ncols=True + ) + ): + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + + outputs = model(input_ids, attention_mask=attention_mask) + logits = outputs.logits + loss = F.cross_entropy(logits, labels) + + acc = accuracy(logits.detach().cpu(), labels.detach().cpu()) + loss_metric.update(loss.detach().cpu()) + pbar.set_postfix({"accuracy": acc.item(), "loss": loss.item()}) + + acc = accuracy.compute().item() + loss = loss_metric.compute().item() + results = {"accuracy": acc, "loss": loss} + log.info(f"Results for task {self.config.name}: {results}") + return results class GPT2TextClassificationTaskPool(TaskPool): + _fabric: L.Fabric = None + _tokenizer: GPT2Tokenizer = None + _modelpool: "fusion_bench.modelpool.HuggingFaceGPT2ClassificationPool" = None + def __init__(self, taskpool_config: DictConfig): super().__init__(taskpool_config) + + @property + def fabric(self): + if self._fabric is not None: + return self._fabric + else: + self._fabric = L.Fabric(devices=1) + self._fabric.launch() + return self._fabric + + @property + def tokenizer(self): + if self._tokenizer is not None: + return self._tokenizer + else: + raise ValueError("Tokenizer not set") + + def prepare_dataset_config(self, dataset_config: DictConfig): + if not hasattr(dataset_config, "type"): + with open_dict(dataset_config): + dataset_config["type"] = self.config.dataset_type + return dataset_config + + def prepare_task_config(self, task_config: DictConfig): + for key in ["num_workers", "batch_size", "fast_dev_run"]: + if not hasattr(task_config, key): + with open_dict(task_config): + task_config[key] = self.config[key] + return task_config + + def load_task(self, task_name_or_config: str | DictConfig): + if isinstance(task_name_or_config, str): + task_config = self.get_task_config(task_name_or_config) + else: + task_config = task_name_or_config + task_config = self.prepare_task_config(task_config) + + # load the task from the configuration + task = GPT2ClassificationTask(task_config, self.fabric, self.tokenizer) + task._fabric = self._fabric + task._tokenizer = self._tokenizer + task._taskpool = self + + return task diff --git a/fusion_bench/tasks/classification.py b/fusion_bench/tasks/classification.py index 169e356c..2f7111a3 100644 --- a/fusion_bench/tasks/classification.py +++ b/fusion_bench/tasks/classification.py @@ -37,10 +37,10 @@ def test_loader(self): @torch.no_grad() def evaluate(self, classifier: nn.Module, device=None): - self.accuracy: MulticlassAccuracy = Accuracy( + accuracy: MulticlassAccuracy = Accuracy( task="multiclass", num_classes=self.num_classes ) - self.loss_metric = MeanMetric() + loss_metric = MeanMetric() # if fast_dev_run is set, we only evaluate on a batch of the data if self.config.get("fast_dev_run", False): @@ -60,11 +60,11 @@ def evaluate(self, classifier: nn.Module, device=None): logits: Tensor = classifier(inputs) loss = F.cross_entropy(logits, targets) - self.loss_metric.update(loss.detach().cpu()) - acc = self.accuracy(logits.detach().cpu(), targets.detach().cpu()) + loss_metric.update(loss.detach().cpu()) + acc = accuracy(logits.detach().cpu(), targets.detach().cpu()) pbar.set_postfix({"accuracy": acc.item(), "loss": loss.item()}) - acc = self.accuracy.compute().item() - loss = self.loss_metric.compute().item() + acc = accuracy.compute().item() + loss = loss_metric.compute().item() results = {"accuracy": acc, "loss": loss} return results