Skip to content

Commit

Permalink
enable generate.fsm with llamacpp by using outlines.processors
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jul 15, 2024
1 parent cdd49e5 commit 4ead465
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 29 deletions.
19 changes: 12 additions & 7 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,7 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

# cache returned masks token masks
# this increases performance of the mask substantially
self.states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()))
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
}
self._cache_state_to_token_tensor()

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand Down Expand Up @@ -285,8 +279,19 @@ def create_states_mapping_from_interegular_fsm(
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
from_interegular_instance._cache_state_to_token_tensor()
return from_interegular_instance

def _cache_state_to_token_tensor(self):
"""
cache state -> token int tensor
this increases performance of mask construction substantially
"""
self.states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()))
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
}

def is_final_state(self, state: int) -> bool:
"""Determine whether the current state of the guide is a final state."""
return state in self.final_states
Expand Down
19 changes: 18 additions & 1 deletion outlines/generate/fsm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
from functools import singledispatch

import interegular

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, LlamaCpp, Transformers
from outlines.samplers import Sampler, multinomial


@singledispatch
def fsm(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator


@fsm.register(MLXLM)
@fsm.register(Transformers)
@fsm.register(LlamaCpp)
def fsm_unified(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
from outlines.processors import FSMLogitsProcessor

fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
13 changes: 1 addition & 12 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):

@regex.register(MLXLM)
@regex.register(Transformers)
@regex.register(LlamaCpp)
def regex_unified(
model,
regex_str: str,
Expand All @@ -52,18 +53,6 @@ def regex_unified(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
Expand Down
6 changes: 1 addition & 5 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:

@text.register(MLXLM)
@text.register(Transformers)
@text.register(LlamaCpp)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)

Expand All @@ -47,11 +48,6 @@ def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(LlamaCpp)
def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
Expand Down
4 changes: 4 additions & 0 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class LlamaCpp:
def __init__(self, model: "Llama"):
self.model = model

@property
def tokenizer(self):
return LlamaCppTokenizer(self.model)

def prepare_generation_parameters(
self,
generation_parameters: GenerationParameters,
Expand Down
4 changes: 0 additions & 4 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ def forward(

return output.logits, output.past_key_values

@property
def device(self):
return self.model.device

def __call__(
self,
input_ids: "torch.LongTensor",
Expand Down
11 changes: 11 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_generate_text_stream(request, model_fixture):
assert isinstance(token, str)


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_fsm(request, model_fixture, pattern):
import interegular

model = request.getfixturevalue(model_fixture)
generator = generate.fsm(model, interegular.parse_pattern(pattern).to_fsm())
res = generator("test")
assert re.fullmatch(pattern, res) is not None, res


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex(request, model_fixture, pattern):
Expand Down

0 comments on commit 4ead465

Please sign in to comment.