Skip to content

Commit

Permalink
[misc][core] lazy import outlines (vllm-project#7831)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 24, 2024
1 parent d81abef commit 7d9ffa2
Showing 4 changed files with 64 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
@@ -87,7 +87,8 @@ steps:
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
- pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/openai

- label: Distributed Tests (4 GPUs) # 10min
48 changes: 48 additions & 0 deletions tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys

from vllm import LLM, SamplingParams


def test_lazy_outlines(sample_regex):
"""If users don't use guided decoding, outlines should not be imported.
"""
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# make sure outlines is not imported
assert 'outlines' not in sys.modules

llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# make sure outlines is not imported
assert 'outlines' not in sys.modules
9 changes: 6 additions & 3 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,6 @@
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor


@@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
@@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
Original file line number Diff line number Diff line change
@@ -14,9 +14,6 @@
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor


@@ -43,6 +40,10 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines

# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
@@ -87,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines

# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:

0 comments on commit 7d9ffa2

Please sign in to comment.