Skip to content

Commit

Permalink
Add pre truncation token counting for completion
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 19, 2024
1 parent cd6203e commit 57167dd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/axolotl/prompt_strategies/alpaca_w_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt):
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

if "num_tokens_pre_truncation" in tokenized_prompt:
tokenized_prompt["num_tokens_pre_truncation"] = (
tokenized_prompt["num_tokens_pre_truncation"]
+ tokenized_res_prompt["num_tokens_pre_truncation"]
)

return tokenized_prompt


Expand Down
28 changes: 24 additions & 4 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""

import abc
import functools
import logging
from typing import Dict, List, Tuple, Union

Expand Down Expand Up @@ -60,18 +61,23 @@ def supports_batched(self):
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
empty = BatchEncoding(
data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0}
)
if not prompt:
LOG.warning("Empty text requested for tokenization.")
return empty

result = self.tokenizer(
prompt,
truncation=True,
_tokenize = functools.partial(
self.tokenizer,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
result = _tokenize(
prompt,
truncation=True,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
Expand All @@ -89,6 +95,20 @@ def _tokenize(
result["attention_mask"] = result["attention_mask"][1:]

result["labels"] = result["input_ids"].copy()

_all_tokens = _tokenize(prompt, truncation=False)
num_tokens_pre_truncation = len(_all_tokens["input_ids"])
if (
_all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id
and add_eos_token
):
num_tokens_pre_truncation += 1
if (
_all_tokens["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
num_tokens_pre_truncation -= 1
result["num_tokens_pre_truncation"] = num_tokens_pre_truncation
return result


Expand Down

0 comments on commit 57167dd

Please sign in to comment.