Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add transformers integration #728

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/transformers_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Example of integrating `outlines` with `transformers`."""

from pydantic import BaseModel
from transformers import pipeline

from outlines.integrations.transformers import JSONPrefixAllowedTokens


class Person(BaseModel):
first_name: str
surname: str


pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1")
prefix_allowed_tokens_fn = JSONPrefixAllowedTokens(
schema=Person, tokenizer_or_pipe=pipe, whitespace_pattern=r" ?"
)
results = pipe(
["He is Tom Jones", "She saw Linda Smith"],
return_full_text=False,
do_sample=False,
max_new_tokens=50,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
print(results)
20 changes: 12 additions & 8 deletions examples/vllm_integration.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
"""Example of integrating `outlines` with `vllm`."""

import vllm
from pydantic import BaseModel

from outlines.serve.vllm import JSONLogitsProcessor
from outlines.integrations.vllm import JSONLogitsProcessor


class User(BaseModel):
id: int
name: str
class Person(BaseModel):
first_name: str
surname: str


llm = vllm.LLM(model="openai-community/gpt2")
logits_processor = JSONLogitsProcessor(schema=User, llm=llm)
llm = vllm.LLM(model="mistralai/Mistral-7B-v0.1", max_model_len=512)
logits_processor = JSONLogitsProcessor(schema=Person, llm=llm, whitespace_pattern=r" ?")
result = llm.generate(
["A prompt", "Another prompt"],
["He is Tom Jones", "She saw Linda Smith"],
sampling_params=vllm.SamplingParams(
max_tokens=100, logits_processors=[logits_processor]
temperature=0.0,
max_tokens=50,
logits_processors=[logits_processor],
),
)
print(result)
7 changes: 6 additions & 1 deletion outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def get_next_state(self, state: int, token_id: int) -> int:
def is_final_state(self, state: int) -> bool:
...

def copy(self) -> "Guide":
...


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""
Expand Down Expand Up @@ -189,7 +192,9 @@ def get_next_state(self, state: int, token_id: int) -> int:
"""
if token_id == self.eos_token_id:
return -1
elif state in self.final_states:
elif (
state in self.final_states
): # Necessary because we keep generating EOS tokens when finished
return state

last_token_to_end_state = self.states_to_token_maps[state]
Expand Down
43 changes: 33 additions & 10 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any, Callable, Tuple
from typing import Protocol, Tuple, Type, Union

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
Expand All @@ -9,26 +9,49 @@
DATETIME = rf"({DATE})(\s)({TIME})"


def python_types_to_regex(python_type: Any) -> Tuple[str, Callable[[str], Any]]:
class FormatFunction(Protocol):
def __call__(
self, sequence: str
) -> Union[int, float, bool, datetime.date, datetime.time, datetime.datetime]:
...


def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
if python_type == float:
float_format_fn = lambda x: float(x)

def float_format_fn(sequence: str) -> float:
return float(sequence)

return FLOAT, float_format_fn
elif python_type == int:
int_format_fn = lambda x: int(x)

def int_format_fn(sequence: str) -> int:
return int(sequence)

return INTEGER, int_format_fn
elif python_type == bool:
bool_format_fn = lambda x: bool(x)

def bool_format_fn(sequence: str) -> bool:
return bool(sequence)

return BOOLEAN, bool_format_fn
elif python_type == datetime.date:
date_format_fn = lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date()

def date_format_fn(sequence: str) -> datetime.date:
return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()

return DATE, date_format_fn
elif python_type == datetime.time:
time_format_fn = lambda s: datetime.datetime.strptime(s, "%H:%M:%S").time()

def time_format_fn(sequence: str) -> datetime.time:
return datetime.datetime.strptime(sequence, "%H:%M:%S").time()

return TIME, time_format_fn
elif python_type == datetime.datetime:
datetime_format_fn = lambda s: datetime.datetime.strptime(
s, "%Y-%m-%d %H:%M:%S"
)

def datetime_format_fn(sequence: str) -> datetime.datetime:
return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")

return DATETIME, datetime_format_fn
else:
raise NotImplementedError(
Expand Down
18 changes: 11 additions & 7 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import datetime
from typing import Iterator, List, Optional, Union

import torch

from outlines.generate.generator import sequence_generator

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]


class SequenceGenerator:
def __init__(
Expand Down Expand Up @@ -100,7 +105,7 @@ def strip_stop_sequences(

return sequence

def format_sequence(self, sequence: str) -> str:
def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.

This method is for instance overridden when generating JSON to either
Expand All @@ -124,7 +129,7 @@ def __call__(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
) -> Union[str, List[str], List[List[str]]]:
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.

Since `SequenceGenerator.stream` calls the tokenizer at every step this
Expand All @@ -148,8 +153,7 @@ def __call__(

Returns
-------
A string or list of strings that contain the generated text.

The generation(s), potentially cast to another type.
"""

if isinstance(prompts, str):
Expand Down Expand Up @@ -222,7 +226,7 @@ def __call__(
formatted = [self.format_sequence(sequence) for sequence in stripped]

# We reshape the output to (batch_size, sample_size)
output = []
output: List[List[FormattedOutput]] = list()
for i in range(batch_size):
output.append(formatted[i : i + num_samples])

Expand All @@ -242,7 +246,7 @@ def stream(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
) -> Iterator[Union[List[str], List[List[str]], str]]:
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.

Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
Expand Down Expand Up @@ -352,7 +356,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
]

# We reshape the output to (batch_size, sample_size)
output = []
output: List[List[str]] = list()
for i in range(batch_size):
output.append(next_tokens[i : i + num_samples])

Expand Down
11 changes: 4 additions & 7 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator
from outlines.integrations.llamacpp import RegexLogitsProcessor
from outlines.models import OpenAI
from outlines.models.llamacpp import (
LlamaCpp,
LlamaSequenceGenerator,
RegexLogitsProcessor,
)
from outlines.models.llamacpp import LlamaCpp, LlamaSequenceGenerator
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -52,8 +49,8 @@ def regex_llamacpp(
+ "than the multinomial sampler."
)

logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
generator = LlamaSequenceGenerator(logits_processor, model)
logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
generator = LlamaSequenceGenerator(logits_processor=logits_processor, model=model)

return generator

Expand Down
1 change: 1 addition & 0 deletions outlines/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utility functions and classes used to integrate `outlines` into other packages."""
Loading
Loading