Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix guided decoding with tokenizer mode mistral #11046

Merged
Merged
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
3 changes: 2 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ aiohttp
openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
uvicorn[standard]
pydantic >= 2.9 # Required for fastapi >= 0.113.0
pillow # Required for image processing
prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar >= 0.1.6; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
Expand Down
54 changes: 48 additions & 6 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pickle

import pytest
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_guided_decoding_logits_processor,
get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'


def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down Expand Up @@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)

regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")


def test_pickle_xgrammar_tokenizer_data():

# TODO: move to another test file for xgrammar
try:
import xgrammar as xgr
except ImportError:
pytest.skip("Could not import xgrammar to run test")

from vllm.model_executor.guided_decoding.xgrammar_decoding import (
TokenizerData)
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
pickled = pickle.dumps(tokenizer_data)

assert pickled is not None

depickled: TokenizerData = pickle.loads(pickled)

assert depickled is not None
assert depickled.vocab_type == xgr.VocabType.RAW
86 changes: 84 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json

import jsonschema
import jsonschema.exceptions
import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]

MISTRAL_FORMAT_MODELS = [
Expand Down Expand Up @@ -126,6 +129,45 @@
}
]

SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -251,3 +293,43 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None

try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
Loading
Loading