Skip to content

Commit

Permalink
[Bugfix] Fix guided decoding with tokenizer mode mistral (vllm-projec…
Browse files Browse the repository at this point in the history
…t#11046)

Signed-off-by: Sage Moore <[email protected]>
  • Loading branch information
wallashss authored and SageMoore committed Dec 19, 2024
1 parent b6082ff commit 8c9aa7a
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 52 deletions.
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
3 changes: 2 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ aiohttp
openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
uvicorn[standard]
pydantic >= 2.9 # Required for fastapi >= 0.113.0
pillow # Required for image processing
prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar >= 0.1.6; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
Expand Down
54 changes: 48 additions & 6 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pickle

import pytest
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_guided_decoding_logits_processor,
get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'


def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down Expand Up @@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)

regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")


def test_pickle_xgrammar_tokenizer_data():

# TODO: move to another test file for xgrammar
try:
import xgrammar as xgr
except ImportError:
pytest.skip("Could not import xgrammar to run test")

from vllm.model_executor.guided_decoding.xgrammar_decoding import (
TokenizerData)
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
pickled = pickle.dumps(tokenizer_data)

assert pickled is not None

depickled: TokenizerData = pickle.loads(pickled)

assert depickled is not None
assert depickled.vocab_type == xgr.VocabType.RAW
86 changes: 84 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json

import jsonschema
import jsonschema.exceptions
import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]

MISTRAL_FORMAT_MODELS = [
Expand Down Expand Up @@ -126,6 +129,45 @@
}
]

SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -251,3 +293,43 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None

try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
Loading

0 comments on commit 8c9aa7a

Please sign in to comment.