-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add code to load glue benchmark for seq2seq generation
- Loading branch information
Showing
6 changed files
with
553 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
71 changes: 71 additions & 0 deletions
71
fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
This scripts preprocess any NLP dataset into a text-to-text format. | ||
""" | ||
|
||
import json | ||
import os | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, Union | ||
|
||
from transformers import AutoTokenizer | ||
|
||
|
||
def preprocess( | ||
tokenizer: AutoTokenizer, | ||
input_text: str, | ||
target_text: str, | ||
tokenizer_kwawgs: Dict[str, Any] = None, | ||
): | ||
""" | ||
standard preprocess function for dataset. | ||
Preprocesses input and target text data using a tokenizer object and returns a dictionary of model inputs. | ||
Args: | ||
tokenizer: An instance of a tokenizer class used to preprocess text data. | ||
input_text (str): A string containing the input text data to be tokenized. | ||
target_text (str, optional): A string containing the target text data to be tokenized. If None, no target data is returned. | ||
Returns: | ||
A dictionary of model inputs containing the tokenized input and output data along with the modified labels tensor. | ||
""" | ||
if tokenizer_kwawgs is None: | ||
tokenizer_kwawgs = {} | ||
model_inputs = tokenizer(input_text, **tokenizer_kwawgs) | ||
if target_text is not None: | ||
labels = tokenizer(target_text, **tokenizer_kwawgs) | ||
labels = labels["input_ids"] | ||
labels[labels == tokenizer.pad_token_id] = -100 | ||
model_inputs["labels"] = labels | ||
return model_inputs | ||
|
||
|
||
class DatasetPreprocessor: | ||
def __init__( | ||
self, | ||
tokenizer: AutoTokenizer, | ||
tokenizer_kwargs: Dict[str, Any] = None, | ||
template: Union[str, Path, Dict] = None, | ||
): | ||
""" | ||
Initializes an instance of the datasets_preprocess class with a tokenizer object. | ||
Args: | ||
tokenizer: An instance of a tokenizer class used to preprocess text data. | ||
""" | ||
super().__init__() | ||
self.tokenizer = tokenizer | ||
self.tokenizer_kwargs = tokenizer_kwargs | ||
if template is not None: | ||
if isinstance(template, str): | ||
template = template | ||
assert os.path.exists( | ||
template | ||
), f"Template file not found at {template}" | ||
with open(template, "r") as f: | ||
self.template = json.load(f) | ||
elif isinstance(template, dict): | ||
self.template = template | ||
else: | ||
raise ValueError( | ||
"Template must be a path to a json file or a dictionary" | ||
) |
51 changes: 51 additions & 0 deletions
51
fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import logging | ||
import os | ||
from typing import Optional | ||
|
||
from datasets import load_dataset, load_from_disk | ||
|
||
from fusion_bench.utils import timeit_context | ||
|
||
from .glue_preprocessors import glue_processors | ||
from .glue_prompt_templates import glue_prompt_templates | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _load_glue_dataset(name, tokenizer): | ||
dataset = load_dataset("glue", name) | ||
preprocessor = glue_processors[name]( | ||
template=glue_prompt_templates[name], | ||
tokenizer=tokenizer, | ||
tokenizer_kwargs={ | ||
"padding": "max_length", | ||
"truncate": True, | ||
"return_tensors": "pt", | ||
}, | ||
) | ||
dataset = dataset.map( | ||
preprocessor, | ||
batched=True, | ||
remove_columns=dataset["train"].column_names, | ||
num_proc=1, | ||
) | ||
return dataset | ||
|
||
|
||
def load_glue_dataset(name, tokenizer, cache_dir: Optional[str]): | ||
with timeit_context(f"Loading {name} dataset"): | ||
if cache_dir is not None: | ||
if not os.path.exists(cache_dir): | ||
os.makedirs(cache_dir) | ||
cache_path = os.path.join( | ||
cache_dir, "flan-t5", f"_load_{name}_dataset_cached" | ||
) | ||
if os.path.exists(cache_path): | ||
return load_from_disk(cache_path) | ||
else: | ||
dataset = _load_glue_dataset(name, tokenizer) | ||
log.info(f"Saving {name} dataset to {cache_path}") | ||
dataset.save_to_disk(cache_path) | ||
return dataset | ||
else: | ||
return _load_glue_dataset(name, tokenizer) |
Oops, something went wrong.