diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 44f47fac1c1b3..b563c96343f92 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -224,8 +224,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: diff --git a/requirements-common.txt b/requirements-common.txt index bd2b4b7a01668..1c935303c8d79 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,12 +14,13 @@ aiohttp openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) uvicorn[standard] pydantic >= 2.9 # Required for fastapi >= 0.113.0 -pillow # Required for image processing prometheus_client >= 0.18.0 +pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 +lark == 1.2.2 xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9f4d81b583141..3334c0df149b5 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,13 +1,19 @@ +import pickle + import pytest import torch from transformers import AutoTokenizer +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor, + get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' + def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" @@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -async def test_guided_logits_processor_black_box(backend: str, sample_regex, +@pytest.mark.parametrize("is_local", [True, False]) +async def test_guided_logits_processor_black_box(backend: str, is_local: bool, + sample_regex, sample_json_schema): - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + + config = ModelConfig( + MODEL_NAME, + task="generate", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer) + + regex_lp = get_local_guided_decoding_logits_processor( + regex_request, tokenizer, config) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer) + json_request, tokenizer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") + + +def test_pickle_xgrammar_tokenizer_data(): + + # TODO: move to another test file for xgrammar + try: + import xgrammar as xgr + except ImportError: + pytest.skip("Could not import xgrammar to run test") + + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + TokenizerData) + tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) + pickled = pickle.dumps(tokenizer_data) + + assert pickled is not None + + depickled: TokenizerData = pickle.loads(pickled) + + assert depickled is not None + assert depickled.vocab_type == xgr.VocabType.RAW diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 99b5d5694f9f7..bdc1571784b5d 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -3,17 +3,20 @@ Run `pytest tests/models/test_mistral.py`. """ import copy +import json +import jsonschema +import jsonschema.exceptions import pytest -from vllm import SamplingParams from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close MODELS = [ - "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] MISTRAL_FORMAT_MODELS = [ @@ -126,6 +129,45 @@ } ] +SAMPLE_JSON_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] +} + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -251,3 +293,43 @@ def test_mistral_function_calling( assert parsed_message.tool_calls[ 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa assert parsed_message.content is None + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("guided_backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with vllm_runner(model, dtype='bfloat16', + tokenizer_mode="mistral") as vllm_model: + + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, + backend=guided_backend) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) + + generated_text = outputs[0].outputs[0].text + json_response = json.loads(generated_text) + assert outputs is not None + + try: + jsonschema.validate(instance=json_response, + schema=SAMPLE_JSON_SCHEMA) + except jsonschema.exceptions.ValidationError: + pytest.fail("Generated response is not valid with JSON schema") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index fc45e37cf6f06..5b97f03257502 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any import torch from transformers import PreTrainedTokenizerFast @@ -16,6 +16,7 @@ from vllm.model_executor.guided_decoding.xgrammar_utils import ( convert_lark_to_gbnf, grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor( return XGrammarLogitsProcessor(config) -class TokenizerData(NamedTuple): +@dataclass(frozen=True) +class TokenizerData: """Immutable container for cached tokenizer data.""" - encoded_vocab: list[str] - stop_token_ids: list[int] | None - backend_str: str + encoded_vocab: list[str] = field(default_factory=list) + stop_token_ids: list[int] | None = None + # These fields are mutually exclusive: `backend_str` is used to create a + # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is + # used within the constructor of TokenizeInfo + backend_str: str | None = None + vocab_type: xgr.VocabType | None = None + + def __post_init__(self): + # Check for mutual exclusive + assert not (self.backend_str and self.vocab_type), \ + "backend_str and vocab_type are mutual exclusive" class TokenizerDataCache: @@ -68,18 +79,27 @@ def get_tokenizer_data(cls, "get_vocab method.") from e stop_token_ids = None - backend_str = xgr.VocabType.RAW + backend_str = "" + vocab_type = xgr.VocabType.RAW + + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + if isinstance(tokenizer, PreTrainedTokenizerFast): backend_str = tokenizer.backend_tokenizer.to_str() - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + vocab_type = None + + elif isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type = xgr.VocabType.BYTE_FALLBACK cls._cache[tokenizer_hash] = TokenizerData( encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + vocab_type=vocab_type) return cls._cache[tokenizer_hash] @@ -98,11 +118,30 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: cache_key = str(config.tokenizer_hash) if cache_key not in cls._cache: - assert config.encoded_vocab is not None - tokenizer_info = xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface( - config.encoded_vocab, config.backend_str, - config.vocab_size, config.stop_token_ids)) + assert config.tokenizer_data is not None + assert config.tokenizer_data.encoded_vocab is not None + + config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + # - If tokenizer_data has backend_str set, use + # xgr_core.TokenizerInfo.from_huggingface (a C++ bind). + # - Otherwise, use the default constructor with vocab_type. + # - xgr_core.TokenizerInfo.from_huggingface != + # xgr.TokenizerInfo.from_huggingface. + if config_data.backend_str: + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config_data.encoded_vocab, config_data.backend_str, + config.vocab_size, config_data.stop_token_ids)) + else: + tokenizer_info = xgr.TokenizerInfo( + config_data.encoded_vocab, + config_data.vocab_type, + vocab_size=config.vocab_size, + stop_token_ids=config_data.stop_token_ids) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -118,10 +157,7 @@ class GrammarConfig: grammar_str: str | None = None json_object: bool | None = None max_threads: int = 8 - # Only populated if tokenizer_hash not in cache - encoded_vocab: list[str] | None = None - stop_token_ids: list[int] | None = None - backend_str: str | None = None + tokenizer_data: TokenizerData | None = None @classmethod def from_guided_params(cls, @@ -132,9 +168,6 @@ def from_guided_params(cls, tokenizer_hash = hash(tokenizer) tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) - encoded_vocab = tokenizer_data.encoded_vocab - stop_token_ids = tokenizer_data.stop_token_ids - backend_str = tokenizer_data.backend_str if guided_params.json: if not isinstance(guided_params.json, str): @@ -152,11 +185,9 @@ def from_guided_params(cls, return cls(json_str=json_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -181,19 +212,17 @@ def from_guided_params(cls, return cls(grammar_str=grammar_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.json_object: - return cls(json_object=True, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + return cls( + json_object=True, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" @@ -269,10 +298,14 @@ def __call__(self, input_ids: list[int], # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type if device_type != "cuda": - scores = scores.to("cpu") + scores = scores.to("cpu").unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) if device_type != "cuda": - scores = scores.to(device_type) + scores = scores.to(device_type).squeeze() return scores diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 54f9f895fe541..e6701f4c4b835 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -132,7 +132,7 @@ def get_tokenizer( if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( 'It is strongly recommended to run mistral models with ' - '`--tokenizer_mode "mistral"` to ensure correct ' + '`--tokenizer-mode "mistral"` to ensure correct ' 'encoding and decoding.', FutureWarning, stacklevel=2) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 83b3c37d6f04c..17d722e3d88fe 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -314,12 +314,15 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens)) # type: ignore decoded = ''.join(decoded_list) return decoded + # WARN: Outlines logits processors can overwrite this method. + # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer + # for more. def decode(self, ids: Union[List[int], int], skip_special_tokens: bool = True) -> str: