Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix guided decoding with tokenizer mode mistral #11046

Merged
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
19 changes: 16 additions & 3 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'


def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down Expand Up @@ -40,12 +43,22 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +72,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand Down
86 changes: 84 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json

import jsonschema
import jsonschema.exceptions
import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]

MISTRAL_FORMAT_MODELS = [
Expand Down Expand Up @@ -126,6 +129,45 @@
}
]

SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -251,3 +293,43 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None

try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
27 changes: 6 additions & 21 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

from vllm.logger import init_logger
Expand Down Expand Up @@ -115,25 +116,9 @@ async def get_guided_decoding_logits_processor(
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert this change?

This is being used for offline use case, with LLM, where as get_guided_decoding_logit_processor is being used for online usecase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed what I did and checked that it was not so good based on the difference of implementation of the methods get_local_outlines_guided_decoding_logits_processor and get_outlines_guided_decoding_logits_processor. But I tried something a little bit difference to not revert everything, just to avoid code duplication. See if you agree, if not I won't insist I can revert with no problem. Also I updated the tests to check the offline and online version to pass all over these code paths, considering the offline path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is even more confusing now that there are three functions. I would prefer a revert as it seems you have no other changes to this file? We can consider refactor in another PR

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
loop = asyncio.get_event_loop()
f = get_guided_decoding_logits_processor(guided_params, tokenizer,
model_config)
res = loop.run_until_complete(f)
return res
Loading
Loading