diff --git a/dataquality/integrations/seq2seq/hf.py b/dataquality/integrations/seq2seq/hf.py index b2fc974f8..2619b6489 100644 --- a/dataquality/integrations/seq2seq/hf.py +++ b/dataquality/integrations/seq2seq/hf.py @@ -3,7 +3,8 @@ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast -from dataquality.loggers.logger_config.seq2seq import seq2seq_logger_config +import dataquality +from dataquality.loggers.logger_config.seq2seq.seq2seq_base import Seq2SeqLoggerConfig from dataquality.schemas.split import Split from dataquality.schemas.task_type import TaskType from dataquality.utils.helpers import check_noop @@ -16,6 +17,7 @@ def set_tokenizer( max_input_tokens: Optional[int] = None, max_target_tokens: Optional[int] = None, ) -> None: + # TODO update """Seq2seq only. Set the tokenizer for your run Must be a fast tokenizer, and must support `decode`, `encode`, `encode_plus`. @@ -23,20 +25,35 @@ def set_tokenizer( We will use this tokenizer for both the input and the target. They will both be truncated after a certain length, which is set in the args max_input_tokens and max_target_tokens. + + 1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface + (ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an + `.is_fast` property that returns True if it's a fast tokenizer. + This class must implement the `encode`, `decode`, and `encode_plus` methods + + You can set your tokenizer via the `set_tokenizer(tok)` function imported + from `dataquality.integrations.seq2seq.hf` + + NOTE: We assume that the tokenizer you provide is the same tokenizer used for + training. This must be true in order to align inputs and outputs correctly. Ensure + all necessary properties (like `add_eos_token`) are set before setting your + tokenizer so as to match the tokenization process to your training process. """ - task_type = get_task_type() - assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq" assert isinstance( tokenizer, PreTrainedTokenizerFast ), "Tokenizer must be an instance of PreTrainedTokenizerFast" assert getattr(tokenizer, "is_fast", False), "Tokenizer must be a fast tokenizer" for attr in ["encode", "decode", "encode_plus", "padding_side"]: assert hasattr(tokenizer, attr), f"Tokenizer must support `{attr}`" - seq2seq_logger_config.tokenizer = tokenizer - seq2seq_logger_config.max_input_tokens = max_input_tokens - if seq2seq_logger_config.max_input_tokens is None: - seq2seq_logger_config.max_input_tokens = tokenizer.model_max_length + logger_config = dataquality.get_data_logger().logger_config + assert isinstance(logger_config, Seq2SeqLoggerConfig) + logger_config.tokenizer = tokenizer + + # This is relevant only for Encoder Decoder Models + logger_config.max_input_tokens = max_input_tokens + if logger_config.max_input_tokens is None: + logger_config.max_input_tokens = tokenizer.model_max_length warn( ( "The argument max_input_tokens is not set, we will use the value " @@ -46,19 +63,27 @@ def set_tokenizer( ) ) - seq2seq_logger_config.max_target_tokens = max_target_tokens - if seq2seq_logger_config.max_target_tokens is None: - seq2seq_logger_config.max_target_tokens = tokenizer.model_max_length - warn( - ( - "The argument max_target_tokens is not set, we will use the value " - f"{tokenizer.model_max_length} from tokenizer.model_max_length. If you " - "tokenized the target with another value, this can lead to confusing " - "insights about this training run." + current_task_type = get_task_type() + if current_task_type == TaskType.encoder_decoder: + logger_config.max_target_tokens = max_target_tokens + if logger_config.max_target_tokens is None: + logger_config.max_target_tokens = tokenizer.model_max_length + warn( + ( + "The argument max_target_tokens is not set, we will use the value " + f"{tokenizer.model_max_length} from tokenizer.model_max_length. " + f"If you tokenized the target with another value, this can lead " + f"to confusing insights about this training run." + ) ) + else: + warn( + "The argument max_target_tokens is only used when working with " + "EncoderDecoder models. This value will be ignored." ) + # Seq2Seq doesn't have labels but we need to set this to avoid validation errors - seq2seq_logger_config.labels = [] + logger_config.labels = [] @check_noop @@ -70,6 +95,7 @@ def watch( max_input_tokens: Optional[int] = None, max_target_tokens: Optional[int] = None, ) -> None: + # TODO Update comment """Seq2seq only. Log model generations for your run Iterates over a given dataset and logs the generations for each sample. @@ -80,8 +106,9 @@ def watch( and generation config and not attaching any hooks to the model. We call it 'watch' for consistency. """ - task_type = get_task_type() - assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq" + logger_config = dataquality.get_data_logger().logger_config + assert isinstance(logger_config, Seq2SeqLoggerConfig) + assert isinstance( model, PreTrainedModel ), "model must be an instance of transformers PreTrainedModel" @@ -89,8 +116,8 @@ def watch( set_tokenizer(tokenizer, max_input_tokens, max_target_tokens) - seq2seq_logger_config.model = model - seq2seq_logger_config.generation_config = generation_config + logger_config.model = model + logger_config.generation_config = generation_config generation_splits = generation_splits or [] generation_splits_set = {Split.test} @@ -104,4 +131,4 @@ def watch( generation_splits_set.add(Split[split]) - seq2seq_logger_config.generation_splits = generation_splits_set + logger_config.generation_splits = generation_splits_set diff --git a/dataquality/loggers/data_logger/__init__.py b/dataquality/loggers/data_logger/__init__.py index 80145a720..7dae8ff97 100644 --- a/dataquality/loggers/data_logger/__init__.py +++ b/dataquality/loggers/data_logger/__init__.py @@ -2,13 +2,13 @@ image_classification, object_detection, semantic_segmentation, - seq2seq, tabular_classification, text_classification, text_multi_label, text_ner, ) from dataquality.loggers.data_logger.base_data_logger import BaseGalileoDataLogger +from dataquality.loggers.data_logger.seq2seq import encoder_decoder __all__ = [ "image_classification", @@ -19,5 +19,5 @@ "text_ner", "object_detection", "BaseGalileoDataLogger", - "seq2seq", + "encoder_decoder", ] diff --git a/dataquality/loggers/data_logger/seq2seq/__init__.py b/dataquality/loggers/data_logger/seq2seq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py b/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py new file mode 100644 index 000000000..fe767f04a --- /dev/null +++ b/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py @@ -0,0 +1,146 @@ +from typing import Optional + +from vaex.dataframe import DataFrame + +from dataquality.loggers.data_logger.base_data_logger import ( + MetasType, +) +from dataquality.loggers.data_logger.seq2seq.seq2seq_base import Seq2SeqDataLogger +from dataquality.loggers.logger_config.seq2seq.encoder_decoder import ( + EncoderDecoderLoggerConfig, + encoder_decoder_logger_config, +) +from dataquality.schemas.seq2seq import Seq2SeqInputCols as C +from dataquality.utils.seq2seq.offsets import ( + align_tokens_to_character_spans, + get_cutoff_from_saved_offsets, + get_cutoff_from_truncated_tokenization, +) + + +class EncoderDecoderDataLogger(Seq2SeqDataLogger): + """Seq2Seq data logger for EncoderDecoder models + + Logging input data for EncoderDecoder models requires: + 1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface + (ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an + `.is_fast` property that returns True if it's a fast tokenizer. + This class must implement the `encode`, `decode`, and `encode_plus` methods + + You can set your tokenizer via the seq2seq `watch(..., tok, ...)` function + imported from `dataquality.integrations.seq2seq.hf` + 2. A two column (i.e. completion) dataset (pandas/huggingface etc) with string + 'text' (model / / , ...) and 'label' (model + / ( / ...) columns + a data sample id column. + Ex: Billsum dataset, with `text` and `summary` as the