Skip to content

Commit

Permalink
update gpt2 taskpool
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 20, 2024
1 parent 0b3a7c7 commit b3b7f32
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 24 deletions.
38 changes: 38 additions & 0 deletions config/taskpool/gpt-2_glue.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
type: GPT2TextClassificationTaskPool
name: gpt2_classification_on_glue

dataset_type: GPT2ClassificationGLUETask
tasks:
- name: cola
dataset:
name: cola
split: validation
- name: mnli
dataset:
name: mnli
split: validation_matched
- name: mrpc
dataset:
name: mrpc
split: validation
- name: qnli
dataset:
name: qnli
split: validation
- name: qqp
dataset:
name: qqp
split: validation
- name: rte
dataset:
name: rte
split: validation
- name: sst2
dataset:
name: sst2
split: validation

tokenizer: gpt2
batch_size: 8
num_workers: 0
fast_dev_run: ${fast_dev_run}
7 changes: 3 additions & 4 deletions fusion_bench/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ def load_dataset_from_config(dataset_config: DictConfig):
Load the dataset from the configuration.
"""
assert hasattr(dataset_config, "type"), "Dataset type not specified"
if dataset_config.type == "huggingface_image_classification":
if dataset_config.type == "instantiate":
return instantiate(dataset_config.object)
elif dataset_config.type == "huggingface_image_classification":
if not hasattr(dataset_config, "path"):
with open_dict(dataset_config):
dataset_config.path = dataset_config.name

dataset = load_dataset(
dataset_config.path,
**(dataset_config.kwargs if hasattr(dataset_config, "kwargs") else {}),
)
if hasattr(dataset_config, "split"):
dataset = dataset[dataset_config.split]
return dataset
if dataset_config.type == "instantiate":
return instantiate(dataset_config.object)
else:
raise ValueError(f"Unknown dataset type: {dataset_config.type}")
12 changes: 6 additions & 6 deletions fusion_bench/dataset/gpt2_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def wrapper(*args, **kwargs):
dataset = load_from_disk(str(cache_path))
else:
dataset = func(*args, **kwargs)
dataset.save_to_disk(cache_path)
dataset.save_to_disk(str(cache_path.absolute()))
return dataset

return wrapper


# Tokenize and convert examples to features
def mrpc_tokenize_function(tokenizer, examples):
def mrpc_tokenize_function(examples, tokenizer):
inputs = tokenizer(
examples["sentence1"],
examples["sentence2"],
Expand All @@ -53,7 +53,7 @@ def mrpc_tokenize_function(tokenizer, examples):
return inputs


def mnli_tokenize_function(tokenizer, examples):
def mnli_tokenize_function(examples, tokenizer):
inputs = tokenizer(
examples["premise"],
examples["hypothesis"],
Expand All @@ -64,7 +64,7 @@ def mnli_tokenize_function(tokenizer, examples):
return inputs


def cola_tokenize_function(tokenizer, examples):
def cola_tokenize_function(examples, tokenizer):
inputs = tokenizer(
examples["sentence"],
padding="max_length",
Expand All @@ -74,7 +74,7 @@ def cola_tokenize_function(tokenizer, examples):
return inputs


def qnli_tokenize_function(tokenizer, examples):
def qnli_tokenize_function(examples, tokenizer):
inputs = tokenizer(
examples["question"],
examples["sentence"],
Expand All @@ -85,7 +85,7 @@ def qnli_tokenize_function(tokenizer, examples):
return inputs


def qqp_tokenize_function(tokenizer, examples):
def qqp_tokenize_function(examples, tokenizer):
inputs = tokenizer(
examples["question1"],
examples["question2"],
Expand Down
7 changes: 7 additions & 0 deletions fusion_bench/modelpool/base_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
"""
raise NotImplementedError

def setup_taskpool(self, taskpool):
"""
Setup the taskpool before evaluation.
Such as setting the fabric, processor, tokenizer, etc.
"""
pass


class ListModelPool(ModelPool):
"""
Expand Down
31 changes: 26 additions & 5 deletions fusion_bench/modelpool/huggingface_gpt2_classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging

from omegaconf import DictConfig
from transformers import GPT2ForSequenceClassification, GPT2Tokenizer
from transformers import GPT2ForSequenceClassification, GPT2Model, GPT2Tokenizer

from fusion_bench.modelpool import ModelPool
from fusion_bench.utils import timeit_context
from torch import nn

log = logging.getLogger(__name__)

Expand All @@ -18,26 +19,46 @@ def __init__(self, modelpool_config: DictConfig):
@property
def tokenizer(self):
if self._tokenizer is None:
log.info(f"Loading tokenizer classification model.")
if "_pretrained_" in self._model_names:
self._tokenizer = GPT2Tokenizer.from_pretrained(
tokenizer = GPT2Tokenizer.from_pretrained(
self.get_model_config("_pretrained_")["path"]
)
else:
log.warning(
"No pretrained model found in the model pool. Returning the first model."
)
self._tokenizer = GPT2Tokenizer.from_pretrained(
tokenizer = GPT2Tokenizer.from_pretrained(
self.get_model_config(self.model_names[0])["path"]
)
tokenizer.model_max_length = 512
if tokenizer.pad_token is None:
if tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
elif tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError
self._tokenizer = tokenizer
return self._tokenizer

def load_model(
def load_classifier(
self, model_config: str | DictConfig
) -> GPT2ForSequenceClassification:
if isinstance(model_config, str):
model_config = self.get_model_config(model_config)

with timeit_context(
f"Loading GPT2 classification model: '{model_config.name}' from '{model_config.path}'."
f"Loading GPT2 classification head from {model_config.path}."
):
model = GPT2ForSequenceClassification.from_pretrained(model_config.path)
return model

def load_model(self, model_config: str | DictConfig) -> GPT2Model:
model = self.load_classifier(model_config)
return model.transformer

def setup_taskpool(self, taskpool):
if getattr(taskpool, "_tokenizer", None) is None:
taskpool._tokenizer = self.tokenizer
taskpool._modelpool = self
4 changes: 3 additions & 1 deletion fusion_bench/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def run_model_fusion(cfg: DictConfig):
if hasattr(cfg, "taskpool") and cfg.taskpool is not None:
taskpool = load_taskpool_from_config(cfg.taskpool)
if hasattr(modelpool, "_fabric") and hasattr(taskpool, "_fabric"):
taskpool._fabric = modelpool._fabric
if taskpool._fabric is None:
taskpool._fabric = modelpool._fabric
modelpool.setup_taskpool(taskpool)
report = taskpool.evaluate(merged_model)
if cfg.get("save_report", False):
# save report (Dict) to a file
Expand Down
5 changes: 4 additions & 1 deletion fusion_bench/taskpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from .base_pool import TaskPool
from .clip_image_classification import CLIPImageClassificationTaskPool
from .gpt2_text_classification import GPT2TextClassificationTaskPool
from .dummy import DummyTaskPool


def load_taskpool_from_config(taskpool_config: DictConfig):
if hasattr(taskpool_config, "type"):
if taskpool_config.type == "dummy":
return DummyTaskPool(taskpool_config)
if taskpool_config.type == "clip_vit_classification":
elif taskpool_config.type == "clip_vit_classification":
return CLIPImageClassificationTaskPool(taskpool_config)
elif taskpool_config.type == "GPT2TextClassificationTaskPool":
return GPT2TextClassificationTaskPool(taskpool_config)
else:
raise ValueError(f"Unknown task pool type: {taskpool_config.type}")
else:
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/taskpool/base_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def evaluate(self, model):
dict: A dictionary containing the results of the evaluation for each task.
"""
report = {}
for task_name in self.task_names():
for task_name in self.task_names:
task = self.load_task(task_name)
result = task.evaluate(model)
report[task_name] = result
Expand Down
154 changes: 154 additions & 0 deletions fusion_bench/taskpool/gpt2_text_classification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,162 @@
import functools
import itertools
import logging
from copy import deepcopy

import lightning as L
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, open_dict
from torch.nn.modules import Module
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, MeanMetric
from tqdm.autonotebook import tqdm
from transformers import (
GPT2ForSequenceClassification,
GPT2Model,
GPT2Tokenizer,
default_data_collator,
)

import fusion_bench
from fusion_bench.dataset.gpt2_glue import TokenizedGLUE
from fusion_bench.taskpool import TaskPool
from fusion_bench.tasks import BaseTask

log = logging.getLogger(__name__)


class GPT2ClassificationTask(BaseTask):
_taskpool: "GPT2TextClassificationTaskPool" = None

def __init__(
self, task_config: DictConfig, fabric: L.Fabric, tokenizer: GPT2Tokenizer
):
super().__init__(task_config)
self._fabric = fabric
self._tokenizer = tokenizer

@property
def num_classes(self):
return len(self.test_dataset.unique("label"))

@functools.cached_property
def dataset(self):
log.info('Loading dataset: "{}"'.format(self.config.dataset.name))
dataset = TokenizedGLUE(tokenizer=self._tokenizer).load_dataset(
self.config.dataset.name
)
return dataset

@property
def test_dataset(self):
return self.dataset[self.config.dataset.split]

@property
def test_loader(self):
test_dataset = self.test_dataset
loader = DataLoader(
test_dataset,
collate_fn=default_data_collator,
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=False,
)
if self._fabric is not None:
loader = self._fabric.setup_dataloaders(loader)
return loader

def get_classifier(self, model: GPT2Model) -> GPT2ForSequenceClassification:
modelpool = self._taskpool._modelpool
classifier = modelpool.load_classifier(self.config.name)
classifier.transformer = deepcopy(model)
return classifier

@torch.no_grad()
def evaluate(self, model: GPT2Model):
accuracy = Accuracy("multiclass", num_classes=self.num_classes)
loss_metric = MeanMetric()
model: GPT2ForSequenceClassification = self.get_classifier(model)
model = self._fabric.setup(model)

if self.config.get("fast_dev_run", False):
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
test_loader = itertools.islice(self.test_loader, 1)
else:
test_loader = self.test_loader

for batch in (
pbar := tqdm(
test_loader, desc="Evaluating", leave=False, dynamic_ncols=True
)
):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]

outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
loss = F.cross_entropy(logits, labels)

acc = accuracy(logits.detach().cpu(), labels.detach().cpu())
loss_metric.update(loss.detach().cpu())
pbar.set_postfix({"accuracy": acc.item(), "loss": loss.item()})

acc = accuracy.compute().item()
loss = loss_metric.compute().item()
results = {"accuracy": acc, "loss": loss}
log.info(f"Results for task {self.config.name}: {results}")
return results


class GPT2TextClassificationTaskPool(TaskPool):
_fabric: L.Fabric = None
_tokenizer: GPT2Tokenizer = None
_modelpool: "fusion_bench.modelpool.HuggingFaceGPT2ClassificationPool" = None

def __init__(self, taskpool_config: DictConfig):
super().__init__(taskpool_config)

@property
def fabric(self):
if self._fabric is not None:
return self._fabric
else:
self._fabric = L.Fabric(devices=1)
self._fabric.launch()
return self._fabric

@property
def tokenizer(self):
if self._tokenizer is not None:
return self._tokenizer
else:
raise ValueError("Tokenizer not set")

def prepare_dataset_config(self, dataset_config: DictConfig):
if not hasattr(dataset_config, "type"):
with open_dict(dataset_config):
dataset_config["type"] = self.config.dataset_type
return dataset_config

def prepare_task_config(self, task_config: DictConfig):
for key in ["num_workers", "batch_size", "fast_dev_run"]:
if not hasattr(task_config, key):
with open_dict(task_config):
task_config[key] = self.config[key]
return task_config

def load_task(self, task_name_or_config: str | DictConfig):
if isinstance(task_name_or_config, str):
task_config = self.get_task_config(task_name_or_config)
else:
task_config = task_name_or_config
task_config = self.prepare_task_config(task_config)

# load the task from the configuration
task = GPT2ClassificationTask(task_config, self.fabric, self.tokenizer)
task._fabric = self._fabric
task._tokenizer = self._tokenizer
task._taskpool = self

return task
Loading

0 comments on commit b3b7f32

Please sign in to comment.