diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc07435..b844cc08b8 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt): tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + if "num_tokens_pre_truncation" in tokenized_prompt: + tokenized_prompt["num_tokens_pre_truncation"] = ( + tokenized_prompt["num_tokens_pre_truncation"] + + tokenized_res_prompt["num_tokens_pre_truncation"] + ) + return tokenized_prompt diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dce..6b3efe1a1f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,6 +1,7 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc +import functools import logging from typing import Dict, List, Tuple, Union @@ -60,18 +61,23 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding( + data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0} + ) if not prompt: LOG.warning("Empty text requested for tokenization.") return empty - result = self.tokenizer( - prompt, - truncation=True, + _tokenize = functools.partial( + self.tokenizer, max_length=self.max_length, padding=False, return_tensors=None, ) + result = _tokenize( + prompt, + truncation=True, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") return empty @@ -89,6 +95,20 @@ def _tokenize( result["attention_mask"] = result["attention_mask"][1:] result["labels"] = result["input_ids"].copy() + + _all_tokens = _tokenize(prompt, truncation=False) + num_tokens_pre_truncation = len(_all_tokens["input_ids"]) + if ( + _all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id + and add_eos_token + ): + num_tokens_pre_truncation += 1 + if ( + _all_tokens["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + num_tokens_pre_truncation -= 1 + result["num_tokens_pre_truncation"] = num_tokens_pre_truncation return result