Skip to content

Commit

Permalink
Use LogitsProcessors for models.transformers -> outlines.generate.*
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 14, 2024
1 parent 18aaba1 commit f2e8c19
Show file tree
Hide file tree
Showing 16 changed files with 476 additions and 94 deletions.
53 changes: 52 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model
```python
from outlines import models

model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda")
model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda")
```

If you need more fine-grained control you can also initialize the model and tokenizer separately:
Expand All @@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = models.Transformers(llm, tokenizer)
```

# Using Logits Processors

There are two ways to use Outlines Structured Generation with HuggingFace Transformers:
- 1) Use Outlines generation wrapper, `outlines.models.transformers`
- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM`

Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern.

## Example: `outlines.models.transformers`

```
import outlines
time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")
generator = outlines.generate.regex(model, time_regex_pattern)
output = generator("The the best time to visit a dentist is at ")
print(output)
# 2:30 pm
```

## Example: Direct `transformers` library use

```
import outlines
import transformers
model_uri = "microsoft/Phi-3-mini-4k-instruct"
outlines_tokenizer = outlines.models.TransformerTokenizer(
transformers.AutoTokenizer.from_pretrained(model_uri)
)
phone_number_logits_processor = outlines.processors.RegexLogitsProcessor(
"\\+?[1-9][0-9]{7,14}", # phone number pattern
outlines_tokenizer,
)
generator = transformers.pipeline('text-generation', model=model_uri)
output = generator(
"Jenny gave me her number it's ",
logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
)
print(output)
# [{'generated_text': "Jenny gave me her number it's 2125550182"}]
# not quite 8675309 what we expected, but it is a valid phone number
```

[transformers]: https://github.com/huggingface/transformers
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import outlines.generate
import outlines.grammars
import outlines.models
import outlines.processors
import outlines.types
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
Expand Down
42 changes: 8 additions & 34 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from functools import singledispatch

from outlines.fsm.guide import CFGGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial


@singledispatch
def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator:
def cfg(
model, cfg_str: str, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
"""Generate text in the language of a Context-Free Grammar
Arguments
Expand All @@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
Returns
-------
A `SequenceGenerator` instance that generates text.
A `SequenceGeneratorAdapter` instance that generates text.
"""
fsm = CFGGuide(cfg_str, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
f"The CFG Logits processor is not available for {type(model)}."
f"The CFG Logits processor is not available for {type(model)}. "
+ "Please subscribe to https://github.com/outlines-dev/outlines/issues/684"
+ " for updates on the fix."
)


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
cfg_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(OpenAI)
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError(
Expand Down
6 changes: 4 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.transformers import Transformers
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):


@regex.register(MLXLM)
def regex_mlxlm(
model: MLXLM,
@regex.register(Transformers)
def regex_unified(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
Expand Down
5 changes: 3 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:


@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
@text.register(Transformers)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


Expand Down
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .transformers import Transformers, TransformerTokenizer, transformers
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba]
6 changes: 3 additions & 3 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import BaseLogitsProcessor
from outlines.processors import OutlinesLogitsProcessor


class MLXLM:
Expand Down Expand Up @@ -120,7 +120,7 @@ def generate_step(
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "BaseLogitsProcessor",
logits_processor: "OutlinesLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
Expand All @@ -135,7 +135,7 @@ def generate_step(
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
"""
import mlx.core as mx
import mlx_lm
Expand Down
Loading

0 comments on commit f2e8c19

Please sign in to comment.