Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/encoder decoder dq restructure #766

Closed
wants to merge 15 commits into from
Closed
81 changes: 58 additions & 23 deletions dataquality/integrations/seq2seq/hf.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
from typing import List, Optional
from typing import List, Optional, Union
from warnings import warn

from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast

from dataquality.loggers.logger_config.seq2seq import seq2seq_logger_config
from dataquality.exceptions import GalileoException
from dataquality.loggers.logger_config.seq2seq.encoder_decoder import (
EncoderDecoderLoggerConfig,
encoder_decoder_logger_config,
)
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.helpers import check_noop
from dataquality.utils.task_helpers import get_task_type


# TODO Sync with Elliott on how to differentiate between the
# encoder_decoder vs. decoder_only logger_configs in `watch`
def _get_seg2seg_logger_config(
task_type: TaskType,
) -> Union[EncoderDecoderLoggerConfig]:
"""Get the correct Seq2Seq logger_config based on the task_type.

Choices between:
1. EncoderDecoder: task_type.decoder_only
2. DecoderOnly: task_type.decoder_only

Raises an exception if the user has set / is using an incorrect task_type
"""
if task_type == task_type.seq2seq: # TODO Change to encoder_decoder
return encoder_decoder_logger_config

# TODO Change to encoder_decoder
raise GalileoException(
"Galileo's seq2seq watch method is only supported for seq2seq"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can just use the get current task type helpers, since they will have already initialized the project with dq.init and we will have the task type stored in the config file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look for other instances of where we do get_data_logger().logger_config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay 👌 yes this seems helpful!



@check_noop
def set_tokenizer(
tokenizer: PreTrainedTokenizerFast,
logger_config: Union[EncoderDecoderLoggerConfig],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wouldn't have logger config as a param for this cause this is technically a user facing fn and we wouldn't expect them to pass in a logger config

max_input_tokens: Optional[int] = None,
max_target_tokens: Optional[int] = None,
) -> None:
Expand All @@ -24,19 +51,18 @@ def set_tokenizer(
truncated after a certain length, which is set in the args max_input_tokens and
max_target_tokens.
"""
task_type = get_task_type()
assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq"
assert isinstance(
tokenizer, PreTrainedTokenizerFast
), "Tokenizer must be an instance of PreTrainedTokenizerFast"
assert getattr(tokenizer, "is_fast", False), "Tokenizer must be a fast tokenizer"
for attr in ["encode", "decode", "encode_plus", "padding_side"]:
assert hasattr(tokenizer, attr), f"Tokenizer must support `{attr}`"
seq2seq_logger_config.tokenizer = tokenizer
logger_config.tokenizer = tokenizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to get config we could call the get data logger config helper in this fn


seq2seq_logger_config.max_input_tokens = max_input_tokens
if seq2seq_logger_config.max_input_tokens is None:
seq2seq_logger_config.max_input_tokens = tokenizer.model_max_length
# This is relevant only for Encoder Decoder Models
logger_config.max_input_tokens = max_input_tokens
if logger_config.max_input_tokens is None:
logger_config.max_input_tokens = tokenizer.model_max_length
warn(
(
"The argument max_input_tokens is not set, we will use the value "
Expand All @@ -46,19 +72,26 @@ def set_tokenizer(
)
)

seq2seq_logger_config.max_target_tokens = max_target_tokens
if seq2seq_logger_config.max_target_tokens is None:
seq2seq_logger_config.max_target_tokens = tokenizer.model_max_length
warn(
(
"The argument max_target_tokens is not set, we will use the value "
f"{tokenizer.model_max_length} from tokenizer.model_max_length. If you "
"tokenized the target with another value, this can lead to confusing "
"insights about this training run."
if type(logger_config) == EncoderDecoderLoggerConfig:
logger_config.max_target_tokens = max_target_tokens
if logger_config.max_target_tokens is None:
logger_config.max_target_tokens = tokenizer.model_max_length
warn(
(
"The argument max_target_tokens is not set, we will use the value "
f"{tokenizer.model_max_length} from tokenizer.model_max_length. "
f"If you tokenized the target with another value, this can lead "
f"to confusing insights about this training run."
)
)
else:
warn(
"The argument max_target_tokens is only used when working with "
"EncoderDecoder models. This value will be ignored."
)

# Seq2Seq doesn't have labels but we need to set this to avoid validation errors
seq2seq_logger_config.labels = []
logger_config.labels = []


@check_noop
Expand All @@ -70,6 +103,7 @@ def watch(
max_input_tokens: Optional[int] = None,
max_target_tokens: Optional[int] = None,
) -> None:
# TODO Update comment
"""Seq2seq only. Log model generations for your run

Iterates over a given dataset and logs the generations for each sample.
Expand All @@ -81,16 +115,17 @@ def watch(
for consistency.
"""
task_type = get_task_type()
assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq"
# Get the corresponding logger config - handling error checking
logger_config = _get_seg2seg_logger_config(task_type)
assert isinstance(
model, PreTrainedModel
), "model must be an instance of transformers PreTrainedModel"
assert model.can_generate(), "model must contain a `generate` method for seq2seq"

set_tokenizer(tokenizer, max_input_tokens, max_target_tokens)
set_tokenizer(tokenizer, logger_config, max_input_tokens, max_target_tokens)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above!


seq2seq_logger_config.model = model
seq2seq_logger_config.generation_config = generation_config
logger_config.model = model
logger_config.generation_config = generation_config

generatation_splits = generatation_splits or []
generation_splits_set = {Split.test}
Expand All @@ -104,4 +139,4 @@ def watch(

generation_splits_set.add(Split[split])

seq2seq_logger_config.generation_splits = generation_splits_set
logger_config.generation_splits = generation_splits_set
5 changes: 3 additions & 2 deletions dataquality/loggers/data_logger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
image_classification,
object_detection,
semantic_segmentation,
seq2seq,
tabular_classification,
text_classification,
text_multi_label,
text_ner,
)
from dataquality.loggers.data_logger.base_data_logger import BaseGalileoDataLogger
from dataquality.loggers.data_logger.seq2seq import encoder_decoder, seq2seq

__all__ = [
"image_classification",
Expand All @@ -19,5 +19,6 @@
"text_ner",
"object_detection",
"BaseGalileoDataLogger",
"seq2seq",
"seq2seq", # TODO: Likely remove
"encoder_decoder",
]
Empty file.
137 changes: 137 additions & 0 deletions dataquality/loggers/data_logger/seq2seq/encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Optional

from vaex.dataframe import DataFrame

from dataquality.loggers.data_logger.base_data_logger import (
MetasType,
)
from dataquality.loggers.data_logger.seq2seq.seq2seq import Seq2SeqDataLogger
from dataquality.loggers.logger_config.seq2seq.encoder_decoder import (
EncoderDecoderLoggerConfig,
encoder_decoder_logger_config,
)
from dataquality.schemas.seq2seq import Seq2SeqInputCols as C
from dataquality.utils.seq2seq.offsets import (
align_tokens_to_character_spans,
get_cutoff_from_saved_offsets,
get_cutoff_from_truncated_tokenization,
)


class EncoderDecoderDataLogger(Seq2SeqDataLogger):
# TODO UPDATE COMMENT!!!
"""Logging input data for Seq2Seq fine-tuning tasks

Logging input data for Seq2Seq requires 2 pieces of information:
1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface
(ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an
`.is_fast` property that returns True if it's a fast tokenizer.
This class must implement the `encode`, `decode`, and `encode_plus` methods

You can set your tokenizer via the `set_tokenizer(tok)` function imported
from `dataquality.integrations.seq2seq.hf`
2. A dataset (pandas/huggingface etc) with input strings and output labels and ids.
Ex: Billsum dataset, with `text` input and `summary` as the label
id text summary
0 SECTION 1. LIABILITY ... Shields a business entity ...
1 SECTION 1. SHORT TITLE.\n\n ... Human Rights Information Act ...
2 SECTION 1. SHORT TITLE.\n\n ... Jackie Robinson Commemorative Coin ...
3 SECTION 1. NONRECOGNITION ... Amends the Internal Revenue Code to ...
4 SECTION 1. SHORT TITLE.\n\n ... Native American Energy Act - (Sec. 3...

You can log your dataset via the `dq.log_dataset` function, passing in the
column mapping as necessary for `text`, `label`, and `id`
`dq.log_dataset(ds, text="text", label="summary", id="id")`

Putting it all together:
from dataquality.integrations.seq2seq.hf import set_tokenizer
from datasets import load_dataset
from transformers import T5TokenizerFast

tokenizer = T5TokenizerFast.from_pretrained("t5-small")
ds = load_dataset("billsum")
# Add `id` column to each dataset split as the idx
ds = ds.map(lambda x,idx : {"id":idx},with_indices=True)
dq.init("seq2seq")
set_tokenizer(tokenizer)
dq.log_dataset(ds["train"], label="summary", split="train")

NOTE: We assume that the tokenizer you provide is the same tokenizer used for
training. This must be true in order to align inputs and outputs correctly. Ensure
all necessary properties (like `add_eos_token`) are set before setting your
tokenizer so as to match the tokenization process to your training process.
"""

# TODO Change to encoder_decoder after updating API
__logger_name__ = "seq2seq" # encoder_decoder
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed it should be encoder_decoder

logger_config: EncoderDecoderLoggerConfig = encoder_decoder_logger_config
DATA_FOLDER_EXTENSION = {"emb": "hdf5", "prob": "hdf5", "data": "arrow"}

def __init__(self, meta: Optional[MetasType] = None) -> None:
super().__init__(meta)

def validate_and_format(self) -> None:
"""Format Encoder-Decoder Data Format

Tokenize self.labels, using the user's `max_taget_tokens`. From
the tokenized outputs generate the corresponding token alignments
(i.e. label_offsets and lable_positions).

Save the tokenized labels for each sample as `id_to_tokens`. This
is essential during model logging for extracting GT token label
information.

Note: the parent Seq2SeqDataLogger.validate_and_format() handles
common data type validation.
"""
super().validate_and_format()
# TODO: question type checking does not work in super()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? we can look into this together

encoded_data = self.logger_config.tokenizer( # type: ignore
self.labels,
return_offsets_mapping=True,
max_length=self.logger_config.max_target_tokens,
truncation=True,
)
tokenized_labels = encoded_data["input_ids"]
aligned_data = align_tokens_to_character_spans(encoded_data["offset_mapping"])
self.token_label_offsets = aligned_data.token_label_offsets
self.token_label_positions = aligned_data.token_label_positions

id_to_tokens = dict(zip(self.ids, tokenized_labels))
self.logger_config.id_to_tokens[self.token_map_key].update(id_to_tokens)

@classmethod
def calculate_cutoffs(cls, df: DataFrame) -> DataFrame:
"""Calculate the cutoff index for the input and target strings.


When using Encoder-Decoder models, the input AND target tokens are truncated
based on the respective Encoder (input) / Decoder (target) max_lengths
OR user specified max_lengths (note: these may be different between the
Encoder and Decoder).

The model only "sees"/processes the tokens that remain after truncation,
for example if max_length=512 for the Encoder, no matter how long the Input,
the model will only process the first 512 tokens and ignore the rest.

This function adds two columns to the df:
- 'input_cutoff': the position of the last character in the input.
- 'target_cutoff': the position of the last character in the target.
"""
# Error checking
super().calculate_cutoffs(df)

# TODO we may be able to take advantage of shared code with Decoder
tokenizer = cls.logger_config.tokenizer
max_input_length = cls.logger_config.max_input_tokens
df[C.input_cutoff.value] = get_cutoff_from_truncated_tokenization(
df, C.text, tokenizer, max_input_length
)

target_offsets_colname = C.token_label_offsets
if target_offsets_colname in df.get_column_names():
df[C.target_cutoff.value] = get_cutoff_from_saved_offsets(
df, target_offsets_colname
)

return df
Loading