diff --git a/dataquality/integrations/seq2seq/hf.py b/dataquality/integrations/seq2seq/hf.py
index b2fc974f8..2619b6489 100644
--- a/dataquality/integrations/seq2seq/hf.py
+++ b/dataquality/integrations/seq2seq/hf.py
@@ -3,7 +3,8 @@
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast
-from dataquality.loggers.logger_config.seq2seq import seq2seq_logger_config
+import dataquality
+from dataquality.loggers.logger_config.seq2seq.seq2seq_base import Seq2SeqLoggerConfig
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.helpers import check_noop
@@ -16,6 +17,7 @@ def set_tokenizer(
max_input_tokens: Optional[int] = None,
max_target_tokens: Optional[int] = None,
) -> None:
+ # TODO update
"""Seq2seq only. Set the tokenizer for your run
Must be a fast tokenizer, and must support `decode`, `encode`, `encode_plus`.
@@ -23,20 +25,35 @@ def set_tokenizer(
We will use this tokenizer for both the input and the target. They will both be
truncated after a certain length, which is set in the args max_input_tokens and
max_target_tokens.
+
+ 1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface
+ (ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an
+ `.is_fast` property that returns True if it's a fast tokenizer.
+ This class must implement the `encode`, `decode`, and `encode_plus` methods
+
+ You can set your tokenizer via the `set_tokenizer(tok)` function imported
+ from `dataquality.integrations.seq2seq.hf`
+
+ NOTE: We assume that the tokenizer you provide is the same tokenizer used for
+ training. This must be true in order to align inputs and outputs correctly. Ensure
+ all necessary properties (like `add_eos_token`) are set before setting your
+ tokenizer so as to match the tokenization process to your training process.
"""
- task_type = get_task_type()
- assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq"
assert isinstance(
tokenizer, PreTrainedTokenizerFast
), "Tokenizer must be an instance of PreTrainedTokenizerFast"
assert getattr(tokenizer, "is_fast", False), "Tokenizer must be a fast tokenizer"
for attr in ["encode", "decode", "encode_plus", "padding_side"]:
assert hasattr(tokenizer, attr), f"Tokenizer must support `{attr}`"
- seq2seq_logger_config.tokenizer = tokenizer
- seq2seq_logger_config.max_input_tokens = max_input_tokens
- if seq2seq_logger_config.max_input_tokens is None:
- seq2seq_logger_config.max_input_tokens = tokenizer.model_max_length
+ logger_config = dataquality.get_data_logger().logger_config
+ assert isinstance(logger_config, Seq2SeqLoggerConfig)
+ logger_config.tokenizer = tokenizer
+
+ # This is relevant only for Encoder Decoder Models
+ logger_config.max_input_tokens = max_input_tokens
+ if logger_config.max_input_tokens is None:
+ logger_config.max_input_tokens = tokenizer.model_max_length
warn(
(
"The argument max_input_tokens is not set, we will use the value "
@@ -46,19 +63,27 @@ def set_tokenizer(
)
)
- seq2seq_logger_config.max_target_tokens = max_target_tokens
- if seq2seq_logger_config.max_target_tokens is None:
- seq2seq_logger_config.max_target_tokens = tokenizer.model_max_length
- warn(
- (
- "The argument max_target_tokens is not set, we will use the value "
- f"{tokenizer.model_max_length} from tokenizer.model_max_length. If you "
- "tokenized the target with another value, this can lead to confusing "
- "insights about this training run."
+ current_task_type = get_task_type()
+ if current_task_type == TaskType.encoder_decoder:
+ logger_config.max_target_tokens = max_target_tokens
+ if logger_config.max_target_tokens is None:
+ logger_config.max_target_tokens = tokenizer.model_max_length
+ warn(
+ (
+ "The argument max_target_tokens is not set, we will use the value "
+ f"{tokenizer.model_max_length} from tokenizer.model_max_length. "
+ f"If you tokenized the target with another value, this can lead "
+ f"to confusing insights about this training run."
+ )
)
+ else:
+ warn(
+ "The argument max_target_tokens is only used when working with "
+ "EncoderDecoder models. This value will be ignored."
)
+
# Seq2Seq doesn't have labels but we need to set this to avoid validation errors
- seq2seq_logger_config.labels = []
+ logger_config.labels = []
@check_noop
@@ -70,6 +95,7 @@ def watch(
max_input_tokens: Optional[int] = None,
max_target_tokens: Optional[int] = None,
) -> None:
+ # TODO Update comment
"""Seq2seq only. Log model generations for your run
Iterates over a given dataset and logs the generations for each sample.
@@ -80,8 +106,9 @@ def watch(
and generation config and not attaching any hooks to the model. We call it 'watch'
for consistency.
"""
- task_type = get_task_type()
- assert task_type == TaskType.seq2seq, "This method is only supported for seq2seq"
+ logger_config = dataquality.get_data_logger().logger_config
+ assert isinstance(logger_config, Seq2SeqLoggerConfig)
+
assert isinstance(
model, PreTrainedModel
), "model must be an instance of transformers PreTrainedModel"
@@ -89,8 +116,8 @@ def watch(
set_tokenizer(tokenizer, max_input_tokens, max_target_tokens)
- seq2seq_logger_config.model = model
- seq2seq_logger_config.generation_config = generation_config
+ logger_config.model = model
+ logger_config.generation_config = generation_config
generation_splits = generation_splits or []
generation_splits_set = {Split.test}
@@ -104,4 +131,4 @@ def watch(
generation_splits_set.add(Split[split])
- seq2seq_logger_config.generation_splits = generation_splits_set
+ logger_config.generation_splits = generation_splits_set
diff --git a/dataquality/loggers/data_logger/__init__.py b/dataquality/loggers/data_logger/__init__.py
index 80145a720..7dae8ff97 100644
--- a/dataquality/loggers/data_logger/__init__.py
+++ b/dataquality/loggers/data_logger/__init__.py
@@ -2,13 +2,13 @@
image_classification,
object_detection,
semantic_segmentation,
- seq2seq,
tabular_classification,
text_classification,
text_multi_label,
text_ner,
)
from dataquality.loggers.data_logger.base_data_logger import BaseGalileoDataLogger
+from dataquality.loggers.data_logger.seq2seq import encoder_decoder
__all__ = [
"image_classification",
@@ -19,5 +19,5 @@
"text_ner",
"object_detection",
"BaseGalileoDataLogger",
- "seq2seq",
+ "encoder_decoder",
]
diff --git a/dataquality/loggers/data_logger/seq2seq/__init__.py b/dataquality/loggers/data_logger/seq2seq/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py b/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py
new file mode 100644
index 000000000..fe767f04a
--- /dev/null
+++ b/dataquality/loggers/data_logger/seq2seq/encoder_decoder.py
@@ -0,0 +1,146 @@
+from typing import Optional
+
+from vaex.dataframe import DataFrame
+
+from dataquality.loggers.data_logger.base_data_logger import (
+ MetasType,
+)
+from dataquality.loggers.data_logger.seq2seq.seq2seq_base import Seq2SeqDataLogger
+from dataquality.loggers.logger_config.seq2seq.encoder_decoder import (
+ EncoderDecoderLoggerConfig,
+ encoder_decoder_logger_config,
+)
+from dataquality.schemas.seq2seq import Seq2SeqInputCols as C
+from dataquality.utils.seq2seq.offsets import (
+ align_tokens_to_character_spans,
+ get_cutoff_from_saved_offsets,
+ get_cutoff_from_truncated_tokenization,
+)
+
+
+class EncoderDecoderDataLogger(Seq2SeqDataLogger):
+ """Seq2Seq data logger for EncoderDecoder models
+
+ Logging input data for EncoderDecoder models requires:
+ 1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface
+ (ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an
+ `.is_fast` property that returns True if it's a fast tokenizer.
+ This class must implement the `encode`, `decode`, and `encode_plus` methods
+
+ You can set your tokenizer via the seq2seq `watch(..., tok, ...)` function
+ imported from `dataquality.integrations.seq2seq.hf`
+ 2. A two column (i.e. completion) dataset (pandas/huggingface etc) with string
+ 'text' (model / / , ...) and 'label' (model
+ / ( / ...) columns + a data sample id column.
+ Ex: Billsum dataset, with `text` and `summary` as the
+ id text summary
+ 0 SECTION 1. LIABILITY ... Shields a business entity ...
+ 1 SECTION 1. SHORT TITLE.\n\n ... Human Rights Information Act ...
+ 2 SECTION 1. SHORT TITLE.\n\n ... Jackie Robinson Commemorative Coin ...
+ 3 SECTION 1. NONRECOGNITION ... Amends the Internal Revenue Code to ...
+ 4 SECTION 1. SHORT TITLE.\n\n ... Native American Energy Act - (Sec. 3...
+
+ You can log your dataset via the `dq.log_dataset` function, passing in the
+ column mapping as necessary for `text`, `label`, and `id`
+ `dq.log_dataset(ds, text="text", label="summary", id="id")`
+
+ Putting it all together:
+ from dataquality.integrations.seq2seq.hf import watch
+ from datasets import load_dataset
+ from transformers import T5TokenizerFast
+
+ tokenizer = T5TokenizerFast.from_pretrained("t5-small")
+ ds = load_dataset("billsum")
+ # Add `id` column to each dataset split as the idx
+ ds = ds.map(lambda x,idx : {"id":idx},with_indices=True)
+ dq.init("seq2seq")
+ # See `watch` for additional input parameters
+ watch(
+ ...,
+ tokenizer,
+ ...
+ )
+ dq.log_dataset(ds["train"], label="summary", split="train")
+
+ NOTE: We assume that the tokenizer you provide is the same tokenizer used for
+ training. This must be true in order to align inputs and outputs correctly. Ensure
+ all necessary properties (like `add_eos_token`) are set before setting your
+ tokenizer so as to match the tokenization process to your training process.
+
+ NOTE 2: Unlike EncoderOnly models, EncoderDecoder models explicitly separate the
+ processing of the and data. Therefore, we do not need any
+ additional information to isolate / extract information on the data.
+ """
+
+ __logger_name__ = "encoder_decoder"
+ logger_config: EncoderDecoderLoggerConfig = encoder_decoder_logger_config
+ DATA_FOLDER_EXTENSION = {"emb": "hdf5", "prob": "hdf5", "data": "arrow"}
+
+ def __init__(self, meta: Optional[MetasType] = None) -> None:
+ super().__init__(meta)
+
+ def validate_and_format(self) -> None:
+ """Format Encoder-Decoder Data Format
+
+ Tokenize self.labels, using the user's `max_taget_tokens`. From
+ the tokenized outputs generate the corresponding token alignments
+ (i.e. label_offsets and lable_positions).
+
+ Save the tokenized labels for each sample as `id_to_tokens`. This
+ is essential during model logging for extracting GT token label
+ information.
+
+ Note: the parent Seq2SeqDataLogger.validate_and_format() handles
+ common data type validation.
+ """
+ super().validate_and_format()
+ # We ensure tokenizer is set in the parent class
+ encoded_data = self.logger_config.tokenizer( # type: ignore
+ self.labels,
+ return_offsets_mapping=True,
+ max_length=self.logger_config.max_target_tokens,
+ truncation=True,
+ )
+ tokenized_labels = encoded_data["input_ids"]
+ aligned_data = align_tokens_to_character_spans(encoded_data["offset_mapping"])
+ self.token_label_offsets = aligned_data.token_label_offsets
+ self.token_label_positions = aligned_data.token_label_positions
+
+ id_to_tokens = dict(zip(self.ids, tokenized_labels))
+ self.logger_config.id_to_tokens[self.token_map_key].update(id_to_tokens)
+
+ @classmethod
+ def calculate_cutoffs(cls, df: DataFrame) -> DataFrame:
+ """Calculate the cutoff index for the input and target strings.
+
+
+ When using Encoder-Decoder models, the input AND target tokens are truncated
+ based on the respective Encoder (input) / Decoder (target) max_lengths
+ OR user specified max_lengths (note: these may be different between the
+ Encoder and Decoder).
+
+ The model only "sees"/processes the tokens that remain after truncation,
+ for example if max_length=512 for the Encoder, no matter how long the Input,
+ the model will only process the first 512 tokens and ignore the rest.
+
+ This function adds two columns to the df:
+ - 'input_cutoff': the position of the last character in the input.
+ - 'target_cutoff': the position of the last character in the target.
+ """
+ # Error checking
+ super().calculate_cutoffs(df)
+
+ # TODO we may be able to take advantage of shared code with Decoder
+ tokenizer = cls.logger_config.tokenizer
+ max_input_length = cls.logger_config.max_input_tokens
+ df[C.input_cutoff.value] = get_cutoff_from_truncated_tokenization(
+ df, C.text, tokenizer, max_input_length
+ )
+
+ target_offsets_colname = C.token_label_offsets
+ if target_offsets_colname in df.get_column_names():
+ df[C.target_cutoff.value] = get_cutoff_from_saved_offsets(
+ df, target_offsets_colname
+ )
+
+ return df
diff --git a/dataquality/loggers/data_logger/seq2seq.py b/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
similarity index 69%
rename from dataquality/loggers/data_logger/seq2seq.py
rename to dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
index bd9811c78..9c198164c 100644
--- a/dataquality/loggers/data_logger/seq2seq.py
+++ b/dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
@@ -12,7 +12,7 @@
DataSet,
MetasType,
)
-from dataquality.loggers.logger_config.seq2seq import (
+from dataquality.loggers.logger_config.seq2seq.seq2seq_base import (
Seq2SeqLoggerConfig,
seq2seq_logger_config,
)
@@ -22,11 +22,6 @@
from dataquality.utils.seq2seq.generation import (
add_generated_output_to_df,
)
-from dataquality.utils.seq2seq.offsets import (
- align_tokens_to_character_spans,
- get_cutoff_from_saved_offsets,
- get_cutoff_from_truncated_tokenization,
-)
from dataquality.utils.vaex import rename_df
if TYPE_CHECKING:
@@ -34,49 +29,40 @@
class Seq2SeqDataLogger(BaseGalileoDataLogger):
- """Logging input data for Seq2Seq fine-tuning tasks
-
- Logging input data for Seq2Seq requires 2 pieces of information:
- 1. tokenizer: This must be an instance of PreTrainedTokenizerFast from huggingface
- (ie T5TokenizerFast or GPT2TokenizerFast, etc). Your tokenizer should have an
- `.is_fast` property that returns True if it's a fast tokenizer.
- This class must implement the `encode`, `decode`, and `encode_plus` methods
-
- You can set your tokenizer via the `set_tokenizer(tok)` function imported
- from `dataquality.integrations.seq2seq.hf`
- 2. A dataset (pandas/huggingface etc) with input strings and output labels and ids.
- Ex: Billsum dataset, with `text` input and `summary` as the label
- id text summary
- 0 SECTION 1. LIABILITY ... Shields a business entity ...
- 1 SECTION 1. SHORT TITLE.\n\n ... Human Rights Information Act ...
- 2 SECTION 1. SHORT TITLE.\n\n ... Jackie Robinson Commemorative Coin ...
- 3 SECTION 1. NONRECOGNITION ... Amends the Internal Revenue Code to ...
- 4 SECTION 1. SHORT TITLE.\n\n ... Native American Energy Act - (Sec. 3...
-
- You can log your dataset via the `dq.log_dataset` function, passing in the
- column mapping as necessary for `text`, `label`, and `id`
- `dq.log_dataset(ds, text="text", label="summary", id="id")`
-
- Putting it all together:
- from dataquality.integrations.seq2seq.hf import set_tokenizer
- from datasets import load_dataset
- from transformers import T5TokenizerFast
-
- tokenizer = T5TokenizerFast.from_pretrained("t5-small")
- ds = load_dataset("billsum")
- # Add `id` column to each dataset split as the idx
- ds = ds.map(lambda x,idx : {"id":idx},with_indices=True)
- dq.init("seq2seq")
- set_tokenizer(tokenizer)
- dq.log_dataset(ds["train"], label="summary", split="train")
-
- NOTE: We assume that the tokenizer you provide is the same tokenizer used for
- training. This must be true in order to align inputs and outputs correctly. Ensure
- all necessary properties (like `add_eos_token`) are set before setting your
- tokenizer so as to match the tokenization process to your training process.
+ """Seq2Seq base data logger
+
+ This class defines the base functionality for logging input data in Seq2Seq
+ tasks - i.e. shared between EncoderDecoder and DecoderOnly architectures.
+
+ At its core, Seq2Seq data logging expects the user's tokenizer (logged through
+ the provided 'watch' integration) and expects the dataset to be formatted
+ as a two column datasets - corresponding to Inputs and Targets.
+
+ During processing, we use the tokenizer to tokenize the Target data (used later
+ during model output logging) and prepare for the alignment of token-level and
+ string character level information.
+
+ After processing, the following key information is extracted:
+ - ids
+ - texts: corresponding to the data column
+ - labels: corresponding to the data column
+ - token_label_offsets + token_label_positions: used for alignment of
+ token level and string character level information within the UI. Note
+ this only applies to the data.
+
+ Additionally, we critically save the tokenized Target data as the ground truth
+ "labels" for model output logging.
+
+ While much of the general Seq2Seq logic can be shared between EncoderDecoder and
+ DecoderOnly models, there are nuances and specific information that differentiate
+ them. Therefore, the following abstract functions must be overridden by subclasses
+ - validate_and_format
+ - calculate_cutoffs
+
+ Note that some shared functionality is implemented here - generally around error
+ handling.
"""
- __logger_name__ = "seq2seq"
logger_config: Seq2SeqLoggerConfig = seq2seq_logger_config
DATA_FOLDER_EXTENSION = {"emb": "hdf5", "prob": "hdf5", "data": "arrow"}
@@ -97,6 +83,11 @@ def token_map_key(self) -> str:
return str(self.split)
def validate_and_format(self) -> None:
+ """Validation backbone for Seq2Seq
+
+ Provides basic validation checking across Seq2Seq tasks.
+ See sub_classes for formatting and further validation.
+ """
super().validate_and_format()
label_len = len(self.labels)
text_len = len(self.texts)
@@ -109,19 +100,6 @@ def validate_and_format(self) -> None:
"You must set your tokenizer before logging. "
"Use `dq.integrations.seq2seq.hf.set_tokenizer`"
)
- encoded_data = self.logger_config.tokenizer(
- self.labels,
- return_offsets_mapping=True,
- max_length=self.logger_config.max_target_tokens,
- truncation=True,
- )
- tokenized_labels = encoded_data["input_ids"]
- aligned_data = align_tokens_to_character_spans(encoded_data["offset_mapping"])
- self.token_label_offsets = aligned_data.token_label_offsets
- self.token_label_positions = aligned_data.token_label_positions
-
- id_to_tokens = dict(zip(self.ids, tokenized_labels))
- self.logger_config.id_to_tokens[self.token_map_key].update(id_to_tokens)
def _get_input_df(self) -> DataFrame:
return vaex.from_dict(
@@ -231,7 +209,9 @@ def create_in_out_frames(
)
@classmethod
- def add_generated_output_to_df(cls, df: DataFrame, split: str) -> DataFrame:
+ def add_generated_output_to_df(
+ cls, df: DataFrame, split: str
+ ) -> Optional[DataFrame]:
"""Adds the generated output to the dataframe
Adds the generated output to the dataframe, and also adds the
`token_label_positions` column
@@ -296,14 +276,22 @@ def separate_dataframe(
@classmethod
def calculate_cutoffs(cls, df: DataFrame) -> DataFrame:
- """
- Calculate the cutoff index of the input and target strings that were used by
- the model. The input/target are typically truncated and the model will only look
- at the first n characters, for example from the beginning until we reach 512
- tokens.
- This function adds two columns to the dataframe:
- - 'input_cutoff': the position of the last character in the input
- - 'target_cutoff': the position of the last character in the target
+ """Calculates cuttoff indexes for the input and/or target string.
+
+ Transformer models (or sub-modules) are trained over a maximum number of
+ tokens / sequence length. This max_length controls the maximum number of
+ tokens that the transformer model can process / "see." During training,
+ the tokenizer uses this max_length to truncate additional tokens - so any
+ tokens beyond the max token length are fully ignored.
+
+ `calculate_cutoffs` adds relevant max_length information at the string
+ character level for the `target` and/or `input` columns. This character
+ info communicates to the UI how much of the respective string gets "seen"
+ during processing by the model.
+
+ In this abstract definition, we provide very basic error checking.
+
+ See sub_classes (EncoderDecoder and DecoderOnly) for model specific details.
"""
tokenizer = cls.logger_config.tokenizer
if tokenizer is None:
@@ -311,16 +299,6 @@ def calculate_cutoffs(cls, df: DataFrame) -> DataFrame:
"You must set your tokenizer before calling dq.finish. Use "
"`dataquality.integrations.seq2seq.hf.watch`"
)
- max_input_length = cls.logger_config.max_input_tokens
- df[C.input_cutoff.value] = get_cutoff_from_truncated_tokenization(
- df, C.text, tokenizer, max_input_length
- )
-
- target_offsets_colname = C.token_label_offsets
- if target_offsets_colname in df.get_column_names():
- df[C.target_cutoff.value] = get_cutoff_from_saved_offsets(
- df, target_offsets_colname
- )
return df
diff --git a/dataquality/loggers/logger_config/seq2seq/__init__.py b/dataquality/loggers/logger_config/seq2seq/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/dataquality/loggers/logger_config/seq2seq/encoder_decoder.py b/dataquality/loggers/logger_config/seq2seq/encoder_decoder.py
new file mode 100644
index 000000000..e200c039a
--- /dev/null
+++ b/dataquality/loggers/logger_config/seq2seq/encoder_decoder.py
@@ -0,0 +1,12 @@
+from dataquality.loggers.logger_config.seq2seq.seq2seq_base import Seq2SeqLoggerConfig
+
+
+class EncoderDecoderLoggerConfig(Seq2SeqLoggerConfig):
+ """Encoder Decoder logger config
+
+ For now, the Encoder Decoder logger config has the same fields
+ as the base Seq2Seq logger config
+ """
+
+
+encoder_decoder_logger_config = EncoderDecoderLoggerConfig()
diff --git a/dataquality/loggers/logger_config/seq2seq.py b/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py
similarity index 100%
rename from dataquality/loggers/logger_config/seq2seq.py
rename to dataquality/loggers/logger_config/seq2seq/seq2seq_base.py
diff --git a/dataquality/loggers/model_logger/__init__.py b/dataquality/loggers/model_logger/__init__.py
index a96aa75f9..086faa2d4 100644
--- a/dataquality/loggers/model_logger/__init__.py
+++ b/dataquality/loggers/model_logger/__init__.py
@@ -1,12 +1,12 @@
from dataquality.loggers.model_logger import (
image_classification,
- seq2seq,
tabular_classification,
text_classification,
text_multi_label,
text_ner,
)
from dataquality.loggers.model_logger.base_model_logger import BaseGalileoModelLogger
+from dataquality.loggers.model_logger.seq2seq import encoder_decoder, seq2seq_base
__all__ = [
"image_classification",
@@ -15,5 +15,6 @@
"text_multi_label",
"text_ner",
"BaseGalileoModelLogger",
- "seq2seq",
+ "seq2seq_base", # TODO Likely remove
+ "encoder_decoder",
]
diff --git a/dataquality/loggers/model_logger/seq2seq/__init__.py b/dataquality/loggers/model_logger/seq2seq/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/dataquality/loggers/model_logger/seq2seq/encoder_decoder.py b/dataquality/loggers/model_logger/seq2seq/encoder_decoder.py
new file mode 100644
index 000000000..2cec56fe3
--- /dev/null
+++ b/dataquality/loggers/model_logger/seq2seq/encoder_decoder.py
@@ -0,0 +1,57 @@
+from typing import List, Optional, Union
+
+import numpy as np
+
+from dataquality.loggers.logger_config.seq2seq.encoder_decoder import (
+ EncoderDecoderLoggerConfig,
+ encoder_decoder_logger_config,
+)
+from dataquality.loggers.model_logger.seq2seq.seq2seq_base import Seq2SeqModelLogger
+
+
+class EncoderDecoderModelLogger(Seq2SeqModelLogger):
+ __logger_name__ = "encoder_decoder"
+ logger_config: EncoderDecoderLoggerConfig = encoder_decoder_logger_config
+ log_file_ext = "arrow"
+
+ def __init__(
+ self,
+ embs: Optional[Union[List, np.ndarray]] = None,
+ probs: Optional[Union[List, np.ndarray]] = None,
+ logits: Optional[Union[List, np.ndarray]] = None,
+ ids: Optional[Union[List, np.ndarray]] = None,
+ split: str = "",
+ epoch: Optional[int] = None,
+ inference_name: Optional[str] = None,
+ labels: Optional[np.ndarray] = None,
+ ) -> None:
+ super().__init__(
+ embs=embs,
+ probs=probs,
+ logits=logits,
+ ids=ids,
+ split=split,
+ epoch=epoch,
+ inference_name=inference_name,
+ labels=labels,
+ )
+
+ def validate_and_format(self) -> None:
+ """Compute token level log-prob info for Encoder-Decoder Models
+
+ Encoder-Decoder models output `logits` just over the target tokens.
+ Therefore, we can very easily extract token log-prob info without
+ any additional data formatting / token splitting.
+ """
+ super().validate_and_format()
+
+ # TODO: [JON] computing softmax on GPU can lead to speedups of ~5x
+ # TODO: Question, the validation done in the parent class does not seem
+ # to propigate. Here e.g. we convert ids to np.array in super()
+ logprobs = self.convert_logits_to_logprobs(self.logits)
+ (
+ self.token_logprobs,
+ self.top_logprobs,
+ ) = self.process_logprobs(
+ self.ids, logprobs # type: ignore
+ )
diff --git a/dataquality/loggers/model_logger/seq2seq.py b/dataquality/loggers/model_logger/seq2seq/seq2seq_base.py
similarity index 91%
rename from dataquality/loggers/model_logger/seq2seq.py
rename to dataquality/loggers/model_logger/seq2seq/seq2seq_base.py
index 521254bfd..85021156b 100644
--- a/dataquality/loggers/model_logger/seq2seq.py
+++ b/dataquality/loggers/model_logger/seq2seq/seq2seq_base.py
@@ -4,7 +4,7 @@
import pyarrow as pa
from scipy.special import log_softmax
-from dataquality.loggers.logger_config.seq2seq import (
+from dataquality.loggers.logger_config.seq2seq.seq2seq_base import (
Seq2SeqLoggerConfig,
seq2seq_logger_config,
)
@@ -23,7 +23,6 @@
class Seq2SeqModelLogger(BaseGalileoModelLogger):
- __logger_name__ = "seq2seq"
logger_config: Seq2SeqLoggerConfig = seq2seq_logger_config
log_file_ext = "arrow"
@@ -59,7 +58,12 @@ def token_map_key(self) -> str:
return str(self.split)
def validate_and_format(self) -> None:
- """Validate the lengths, calculate token level dep, extract GT probs"""
+ """Validate shared data format for seq2seq
+
+ Note that this base fn does the validation
+ Sub classes that inherit from this can have modality specific
+ formatting (such as EncoderDecoder and DecoderOnly)
+ """
if self.labels is not None:
self.labels = self._convert_tensor_ndarray(self.labels)
self.logits = self._convert_tensor_ndarray(self.logits)
@@ -75,20 +79,10 @@ def validate_and_format(self) -> None:
self.logger_config.tokenizer is not None
), "Must set your tokenizer. Use `dq.integrations.seq2seq.hf.set_tokenizer`"
- # TODO: This is potentially slow. This is what needs to be optimized. Can we
- # potentially do this on the GPU with torch? And dont convert to a np array
- # [JON] computing softmax on GPU can lead to speedups of around 5x in my
- # experience
- logprobs = self.convert_logits_to_logprobs(self.logits)
- (
- self.token_logprobs,
- self.top_logprobs,
- ) = self.process_logprobs(self.ids, logprobs)
-
def process_logprobs(
self, batch_ids: np.ndarray, batch_logprobs: np.ndarray
) -> Tuple[pa.array, pa.array]:
- """Handle processing of a batch of sample logprobs
+ """Handle processing for a batch of sample logprobs
For each sample in the batch extract / compute the following values:
- Token level logprobs for the GT label
diff --git a/dataquality/schemas/task_type.py b/dataquality/schemas/task_type.py
index ffb8b03ed..9d70b19bf 100644
--- a/dataquality/schemas/task_type.py
+++ b/dataquality/schemas/task_type.py
@@ -14,8 +14,9 @@ class TaskType(str, Enum):
object_detection = "object_detection"
semantic_segmentation = "semantic_segmentation"
prompt_evaluation = "prompt_evaluation"
- seq2seq = "seq2seq"
+ seq2seq = "seq2seq" # deprecated, use encoder_decoder or decoder_only
llm_monitor = "llm_monitor"
+ encoder_decoder = "encoder_decoder"
@staticmethod
def get_valid_tasks() -> List["TaskType"]:
@@ -38,6 +39,7 @@ def get_mapping(task_int: int) -> "TaskType":
5: TaskType.object_detection,
6: TaskType.semantic_segmentation,
7: TaskType.prompt_evaluation,
- 8: TaskType.seq2seq,
+ 8: TaskType.seq2seq, # deprecated
9: TaskType.llm_monitor,
+ 10: TaskType.encoder_decoder,
}[task_int]
diff --git a/tests/loggers/test_seq2seq.py b/tests/loggers/test_seq2seq.py
index 1da645809..996dbf9b0 100644
--- a/tests/loggers/test_seq2seq.py
+++ b/tests/loggers/test_seq2seq.py
@@ -12,9 +12,16 @@
import dataquality as dq
from dataquality.integrations.seq2seq.hf import set_tokenizer, watch
from dataquality.loggers.data_logger.base_data_logger import DataSet
-from dataquality.loggers.data_logger.seq2seq import Seq2SeqDataLogger
-from dataquality.loggers.logger_config.seq2seq import seq2seq_logger_config
-from dataquality.loggers.model_logger.seq2seq import Seq2SeqModelLogger
+from dataquality.loggers.data_logger.seq2seq.encoder_decoder import (
+ EncoderDecoderDataLogger,
+)
+from dataquality.loggers.data_logger.seq2seq.seq2seq_base import Seq2SeqDataLogger
+from dataquality.loggers.logger_config.seq2seq.encoder_decoder import (
+ encoder_decoder_logger_config,
+)
+from dataquality.loggers.model_logger.seq2seq.encoder_decoder import (
+ EncoderDecoderModelLogger,
+)
from dataquality.schemas.seq2seq import (
TOP_K,
BatchGenerationData,
@@ -56,18 +63,19 @@
),
],
)
-def test_log_dataset(
+def test_log_dataset_encoder_decoder(
dataset: DataSet,
set_test_config: Callable,
cleanup_after_use: Callable,
test_session_vars: TestSessionVariables,
) -> None:
+ # TODO Test with watch
set_test_config(task_type="seq2seq")
- logger = Seq2SeqDataLogger()
+ logger = EncoderDecoderDataLogger()
with patch("dataquality.core.log.get_data_logger") as mock_method:
mock_method.return_value = logger
- set_tokenizer(tokenizer)
+ set_tokenizer(tokenizer, encoder_decoder_logger_config)
dq.log_dataset(
dataset, text="summary", label="title", id="my_id", split="train"
)
@@ -98,6 +106,7 @@ def test_log_dataset_no_tokenizer(set_test_config: Callable) -> None:
"my_id": [1, 2, 3],
}
)
+ # Note this functionality is tested fully by the Seq2Seq parent class
logger = Seq2SeqDataLogger()
with patch("dataquality.core.log.get_data_logger") as mock_method:
mock_method.return_value = logger
@@ -109,11 +118,12 @@ def test_log_dataset_no_tokenizer(set_test_config: Callable) -> None:
)
-def test_log_model_outputs(
+def test_log_model_outputs_encoder_decoder(
set_test_config: Callable,
cleanup_after_use: Callable,
test_session_vars: TestSessionVariables,
) -> None:
+ # TODO Add commment
set_test_config(task_type="seq2seq")
tokenized_labels = [
@@ -146,7 +156,7 @@ def test_log_model_outputs(
split="training",
epoch=0,
)
- logger = Seq2SeqModelLogger(**log_data)
+ logger = EncoderDecoderModelLogger(**log_data)
logger.logger_config = config
with patch("dataquality.core.log.get_model_logger") as mock_method:
mock_method.return_value = logger
@@ -275,6 +285,7 @@ def test_tokenize_input_provide_maxlength(
set_test_config: Callable,
cleanup_after_use: Generator,
) -> None:
+ # TODO comment!
"""
Test that as we generate output and the user provided the max_input_tokens argument,
the input is tokenized correctly to the length set by max_input_tokens.
@@ -286,13 +297,14 @@ def test_tokenize_input_provide_maxlength(
mock_model.generate.return_value = seq2seq_generated_output
mock_generation_config = Mock(spec=GenerationConfig)
- set_tokenizer(tokenizer_T5, max_input_tokens=7)
+ # TODO: for now encoder_decoder covers general case
+ set_tokenizer(tokenizer_T5, encoder_decoder_logger_config, max_input_tokens=7)
input_text = "a b c d e f g h i j"
generate_sample_output(
input_text,
mock_model,
tokenizer_T5,
- seq2seq_logger_config.max_input_tokens,
+ encoder_decoder_logger_config.max_input_tokens,
mock_generation_config,
)
@@ -328,13 +340,14 @@ def test_tokenize_input_doesnt_provide_maxlength(
mock_model.generate.return_value = seq2seq_generated_output
mock_generation_config = Mock(spec=GenerationConfig)
- set_tokenizer(tokenizer_T5)
+ # TODO: for now encoder_decoder covers general case
+ set_tokenizer(tokenizer_T5, encoder_decoder_logger_config)
input_text = "a b c d e f g h i j" * 100
generate_sample_output(
input_text,
mock_model,
tokenizer_T5,
- seq2seq_logger_config.max_input_tokens,
+ encoder_decoder_logger_config.max_input_tokens,
mock_generation_config,
)
@@ -351,9 +364,10 @@ def test_tokenize_input_doesnt_provide_maxlength(
mock_process_sample_logprobs.assert_called_once()
-def test_tokenize_target_provide_maxlength(
+def test_tokenize_target_provide_maxlength_encoder_decoder(
set_test_config: Callable, cleanup_after_use: Generator
) -> None:
+ # TODO Update based on hf support for encoder-decoder vs. decoder-only
"""
Test that the target is tokenized correctly to the length provided by the user in
the max_target_tokens argument.
@@ -370,26 +384,29 @@ def test_tokenize_target_provide_maxlength(
)
dq.log_dataset(ds, text="input", label="target", split="train")
- assert set(seq2seq_logger_config.id_to_tokens["training"]) == {0, 1}
- assert len(seq2seq_logger_config.id_to_tokens["training"][0]) == 7
+ assert set(encoder_decoder_logger_config.id_to_tokens["training"]) == {0, 1}
+ assert len(encoder_decoder_logger_config.id_to_tokens["training"][0]) == 7
# Check that it has two tokens: the token "2" + EOS token
- assert len(seq2seq_logger_config.id_to_tokens["training"][1]) == 2
+ assert len(encoder_decoder_logger_config.id_to_tokens["training"][1]) == 2
# Check that both sentences end with the same EOS token
assert (
- seq2seq_logger_config.id_to_tokens["training"][0][-1]
- == seq2seq_logger_config.id_to_tokens["training"][1][-1]
+ encoder_decoder_logger_config.id_to_tokens["training"][0][-1]
+ == encoder_decoder_logger_config.id_to_tokens["training"][1][-1]
)
-def test_tokenize_target_doesnt_provide_maxlength(
+def test_tokenize_target_doesnt_provide_maxlength_encoder_decoder(
set_test_config: Callable, cleanup_after_use: Generator
) -> None:
+ # TODO Update based on hf support for encoder-decoder vs. decoder-only
"""
Test that the target is tokenized correctly when the user does not provide a
max_target_tokens argument, i.e., to the length set by default in the tokenizer.
"""
set_test_config(task_type=TaskType.seq2seq)
mock_generation_config = Mock(spec=GenerationConfig)
+ # TODO Does using a real model here take a lot of time?
+ # should we just mock the model and add a max length?
watch(model_T5, tokenizer_T5, mock_generation_config)
ds = Dataset.from_dict(
{
@@ -400,23 +417,26 @@ def test_tokenize_target_doesnt_provide_maxlength(
)
dq.log_dataset(ds, text="input", label="target", split="train")
- assert set(seq2seq_logger_config.id_to_tokens["training"]) == {0, 1}
+ assert set(encoder_decoder_logger_config.id_to_tokens["training"]) == {0, 1}
# Make sure that the target is large enough to require truncation
assert len(ds["target"][0]) > tokenizer_T5.model_max_length
assert (
- len(seq2seq_logger_config.id_to_tokens["training"][0])
+ len(encoder_decoder_logger_config.id_to_tokens["training"][0])
== tokenizer_T5.model_max_length
)
# Check that it has two tokens: the token "2" + EOS token
- assert len(seq2seq_logger_config.id_to_tokens["training"][1]) == 2
+ assert len(encoder_decoder_logger_config.id_to_tokens["training"][1]) == 2
# Check that both sentences end with the same EOS token
assert (
- seq2seq_logger_config.id_to_tokens["training"][0][-1]
- == seq2seq_logger_config.id_to_tokens["training"][1][-1]
+ encoder_decoder_logger_config.id_to_tokens["training"][0][-1]
+ == encoder_decoder_logger_config.id_to_tokens["training"][1][-1]
)
-def test_calculate_cutoffs(set_test_config: Callable, cleanup_after_use: Generator):
+def test_calculate_cutoffs_encoder_decoder(
+ set_test_config: Callable, cleanup_after_use: Generator
+):
+ # TODO Add comment!
"""Test that calculate_cutoffs works correctly for both input/target"""
set_test_config(task_type=TaskType.seq2seq)
mock_model = Mock(spec=T5ForConditionalGeneration)
@@ -441,7 +461,7 @@ def test_calculate_cutoffs(set_test_config: Callable, cleanup_after_use: Generat
)
dq.log_dataset(ds, text="input", label="target", split="train")
- data_logger = Seq2SeqDataLogger()
+ data_logger = EncoderDecoderDataLogger()
in_frame_split = vaex.open(
f"{data_logger.input_data_path}/training/*.{data_logger.INPUT_DATA_FILE_EXT}"
)
diff --git a/tests/utils/test_seq2seq_offset.py b/tests/utils/test_seq2seq_offset.py
index 96cf089df..608ad75b7 100644
--- a/tests/utils/test_seq2seq_offset.py
+++ b/tests/utils/test_seq2seq_offset.py
@@ -8,7 +8,9 @@
import dataquality as dq
from dataquality.integrations.seq2seq.hf import watch
-from dataquality.loggers.data_logger.seq2seq import Seq2SeqDataLogger
+from dataquality.loggers.data_logger.seq2seq.encoder_decoder import (
+ EncoderDecoderDataLogger,
+)
from dataquality.schemas.seq2seq import Seq2SeqInputCols as C
from dataquality.schemas.task_type import TaskType
from dataquality.utils.seq2seq.offsets import (
@@ -143,9 +145,13 @@ def test_rollup_spans(
def test_get_position_of_last_offset_input(
set_test_config: Callable, cleanup_after_use: Generator
):
+ # TODO Consider if need to have this separate for EncoderDecoder and Decoder-Only
"""
Test that get_position_of_last_offset_input returns the correct cut-off point for
the input text string.
+
+ We use the EncoderDecoder model, but this serves as a generic test for both
+ model types.
"""
set_test_config(task_type=TaskType.seq2seq)
mock_model = Mock(spec=T5ForConditionalGeneration)
@@ -163,7 +169,7 @@ def test_get_position_of_last_offset_input(
)
dq.log_dataset(ds, text="input", label="target", split="train")
- data_logger = Seq2SeqDataLogger()
+ data_logger = EncoderDecoderDataLogger()
in_frame_split = vaex.open(
f"{data_logger.input_data_path}/training/*.{data_logger.INPUT_DATA_FILE_EXT}"
)
@@ -185,6 +191,8 @@ def test_get_position_of_last_offset_target(
"""
Test that get_position_of_last_offset_target returns the correct cut-off point for
the target text string.
+
+ Note that this just applies to EncoderDecoder models.
"""
set_test_config(task_type=TaskType.seq2seq)
mock_model = Mock(spec=T5ForConditionalGeneration)
@@ -202,7 +210,7 @@ def test_get_position_of_last_offset_target(
)
dq.log_dataset(ds, text="input", label="target", split="train")
- data_logger = Seq2SeqDataLogger()
+ data_logger = EncoderDecoderDataLogger()
in_frame_split = vaex.open(
f"{data_logger.input_data_path}/training/*.{data_logger.INPUT_DATA_FILE_EXT}"
)
diff --git a/tests/utils/test_seq2seq_utils.py b/tests/utils/test_seq2seq_utils.py
index deb4d8e86..8700e2ead 100644
--- a/tests/utils/test_seq2seq_utils.py
+++ b/tests/utils/test_seq2seq_utils.py
@@ -8,7 +8,9 @@
import torch
from dataquality.exceptions import GalileoException
-from dataquality.loggers.model_logger.seq2seq import Seq2SeqModelLogger
+from dataquality.loggers.model_logger.seq2seq.encoder_decoder import (
+ EncoderDecoderModelLogger,
+)
from dataquality.schemas.seq2seq import (
TOP_K,
AlignedTokenData,
@@ -280,7 +282,7 @@ def test_model_logger_remove_padding() -> None:
split="training",
epoch=0,
)
- logger = Seq2SeqModelLogger(**log_data)
+ logger = EncoderDecoderModelLogger(**log_data)
logger.logger_config = config
for sample_id, (sample_logprobs, sample_top_indices) in enumerate(
zip(logprobs, top_indices)