Skip to content

Commit

Permalink
Fix vLLM integration (dottxt-ai#711)
Browse files Browse the repository at this point in the history
When integrating Outlines with vLLM I faced the following issues, which
are fixed in this PR:

1. When calling `vllm.LLM.generate` then within the internals of vLLM a
`copy.deepcopy` of the vLLM `SamplingParams` is made, which includes the
logits processor from Outlines (`RegexLogitsProcessor`, say). This
requires everything to be pickleable, and the
`RegexLogitsProcessor.fsm.vocabulary` is a `dict_values` object, which
doesn't satisfy that. The fix is easy: just convert it to a list. This
doesn't affect how this `vocabulary` variable is being used in the code.
2. The `RegexLogitsProcessor` takes an `llm` argument, which the
docstring states should be a `vllm.LLM` object, but then attempts to
extract the underlying tokenizer via `llm.tokenizer.tokenizer`. The
tokenizer of `vllm.LLM` currently lies in the
`vllm.LLM.llm_engine.tokenizer.tokenizer` attribute, but this is a big
mess and isn't backwards compatible with previous vLLM versions.
Instead, they have a convenience method, `vllm.LLM.get_tokenizer`, which
fetches the tokenizer. To remain backwards compatibility, in case people
have supplied `vllm.LLM.llm_engine` directly into
`RegexLogitsProcessor`, it falls back to a `tokenizer` or
`tokenizer.tokenizer` attribute.

I also updated the vLLM example script, as that was outdated as well
(used the previous `_patched_apply_logits_processors`).

Closes dottxt-ai#704
  • Loading branch information
saattrupdan authored Feb 27, 2024
1 parent d85e67f commit d938678
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ docs/build
.coverage
.idea/
*.gguf
.venv
10 changes: 3 additions & 7 deletions examples/vllm_integration.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import vllm
import vllm.model_executor.layers.sampler as sampler
from pydantic import BaseModel

from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors

# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors
from outlines.serve.vllm import JSONLogitsProcessor


class User(BaseModel):
id: int
name: str


llm = vllm.LLM(model="gpt2")
logits_processor = JSONLogitsProcessor(User, llm)
llm = vllm.LLM(model="openai-community/gpt2")
logits_processor = JSONLogitsProcessor(schema=User, llm=llm)
result = llm.generate(
["A prompt", "Another prompt"],
sampling_params=vllm.SamplingParams(
Expand Down
4 changes: 2 additions & 2 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def create_states_mapping(
self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.vocabulary = tokenizer.vocabulary.values()
self.vocabulary = list(tokenizer.vocabulary.values())
self.eos_token_id = tokenizer.eos_token_id

def allowed_token_ids(self, state: FSMState) -> List[int]:
Expand Down Expand Up @@ -218,7 +218,7 @@ def create_states_mapping_from_interegular_fsm(
) = create_states_mapping_from_interegular_fsm(
interegular_fsm, tuple(sorted(tokenizer.vocabulary.items()))
)
from_interegular_instance.vocabulary = tokenizer.vocabulary.values()
from_interegular_instance.vocabulary = list(tokenizer.vocabulary.values())
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
return from_interegular_instance

Expand Down
14 changes: 13 additions & 1 deletion outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@ def __init__(self, regex_string, llm):
An instance of `vllm.LLM`
"""
tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer)
if hasattr(llm, "get_tokenizer"):
tokenizer = llm.get_tokenizer()
elif hasattr(llm, "tokenizer"):
if hasattr(llm.tokenizer, "tokenizer"):
tokenizer = llm.tokenizer.tokenizer
else:
tokenizer = llm.tokenizer
else:
raise ValueError(
"The provided LLM instance in `RegexLogitsProcessor` neither has a "
"`tokenizer` attribute or a `get_tokenizer` method."
)
tokenizer = self.adapt_tokenizer(tokenizer=tokenizer)

fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
Expand Down

0 comments on commit d938678

Please sign in to comment.