Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix guided decoding with tokenizer mode mistral #11046

Merged
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
19 changes: 15 additions & 4 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams


MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down Expand Up @@ -40,12 +41,22 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +70,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand Down
87 changes: 85 additions & 2 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json

import jsonschema
import jsonschema.exceptions
import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]

MISTRAL_FORMAT_MODELS = [
Expand Down Expand Up @@ -126,6 +129,45 @@
}
]

SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -251,3 +293,44 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model,
dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None

try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ async def build_guided_decoding_logits_processor_async(
guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,7 @@ def _build_logits_processors(
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor:
Expand Down
29 changes: 7 additions & 22 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

from vllm.logger import init_logger
Expand Down Expand Up @@ -86,7 +87,8 @@


async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
Expand Down Expand Up @@ -115,25 +117,8 @@
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert this change?

This is being used for offline use case, with LLM, where as get_guided_decoding_logit_processor is being used for online usecase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed what I did and checked that it was not so good based on the difference of implementation of the methods get_local_outlines_guided_decoding_logits_processor and get_outlines_guided_decoding_logits_processor. But I tried something a little bit difference to not revert everything, just to avoid code duplication. See if you agree, if not I won't insist I can revert with no problem. Also I updated the tests to check the offline and online version to pass all over these code paths, considering the offline path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is even more confusing now that there are three functions. I would prefer a revert as it seems you have no other changes to this file? We can consider refactor in another PR

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
loop = asyncio.get_event_loop()
f = get_guided_decoding_logits_processor(guided_params, tokenizer, model_config)

Check failure on line 122 in vllm/model_executor/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/guided_decoding/__init__.py:122:81: E501 Line too long (84 > 80)
res = loop.run_until_complete(f)
return res
60 changes: 44 additions & 16 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from vllm.model_executor.guided_decoding.xgrammar_utils import (
convert_lark_to_gbnf, grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
Expand All @@ -41,7 +42,8 @@ class TokenizerData(NamedTuple):
"""Immutable container for cached tokenizer data."""
encoded_vocab: list[str]
stop_token_ids: list[int] | None
backend_str: str
backend_str: str | None
vocab_type: xgr.VocabType | None


class TokenizerDataCache:
Expand All @@ -68,18 +70,26 @@ def get_tokenizer_data(cls,
"get_vocab method.") from e

stop_token_ids = None
backend_str = xgr.VocabType.RAW
backend_str = ""
vocab_type = xgr.VocabType.RAW

if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]

if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]

elif isinstance(tokenizer, MistralTokenizer):
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type = xgr.VocabType.BYTE_FALLBACK

cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str)
backend_str=backend_str,
vocab_type=vocab_type)

return cls._cache[tokenizer_hash]

Expand All @@ -99,10 +109,18 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:

if cache_key not in cls._cache:
assert config.encoded_vocab is not None
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str,
config.vocab_size, config.stop_token_ids))

if config.backend_str:
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str,
config.vocab_size, config.stop_token_ids))
else:
tokenizer_info = xgr.TokenizerInfo(
mgoin marked this conversation as resolved.
Show resolved Hide resolved
config.encoded_vocab,
config.vocab_type,
vocab_size=config.vocab_size,
stop_token_ids=config.stop_token_ids)
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)

Expand All @@ -122,6 +140,7 @@ class GrammarConfig:
encoded_vocab: list[str] | None = None
stop_token_ids: list[int] | None = None
backend_str: str | None = None
vocab_type: xgr.VocabType = xgr.VocabType.RAW

@classmethod
def from_guided_params(cls,
Expand All @@ -136,11 +155,13 @@ def from_guided_params(cls,
encoded_vocab = None
stop_token_ids = None
backend_str = None
vocab_type = xgr.VocabType.RAW
else:
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
vocab_type = tokenizer_data.vocab_type

if guided_params.json:
if not isinstance(guided_params.json, str):
Expand All @@ -153,7 +174,8 @@ def from_guided_params(cls,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
max_threads=max_threads,
vocab_type=vocab_type)
elif guided_params.grammar:
# XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar):
Expand All @@ -173,15 +195,17 @@ def from_guided_params(cls,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
max_threads=max_threads,
vocab_type=vocab_type)
elif guided_params.json_object:
return cls(json_object=True,
vocab_size=model_config.hf_text_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
max_threads=max_threads,
vocab_type=vocab_type)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
Expand Down Expand Up @@ -257,10 +281,14 @@ def __call__(self, input_ids: list[int],
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
if device_type != "cuda":
wallashss marked this conversation as resolved.
Show resolved Hide resolved
scores = scores.to("cpu")
scores = scores.to("cpu").unsqueeze(0)

# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
# unsqueeze above for scores, to match the token bitmask shape
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda":
scores = scores.to(device_type)
scores = scores.to(device_type).squeeze()

return scores
2 changes: 1 addition & 1 deletion vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_tokenizer(
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'`--tokenizer-mode "mistral"` to ensure correct '
'encoding and decoding.',
FutureWarning,
stacklevel=2)
Expand Down
Loading
Loading