Skip to content

Commit

Permalink
Merge branch 'main' into feat/add-transformers-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan authored Mar 4, 2024
2 parents d2affb8 + 0488ad2 commit 962c419
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 74 deletions.
6 changes: 3 additions & 3 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
The new state of the FSM.
"""
if token_id == self.eos_token_id:
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state

return self.first_state
Expand Down Expand Up @@ -172,7 +172,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
The new state of the FSM.
"""
if token_id == self.eos_token_id:
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state

last_token_to_end_state = self.states_to_token_maps[state]
Expand Down Expand Up @@ -354,7 +354,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
-------
The new state of the FSM.
"""
if token_id == self.tokenizer.eos_token_id:
if token_id == self.tokenizer.eos_token_id or state == self.final_state:
return self.final_state

self.generation += self.tokenizer.decode([token_id])[0]
Expand Down
10 changes: 9 additions & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def to_regex(
if any(is_required):
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}'
subregex += to_regex(resolver, value, whitespace_pattern)
if i < last_required_pos:
subregex = f"{subregex}{whitespace_pattern},"
Expand Down Expand Up @@ -216,6 +216,14 @@ def to_regex(

return f"({'|'.join(choices)})"

elif "const" in instance:
const = instance["const"]
if type(const) in [int, float, bool, None]:
const = re.escape(str(const))
elif type(const) == str:
const = f'"{re.escape(const)}"'
return const

elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
Expand Down
14 changes: 8 additions & 6 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,9 @@ def stream(
stop_sequences = stop_at
num_samples = self.num_samples

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)
attention_masks = attention_masks.to(prompt_token_ids.device)

# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
Expand All @@ -298,9 +294,15 @@ def stream(
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
(batch_size * num_samples),
dtype=torch.float,
device=prompt_token_ids.device,
)

if rng is None:
rng = torch.Generator(device=prompt_token_ids.device)
rng.seed()

states = sequence_generator(
self.model,
self.sampler,
Expand Down
139 changes: 76 additions & 63 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizerBase

__all__ = ["transformers"]

Expand Down Expand Up @@ -55,13 +55,87 @@ class CodeLlamaTokenizerFast: # type: ignore
)


class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(
self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs
):
if isinstance(tokenizer_or_model_name, str):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = tokenizer_or_model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_or_model_name, **kwargs
)
else:
self.tokenizer = tokenizer_or_model_name

self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

self.special_tokens = set(self.tokenizer.all_special_tokens)

self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def __eq__(self, other):
if isinstance(other, type(self)):
if hasattr(self, "model_name") and hasattr(self, "kwargs"):
return (
other.model_name == self.model_name and other.kwargs == self.kwargs
)
else:
return other.tokenizer == self.tokenizer
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))


class Transformer:
"""Represents a `transformers` model."""

def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
tokenizer: TransformerTokenizer,
):
self.device = model.device
self.model = model
Expand Down Expand Up @@ -119,67 +193,6 @@ def __call__(
return next_token_logits, kv_cache


class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

self.special_tokens = set(self.tokenizer.all_special_tokens)

self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))


def transformers(
model_name: str,
device: Optional[str] = None,
Expand Down
18 changes: 18 additions & 0 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ def test_match_number(pattern, does_match):
("0", False),
],
),
# Const string
(
{"title": "Foo", "const": "Marc", "type": "string"},
'"Marc"',
[('"Marc"', True), ('"Jean"', False), ('"John"', False)],
),
# Make sure strings are escaped
(
{"title": "Foo", "const": ".*", "type": "string"},
r'"\.\*"',
[('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)],
),
# Const integer
(
{"title": "Foo", "const": 0, "type": "integer"},
"0",
[("0", True), ("1", False), ("a", False)],
),
# Enum string
(
{"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"},
Expand Down
17 changes: 16 additions & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import outlines.generate as generate
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.transformers import TransformerTokenizer
from outlines.models.transformers import Transformer, TransformerTokenizer
from outlines.samplers import beam_search, multinomial


Expand Down Expand Up @@ -615,3 +615,18 @@ def __call__(
)

assert sequence == "c"


def test_transformers_use_existing_model_and_tokenizer():
from transformers import AutoModelForCausalLM, AutoTokenizer

rng = torch.Generator()
rng.manual_seed(10000)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = TransformerTokenizer(hf_tokenizer)
model = Transformer(hf_model, tokenizer)
sequence = generate.text(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)

0 comments on commit 962c419

Please sign in to comment.