From a30ca910616e52091cab571f385fde31357a0474 Mon Sep 17 00:00:00 2001 From: Jonathan Gomes Selman Date: Tue, 26 Dec 2023 17:27:57 -0500 Subject: [PATCH] feat: generation peft (#822) --- dataquality/__init__.py | 2 +- dataquality/integrations/seq2seq/core.py | 4 +++- .../loggers/logger_config/seq2seq/seq2seq_base.py | 5 +++-- dataquality/utils/seq2seq/generation.py | 14 ++++++++++++++ pyproject.toml | 1 + 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index 0c262777d..8f9908d07 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "1.4.1" +__version__ = "1.4.2" import sys from typing import Any, List, Optional diff --git a/dataquality/integrations/seq2seq/core.py b/dataquality/integrations/seq2seq/core.py index 05ffaa736..cb02e8262 100644 --- a/dataquality/integrations/seq2seq/core.py +++ b/dataquality/integrations/seq2seq/core.py @@ -1,6 +1,7 @@ from typing import List, Optional, Union from warnings import warn +from peft import PeftModel from tokenizers import Tokenizer from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast @@ -172,8 +173,9 @@ def watch( # A model of the correct type is required if we need to generate if generation_splits: assert isinstance( - model, PreTrainedModel + model, (PreTrainedModel, PeftModel) ), "model must be an instance of transformers PreTrainedModel" + assert ( model.can_generate() ), "model must contain a `generate` method for seq2seq" diff --git a/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py b/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py index 1164007db..0a09ba0e7 100644 --- a/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py +++ b/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py @@ -1,6 +1,7 @@ from collections import defaultdict -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union +from peft import PeftModel from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig @@ -15,7 +16,7 @@ class Seq2SeqLoggerConfig(BaseLoggerConfig): max_target_tokens: Optional[int] = None # For each split/inference-name, store sample id -> List[token_id] for the label id_to_tokens: Dict[str, Dict[int, List[int]]] = defaultdict(dict) - model: Optional[PreTrainedModel] = None + model: Optional[Union[PreTrainedModel, PeftModel]] = None generation_config: Optional[GenerationConfig] = None generation_splits: Set[Split] = set() model_type: Optional[Seq2SeqModelType] = None diff --git a/dataquality/utils/seq2seq/generation.py b/dataquality/utils/seq2seq/generation.py index 3899bf500..4865a7a5a 100644 --- a/dataquality/utils/seq2seq/generation.py +++ b/dataquality/utils/seq2seq/generation.py @@ -142,6 +142,17 @@ def add_generated_output_to_df( Updated Dataframe with the generated columns added (see above) """ model.eval() + # When generating it is important to set `use_cache = True`. + # - WHAT? Caching stores intermediate token activations / representations. + # During autoregressive generation, the cache is updated each time a token + # is generated. + # - WHY? Caching prevents re-computing token information during auto-regressive + # generation, DRAMATICALLY speeding up performance. Every time a new token is + # generated, we only need to do the forward pass for a single new token, as we + # leverage the cached information to compute transformer based attention. + model_cache_flag = model.config.use_cache + model.config.use_cache = True + generated_data = BatchGenerationData() num_batches = math.ceil(len(df) / GENERATION_BATCH_SIZE) @@ -183,4 +194,7 @@ def add_generated_output_to_df( generated_data.generated_top_logprobs, type=TOP_LOGPROBS_SCHEMA ) + # Reset the cache flag for the model + model.config.use_cache = model_cache_flag + return df diff --git a/pyproject.toml b/pyproject.toml index 007dfa292..810dc792f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "ipywidgets>=8.1.0", "imagededup>=0.3.1", "pyjwt>=2.8.0", + "peft" ] [[project.authors]] name = "Galileo Technologies, Inc."