Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix guided decoding with tokenizer mode mistral #11046

Merged
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
51 changes: 45 additions & 6 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import contextlib
import pickle

import pytest
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_guided_decoding_logits_processor,
get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing from xgrammar_decoding.py will also import+require xgrammar, in addition to your import xgrammar as xgr line below. I would recommend making your test_pickle_xgrammar_tokenizer_data function skip if xgrammar isn't able to be imported, and simply including these two imports within that function.

In fact it may be best to add a dedicated xgrammar testing file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change but you need to move from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData to the try-except as well

from vllm.sampling_params import GuidedDecodingParams

with contextlib.suppress(ImportError):
import xgrammar as xgr

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'


def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down Expand Up @@ -38,14 +49,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)

regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +85,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -84,3 +110,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")


def test_pickle_xgrammar_tokenizer_data():

tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
pickled = pickle.dumps(tokenizer_data)

assert pickled is not None

depickled: TokenizerData = pickle.loads(pickled)

assert depickled is not None
assert depickled.vocab_type == xgr.VocabType.RAW
86 changes: 84 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json

import jsonschema
import jsonschema.exceptions
import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]

MISTRAL_FORMAT_MODELS = [
Expand Down Expand Up @@ -126,6 +129,45 @@
}
]

SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -251,3 +293,43 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None

try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
27 changes: 14 additions & 13 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,15 @@ async def get_guided_decoding_logits_processor(
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
return _get_local_guided_decoding_logits_processor(guided_params,
tokenizer, model_config)


def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:

guided_params = maybe_backend_fallback(guided_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert this change?

This is being used for offline use case, with LLM, where as get_guided_decoding_logit_processor is being used for online usecase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed what I did and checked that it was not so good based on the difference of implementation of the methods get_local_outlines_guided_decoding_logits_processor and get_outlines_guided_decoding_logits_processor. But I tried something a little bit difference to not revert everything, just to avoid code duplication. See if you agree, if not I won't insist I can revert with no problem. Also I updated the tests to check the offline and online version to pass all over these code paths, considering the offline path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is even more confusing now that there are three functions. I would prefer a revert as it seems you have no other changes to this file? We can consider refactor in another PR

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
Expand All @@ -123,6 +113,17 @@ def get_local_guided_decoding_logits_processor(
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)

return _get_local_guided_decoding_logits_processor(guided_params,
tokenizer, model_config)


def _get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:

assert guided_params.backend != 'outlines'

if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
Expand Down
Loading
Loading