Skip to content

Commit

Permalink
Allow specifying own HF tokenizer object
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnz authored and rlouf committed Mar 1, 2024
1 parent c0b47a4 commit c1b4ffa
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 64 deletions.
133 changes: 70 additions & 63 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizerBase

__all__ = ["transformers"]

Expand Down Expand Up @@ -55,13 +55,81 @@ class CodeLlamaTokenizerFast: # type: ignore
)


class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs):
if isinstance(tokenizer_or_model_name, str):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = tokenizer_or_model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_or_model_name, **kwargs)
else:
self.tokenizer = tokenizer_or_model_name

self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

self.special_tokens = set(self.tokenizer.all_special_tokens)

self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def __eq__(self, other):
if isinstance(other, type(self)):
if hasattr(self, "model_name") and hasattr(self, "kwargs"):
return other.model_name == self.model_name and other.kwargs == self.kwargs
else:
return other.tokenizer == self.tokenizer
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))


class Transformer:
"""Represents a `transformers` model."""

def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
tokenizer: TransformerTokenizer,
):
self.device = model.device
self.model = model
Expand Down Expand Up @@ -119,67 +187,6 @@ def __call__(
return next_token_logits, kv_cache


class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

self.special_tokens = set(self.tokenizer.all_special_tokens)

self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))


def transformers(
model_name: str,
device: Optional[str] = None,
Expand Down
17 changes: 16 additions & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import outlines.generate as generate
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.transformers import TransformerTokenizer
from outlines.models.transformers import Transformer, TransformerTokenizer
from outlines.samplers import beam_search, multinomial


Expand Down Expand Up @@ -615,3 +615,18 @@ def __call__(
)

assert sequence == "c"


def test_transformers_use_existing_model_and_tokenizer():
from transformers import AutoTokenizer, AutoModelForCausalLM

rng = torch.Generator()
rng.manual_seed(10000)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = TransformerTokenizer(hf_tokenizer)
model = Transformer(hf_model, tokenizer)
sequence = generate.text(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)

0 comments on commit c1b4ffa

Please sign in to comment.