-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/encoder decoder dq restructure #766
Changes from 8 commits
6e31cb3
ac6fb6b
2749577
75b6fbc
a34192d
3a9e7bd
46ee172
e0fd7b9
24a3a5a
1cd8a70
7356f09
43e4cab
18ac8c1
8c113cd
83ec3f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,45 @@ | ||
from typing import List, Optional | ||
from typing import List, Optional, Union | ||
from warnings import warn | ||
|
||
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast | ||
|
||
from dataquality.loggers.logger_config.seq2seq import seq2seq_logger_config | ||
from dataquality.exceptions import GalileoException | ||
from dataquality.loggers.logger_config.seq2seq.encoder_decoder import ( | ||
EncoderDecoderLoggerConfig, | ||
encoder_decoder_logger_config, | ||
) | ||
from dataquality.schemas.split import Split | ||
from dataquality.schemas.task_type import TaskType | ||
from dataquality.utils.helpers import check_noop | ||
from dataquality.utils.task_helpers import get_task_type | ||
|
||
|
||
# TODO Sync with Elliott on how to differentiate between the | ||
# encoder_decoder vs. decoder_only logger_configs in `watch` | ||
def _get_seg2seg_logger_config( | ||
task_type: TaskType, | ||
) -> Union[EncoderDecoderLoggerConfig]: | ||
"""Get the correct Seq2Seq logger_config based on the task_type. | ||
|
||
Choices between: | ||
1. EncoderDecoder: task_type.decoder_only | ||
2. DecoderOnly: task_type.decoder_only | ||
|
||
Raises an exception if the user has set / is using an incorrect task_type | ||
""" | ||
if task_type == task_type.seq2seq: # TODO Change to encoder_decoder | ||
return encoder_decoder_logger_config | ||
|
||
# TODO Change to encoder_decoder | ||
raise GalileoException( | ||
"Galileo's seq2seq watch method is only supported for seq2seq" | ||
) | ||
|
||
|
||
@check_noop | ||
def set_tokenizer( | ||
tokenizer: PreTrainedTokenizerFast, | ||
logger_config: Union[EncoderDecoderLoggerConfig], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wouldn't have logger config as a param for this cause this is technically a user facing fn and we wouldn't expect them to pass in a logger config |
||
max_input_tokens: Optional[int] = None, | ||
max_target_tokens: Optional[int] = None, | ||
) -> None: | ||
|
@@ -24,19 +51,18 @@ def set_tokenizer( | |
truncated after a certain length, which is set in the args max_input_tokens and | ||
max_target_tokens. | ||
""" | ||
task_type = get_task_type() | ||
assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq" | ||
assert isinstance( | ||
tokenizer, PreTrainedTokenizerFast | ||
), "Tokenizer must be an instance of PreTrainedTokenizerFast" | ||
assert getattr(tokenizer, "is_fast", False), "Tokenizer must be a fast tokenizer" | ||
for attr in ["encode", "decode", "encode_plus", "padding_side"]: | ||
assert hasattr(tokenizer, attr), f"Tokenizer must support `{attr}`" | ||
seq2seq_logger_config.tokenizer = tokenizer | ||
logger_config.tokenizer = tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to get config we could call the get data logger config helper in this fn |
||
|
||
seq2seq_logger_config.max_input_tokens = max_input_tokens | ||
if seq2seq_logger_config.max_input_tokens is None: | ||
seq2seq_logger_config.max_input_tokens = tokenizer.model_max_length | ||
# This is relevant only for Encoder Decoder Models | ||
logger_config.max_input_tokens = max_input_tokens | ||
if logger_config.max_input_tokens is None: | ||
logger_config.max_input_tokens = tokenizer.model_max_length | ||
warn( | ||
( | ||
"The argument max_input_tokens is not set, we will use the value " | ||
|
@@ -46,19 +72,26 @@ def set_tokenizer( | |
) | ||
) | ||
|
||
seq2seq_logger_config.max_target_tokens = max_target_tokens | ||
if seq2seq_logger_config.max_target_tokens is None: | ||
seq2seq_logger_config.max_target_tokens = tokenizer.model_max_length | ||
warn( | ||
( | ||
"The argument max_target_tokens is not set, we will use the value " | ||
f"{tokenizer.model_max_length} from tokenizer.model_max_length. If you " | ||
"tokenized the target with another value, this can lead to confusing " | ||
"insights about this training run." | ||
if type(logger_config) == EncoderDecoderLoggerConfig: | ||
logger_config.max_target_tokens = max_target_tokens | ||
if logger_config.max_target_tokens is None: | ||
logger_config.max_target_tokens = tokenizer.model_max_length | ||
warn( | ||
( | ||
"The argument max_target_tokens is not set, we will use the value " | ||
f"{tokenizer.model_max_length} from tokenizer.model_max_length. " | ||
f"If you tokenized the target with another value, this can lead " | ||
f"to confusing insights about this training run." | ||
) | ||
) | ||
else: | ||
warn( | ||
"The argument max_target_tokens is only used when working with " | ||
"EncoderDecoder models. This value will be ignored." | ||
) | ||
|
||
# Seq2Seq doesn't have labels but we need to set this to avoid validation errors | ||
seq2seq_logger_config.labels = [] | ||
logger_config.labels = [] | ||
|
||
|
||
@check_noop | ||
|
@@ -70,6 +103,7 @@ def watch( | |
max_input_tokens: Optional[int] = None, | ||
max_target_tokens: Optional[int] = None, | ||
) -> None: | ||
# TODO Update comment | ||
"""Seq2seq only. Log model generations for your run | ||
|
||
Iterates over a given dataset and logs the generations for each sample. | ||
|
@@ -81,16 +115,17 @@ def watch( | |
for consistency. | ||
""" | ||
task_type = get_task_type() | ||
assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq" | ||
# Get the corresponding logger config - handling error checking | ||
logger_config = _get_seg2seg_logger_config(task_type) | ||
assert isinstance( | ||
model, PreTrainedModel | ||
), "model must be an instance of transformers PreTrainedModel" | ||
assert model.can_generate(), "model must contain a `generate` method for seq2seq" | ||
|
||
set_tokenizer(tokenizer, max_input_tokens, max_target_tokens) | ||
set_tokenizer(tokenizer, logger_config, max_input_tokens, max_target_tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above! |
||
|
||
seq2seq_logger_config.model = model | ||
seq2seq_logger_config.generation_config = generation_config | ||
logger_config.model = model | ||
logger_config.generation_config = generation_config | ||
|
||
generatation_splits = generatation_splits or [] | ||
generation_splits_set = {Split.test} | ||
|
@@ -104,4 +139,4 @@ def watch( | |
|
||
generation_splits_set.add(Split[split]) | ||
|
||
seq2seq_logger_config.generation_splits = generation_splits_set | ||
logger_config.generation_splits = generation_splits_set |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from typing import Optional | ||
|
||
from vaex.dataframe import DataFrame | ||
|
||
from dataquality.loggers.data_logger.base_data_logger import ( | ||
MetasType, | ||
) | ||
from dataquality.loggers.data_logger.seq2seq.seq2seq import Seq2SeqDataLogger | ||
from dataquality.loggers.logger_config.seq2seq.encoder_decoder import ( | ||
EncoderDecoderLoggerConfig, | ||
encoder_decoder_logger_config, | ||
) | ||
from dataquality.schemas.seq2seq import Seq2SeqInputCols as C | ||
from dataquality.utils.seq2seq.offsets import ( | ||
align_tokens_to_character_spans, | ||
get_cutoff_from_saved_offsets, | ||
get_cutoff_from_truncated_tokenization, | ||
) | ||
|
||
|
||
class EncoderDecoderDataLogger(Seq2SeqDataLogger): | ||
# TODO UPDATE COMMENT!!! | ||
"""Logging input data for Seq2Seq fine-tuning tasks | ||
|
||
Logging input data for Seq2Seq requires 2 pieces of information: | ||
1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface | ||
(ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an | ||
`.is_fast` property that returns True if it's a fast tokenizer. | ||
This class must implement the `encode`, `decode`, and `encode_plus` methods | ||
|
||
You can set your tokenizer via the `set_tokenizer(tok)` function imported | ||
from `dataquality.integrations.seq2seq.hf` | ||
2. A dataset (pandas/huggingface etc) with input strings and output labels and ids. | ||
Ex: Billsum dataset, with `text` input and `summary` as the label | ||
id text summary | ||
0 SECTION 1. LIABILITY ... Shields a business entity ... | ||
1 SECTION 1. SHORT TITLE.\n\n ... Human Rights Information Act ... | ||
2 SECTION 1. SHORT TITLE.\n\n ... Jackie Robinson Commemorative Coin ... | ||
3 SECTION 1. NONRECOGNITION ... Amends the Internal Revenue Code to ... | ||
4 SECTION 1. SHORT TITLE.\n\n ... Native American Energy Act - (Sec. 3... | ||
|
||
You can log your dataset via the `dq.log_dataset` function, passing in the | ||
column mapping as necessary for `text`, `label`, and `id` | ||
`dq.log_dataset(ds, text="text", label="summary", id="id")` | ||
|
||
Putting it all together: | ||
from dataquality.integrations.seq2seq.hf import set_tokenizer | ||
from datasets import load_dataset | ||
from transformers import T5TokenizerFast | ||
|
||
tokenizer = T5TokenizerFast.from_pretrained("t5-small") | ||
ds = load_dataset("billsum") | ||
# Add `id` column to each dataset split as the idx | ||
ds = ds.map(lambda x,idx : {"id":idx},with_indices=True) | ||
dq.init("seq2seq") | ||
set_tokenizer(tokenizer) | ||
dq.log_dataset(ds["train"], label="summary", split="train") | ||
|
||
NOTE: We assume that the tokenizer you provide is the same tokenizer used for | ||
training. This must be true in order to align inputs and outputs correctly. Ensure | ||
all necessary properties (like `add_eos_token`) are set before setting your | ||
tokenizer so as to match the tokenization process to your training process. | ||
""" | ||
|
||
# TODO Change to encoder_decoder after updating API | ||
__logger_name__ = "seq2seq" # encoder_decoder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed it should be encoder_decoder |
||
logger_config: EncoderDecoderLoggerConfig = encoder_decoder_logger_config | ||
DATA_FOLDER_EXTENSION = {"emb": "hdf5", "prob": "hdf5", "data": "arrow"} | ||
|
||
def __init__(self, meta: Optional[MetasType] = None) -> None: | ||
super().__init__(meta) | ||
|
||
def validate_and_format(self) -> None: | ||
"""Format Encoder-Decoder Data Format | ||
|
||
Tokenize self.labels, using the user's `max_taget_tokens`. From | ||
the tokenized outputs generate the corresponding token alignments | ||
(i.e. label_offsets and lable_positions). | ||
|
||
Save the tokenized labels for each sample as `id_to_tokens`. This | ||
is essential during model logging for extracting GT token label | ||
information. | ||
|
||
Note: the parent Seq2SeqDataLogger.validate_and_format() handles | ||
common data type validation. | ||
""" | ||
super().validate_and_format() | ||
# TODO: question type checking does not work in super() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean? we can look into this together |
||
encoded_data = self.logger_config.tokenizer( # type: ignore | ||
self.labels, | ||
return_offsets_mapping=True, | ||
max_length=self.logger_config.max_target_tokens, | ||
truncation=True, | ||
) | ||
tokenized_labels = encoded_data["input_ids"] | ||
aligned_data = align_tokens_to_character_spans(encoded_data["offset_mapping"]) | ||
self.token_label_offsets = aligned_data.token_label_offsets | ||
self.token_label_positions = aligned_data.token_label_positions | ||
|
||
id_to_tokens = dict(zip(self.ids, tokenized_labels)) | ||
self.logger_config.id_to_tokens[self.token_map_key].update(id_to_tokens) | ||
|
||
@classmethod | ||
def calculate_cutoffs(cls, df: DataFrame) -> DataFrame: | ||
"""Calculate the cutoff index for the input and target strings. | ||
|
||
|
||
When using Encoder-Decoder models, the input AND target tokens are truncated | ||
based on the respective Encoder (input) / Decoder (target) max_lengths | ||
OR user specified max_lengths (note: these may be different between the | ||
Encoder and Decoder). | ||
|
||
The model only "sees"/processes the tokens that remain after truncation, | ||
for example if max_length=512 for the Encoder, no matter how long the Input, | ||
the model will only process the first 512 tokens and ignore the rest. | ||
|
||
This function adds two columns to the df: | ||
- 'input_cutoff': the position of the last character in the input. | ||
- 'target_cutoff': the position of the last character in the target. | ||
""" | ||
# Error checking | ||
super().calculate_cutoffs(df) | ||
|
||
# TODO we may be able to take advantage of shared code with Decoder | ||
tokenizer = cls.logger_config.tokenizer | ||
max_input_length = cls.logger_config.max_input_tokens | ||
df[C.input_cutoff.value] = get_cutoff_from_truncated_tokenization( | ||
df, C.text, tokenizer, max_input_length | ||
) | ||
|
||
target_offsets_colname = C.token_label_offsets | ||
if target_offsets_colname in df.get_column_names(): | ||
df[C.target_cutoff.value] = get_cutoff_from_saved_offsets( | ||
df, target_offsets_colname | ||
) | ||
|
||
return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we can just use the get current task type helpers, since they will have already initialized the project with
dq.init
and we will have the task type stored in the config fileThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look for other instances of where we do
get_data_logger().logger_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay 👌 yes this seems helpful!