From 83ea81cd247669a8f7449f3b6b07b9c7ce7025aa Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Mon, 9 Dec 2024 14:29:59 -0300 Subject: [PATCH 01/10] [Bugfix] Fix guided decoding with tokenizer mode mistral Signed-off-by: Wallas Santos --- .buildkite/test-pipeline.yaml | 6 +- .../decoder_only/language/test_mistral.py | 85 ++++++++++++++++++- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 4 +- .../guided_decoding/__init__.py | 37 +++----- .../guided_decoding/xgrammar_decoding.py | 70 ++++++++++----- vllm/transformers_utils/tokenizer.py | 2 +- vllm/transformers_utils/tokenizers/mistral.py | 5 +- 8 files changed, 152 insertions(+), 61 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8f57006214c88..5a371c3d21f24 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -221,8 +221,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 99b5d5694f9f7..420e05ccd9551 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -3,17 +3,20 @@ Run `pytest tests/models/test_mistral.py`. """ import copy +import json +import jsonschema +import jsonschema.exceptions import pytest -from vllm import SamplingParams from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close MODELS = [ - "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] MISTRAL_FORMAT_MODELS = [ @@ -126,6 +129,45 @@ } ] +SAMPLE_JSON_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] +} + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -251,3 +293,42 @@ def test_mistral_function_calling( assert parsed_message.tool_calls[ 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa assert parsed_message.content is None + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("guided_backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with vllm_runner(model, tokenizer_mode="mistral") as vllm_model: + + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, + backend=guided_backend) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) + + generated_text = outputs[0].outputs[0].text + json_response = json.loads(generated_text) + assert outputs is not None + + try: + jsonschema.validate(instance=json_response, + schema=SAMPLE_JSON_SCHEMA) + except jsonschema.exceptions.ValidationError: + pytest.fail("Generated response is not valid with JSON schema") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 60dccd7a0812c..448226dd75952 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -551,9 +551,7 @@ async def build_guided_decoding_logits_processor_async( guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, - tokenizer=tokenizer, - model_config=model_config) + guided_params=guided_decoding, tokenizer=tokenizer) if processor: if sampling_params.logits_processors is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 560f84a008291..5b6879813e740 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2008,9 +2008,7 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, - tokenizer=tokenizer, - model_config=self.model_config) + guided_params=guided_decoding, tokenizer=tokenizer) if processor: logits_processors.append(processor) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index e631aec928ec5..a8a96f900fe40 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -8,7 +9,6 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer - from vllm.config import ModelConfig from vllm.logits_process import LogitsProcessor from vllm.sampling_params import GuidedDecodingParams @@ -86,8 +86,8 @@ def maybe_backend_fallback( async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': @@ -105,7 +105,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -113,27 +113,10 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: - guided_params = maybe_backend_fallback(guided_params) - # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines': - # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 - from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa - get_local_outlines_guided_decoding_logits_processor) - return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend == 'lm-format-enforcer': - from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa - get_local_lm_format_enforcer_guided_decoding_logits_processor) - return get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend == 'xgrammar': - from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa - get_local_xgrammar_guided_decoding_logits_processor) - return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None: - raise ValueError( - f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") + loop = asyncio.get_event_loop() + f = get_guided_decoding_logits_processor(guided_params, tokenizer) + res = loop.run_until_complete(f) + return res \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index b59a2269d2cd5..ef780f9005239 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -16,11 +16,11 @@ from vllm.model_executor.guided_decoding.xgrammar_utils import ( convert_lark_to_gbnf, grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer if TYPE_CHECKING: from transformers import PreTrainedTokenizer - from vllm.config import ModelConfig from vllm.sampling_params import GuidedDecodingParams @@ -28,10 +28,8 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, - model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) return XGrammarLogitsProcessor(config) @@ -41,7 +39,8 @@ class TokenizerData(NamedTuple): """Immutable container for cached tokenizer data.""" encoded_vocab: list[str] stop_token_ids: list[int] | None - backend_str: str + backend_str: str | None + vocab_type: xgr.VocabType | None class TokenizerDataCache: @@ -68,18 +67,26 @@ def get_tokenizer_data(cls, "get_vocab method.") from e stop_token_ids = None - backend_str = xgr.VocabType.RAW + backend_str = "" + vocab_type = xgr.VocabType.RAW + + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + if isinstance(tokenizer, PreTrainedTokenizerFast): backend_str = tokenizer.backend_tokenizer.to_str() - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + + elif isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type = xgr.VocabType.BYTE_FALLBACK cls._cache[tokenizer_hash] = TokenizerData( encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + vocab_type=vocab_type) return cls._cache[tokenizer_hash] @@ -99,10 +106,18 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: if cache_key not in cls._cache: assert config.encoded_vocab is not None - tokenizer_info = xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface( - config.encoded_vocab, config.backend_str, - config.vocab_size, config.stop_token_ids)) + + if config.backend_str: + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config.encoded_vocab, config.backend_str, + config.vocab_size, config.stop_token_ids)) + else: + tokenizer_info = xgr.TokenizerInfo( + config.encoded_vocab, + config.vocab_type, + vocab_size=config.vocab_size, + stop_token_ids=config.stop_token_ids) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -122,11 +137,11 @@ class GrammarConfig: encoded_vocab: list[str] | None = None stop_token_ids: list[int] | None = None backend_str: str | None = None + vocab_type: xgr.VocabType = xgr.VocabType.RAW @classmethod def from_guided_params(cls, guided_params: GuidedDecodingParams, - model_config: ModelConfig, tokenizer: PreTrainedTokenizer, max_threads: int = 8) -> GrammarConfig: @@ -136,11 +151,13 @@ def from_guided_params(cls, encoded_vocab = None stop_token_ids = None backend_str = None + vocab_type = xgr.VocabType.RAW else: tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) encoded_vocab = tokenizer_data.encoded_vocab stop_token_ids = tokenizer_data.stop_token_ids backend_str = tokenizer_data.backend_str + vocab_type = tokenizer_data.vocab_type if guided_params.json: if not isinstance(guided_params.json, str): @@ -148,12 +165,13 @@ def from_guided_params(cls, else: json_str = guided_params.json return cls(json_str=json_str, - vocab_size=model_config.hf_config.vocab_size, + vocab_size=tokenizer.vocab_size, encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + vocab_type=vocab_type) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -168,20 +186,22 @@ def from_guided_params(cls, else: grammar_str = guided_params.grammar return cls(grammar_str=grammar_str, - vocab_size=model_config.hf_config.vocab_size, + vocab_size=tokenizer.vocab_size, encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + vocab_type=vocab_type) elif guided_params.json_object: return cls(json_object=True, - vocab_size=model_config.hf_config.vocab_size, + vocab_size=tokenizer.vocab_size, encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + vocab_type=vocab_type) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" @@ -257,10 +277,14 @@ def __call__(self, input_ids: list[int], # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type if device_type != "cuda": - scores = scores.to("cpu") + scores = scores.to("cpu").unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) if device_type != "cuda": - scores = scores.to(device_type) + scores = scores.to(device_type).squeeze() return scores diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 54f9f895fe541..e6701f4c4b835 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -132,7 +132,7 @@ def get_tokenizer( if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( 'It is strongly recommended to run mistral models with ' - '`--tokenizer_mode "mistral"` to ensure correct ' + '`--tokenizer-mode "mistral"` to ensure correct ' 'encoding and decoding.', FutureWarning, stacklevel=2) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 83b3c37d6f04c..b360bf4e38b13 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -314,12 +314,15 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens)) # type: ignore decoded = ''.join(decoded_list) return decoded + # WARN: Outlines logits processors can be overwrite this method. + # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer + # for more. def decode(self, ids: Union[List[int], int], skip_special_tokens: bool = True) -> str: From 710fcc9ce69a4c5183ac067c9f6a5e77ba206bcf Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 10 Dec 2024 14:29:14 -0300 Subject: [PATCH 02/10] revert tokenizer.get_vocab_size Signed-off-by: Wallas Santos --- vllm/engine/async_llm_engine.py | 4 +++- vllm/engine/llm_engine.py | 4 +++- vllm/model_executor/guided_decoding/__init__.py | 4 +++- vllm/model_executor/guided_decoding/xgrammar_decoding.py | 4 ++++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 448226dd75952..92329e9d011a3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -551,7 +551,9 @@ async def build_guided_decoding_logits_processor_async( guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=model_config) if processor: if sampling_params.logits_processors is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 387062e50b191..4d477907550c1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2010,7 +2010,9 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config) if processor: logits_processors.append(processor) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index a8a96f900fe40..4f33da3677631 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer + from vllm.config import ModelConfig from vllm.logits_process import LogitsProcessor from vllm.sampling_params import GuidedDecodingParams @@ -87,7 +88,8 @@ def maybe_backend_fallback( async def get_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None: + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 8bcff329e60a8..b473a20243047 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer + from vllm.config import ModelConfig from vllm.sampling_params import GuidedDecodingParams @@ -28,8 +29,10 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) return XGrammarLogitsProcessor(config) @@ -142,6 +145,7 @@ class GrammarConfig: @classmethod def from_guided_params(cls, guided_params: GuidedDecodingParams, + model_config: ModelConfig, tokenizer: PreTrainedTokenizer, max_threads: int = 8) -> GrammarConfig: From 9792cee6488f76df32bb7154fd45c6303bed948d Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 10 Dec 2024 15:02:26 -0300 Subject: [PATCH 03/10] reverted tokenizer get_vocab_size and adjusted tests Signed-off-by: Wallas Santos --- .../model_executor/test_guided_processors.py | 19 +++++++++++++++---- .../decoder_only/language/test_mistral.py | 4 +++- .../guided_decoding/__init__.py | 8 ++++---- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9f4d81b583141..eaaebc1cd18bf 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -2,13 +2,14 @@ import torch from transformers import AutoTokenizer +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams - +MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') @@ -40,12 +41,22 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + + config = ModelConfig( + MODEL_NAME, + task="generate", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer) + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,7 +70,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer) + json_request, tokenizer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 420e05ccd9551..d4736337849a7 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -303,7 +303,9 @@ def test_mistral_guided_decoding( model: str, guided_backend: str, ) -> None: - with vllm_runner(model, tokenizer_mode="mistral") as vllm_model: + with vllm_runner(model, + dtype='bfloat16', + tokenizer_mode="mistral") as vllm_model: guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, backend=guided_backend) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4f33da3677631..467dbe38f6a1c 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -107,7 +107,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer) + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -115,10 +115,10 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: loop = asyncio.get_event_loop() - f = get_guided_decoding_logits_processor(guided_params, tokenizer) + f = get_guided_decoding_logits_processor(guided_params, tokenizer, model_config) res = loop.run_until_complete(f) return res \ No newline at end of file From b3cb57175e7486bf3cf76e4f436ba332c11efc53 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 10 Dec 2024 15:13:16 -0300 Subject: [PATCH 04/10] fix linting Signed-off-by: Wallas Santos --- tests/model_executor/test_guided_processors.py | 4 +++- tests/models/decoder_only/language/test_mistral.py | 3 +-- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- vllm/model_executor/guided_decoding/__init__.py | 8 ++++---- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index eaaebc1cd18bf..0c97ebacf81a7 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -10,6 +10,8 @@ from vllm.sampling_params import GuidedDecodingParams MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' + + def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') @@ -41,7 +43,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): - + config = ModelConfig( MODEL_NAME, task="generate", diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index d4736337849a7..bdc1571784b5d 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -303,8 +303,7 @@ def test_mistral_guided_decoding( model: str, guided_backend: str, ) -> None: - with vllm_runner(model, - dtype='bfloat16', + with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 92329e9d011a3..60dccd7a0812c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -551,7 +551,7 @@ async def build_guided_decoding_logits_processor_async( guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, + guided_params=guided_decoding, tokenizer=tokenizer, model_config=model_config) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4d477907550c1..6eca304b45f07 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2010,7 +2010,7 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, + guided_params=guided_decoding, tokenizer=tokenizer, model_config=self.model_config) if processor: diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 467dbe38f6a1c..af0e12d1e1948 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -87,8 +87,7 @@ def maybe_backend_fallback( async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizer, + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead @@ -119,6 +118,7 @@ def get_local_guided_decoding_logits_processor( model_config: ModelConfig) -> LogitsProcessor | None: loop = asyncio.get_event_loop() - f = get_guided_decoding_logits_processor(guided_params, tokenizer, model_config) + f = get_guided_decoding_logits_processor(guided_params, tokenizer, + model_config) res = loop.run_until_complete(f) - return res \ No newline at end of file + return res From 4ce6b28765b5bdddfc15dff60064c2589868686e Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 11 Dec 2024 11:55:44 -0300 Subject: [PATCH 05/10] minor refactor on xgrammar_decoding Signed-off-by: Wallas Santos --- .../guided_decoding/xgrammar_decoding.py | 116 ++++++++++-------- 1 file changed, 64 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index b473a20243047..be285b0150c93 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any import torch from transformers import PreTrainedTokenizerFast @@ -38,12 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor( return XGrammarLogitsProcessor(config) -class TokenizerData(NamedTuple): +@dataclass(frozen=True) +class TokenizerData: """Immutable container for cached tokenizer data.""" - encoded_vocab: list[str] - stop_token_ids: list[int] | None - backend_str: str | None - vocab_type: xgr.VocabType | None + encoded_vocab: list[str] = field(default_factory=list) + stop_token_ids: list[int] | None = None + # These fields are mutually exclusive `backend_str` is used to create a + # TokenizeInfo with `TokenizerInfo.from_huggingface` while vocab_type is + # used with the constructor of TokenizeInfo + backend_str: str | None = None + vocab_type: xgr.VocabType | None = None + + def __post_init__(self): + # Check for mutual exclusive + assert not (self.backend_str and self.vocab_type), \ + "backend_str and vocab_type are mutual exclusive" class TokenizerDataCache: @@ -80,6 +89,7 @@ def get_tokenizer_data(cls, if isinstance(tokenizer, PreTrainedTokenizerFast): backend_str = tokenizer.backend_tokenizer.to_str() + vocab_type = None elif isinstance(tokenizer, MistralTokenizer): # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 @@ -108,19 +118,21 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: cache_key = str(config.tokenizer_hash) if cache_key not in cls._cache: - assert config.encoded_vocab is not None + assert config.tokenizer_data is not None + assert config.tokenizer_data.encoded_vocab is not None - if config.backend_str: + config_data = config.tokenizer_data + if config_data.backend_str: tokenizer_info = xgr.TokenizerInfo._create_from_handle( xgr_core.TokenizerInfo.from_huggingface( - config.encoded_vocab, config.backend_str, - config.vocab_size, config.stop_token_ids)) + config_data.encoded_vocab, config_data.backend_str, + config.vocab_size, config_data.stop_token_ids)) else: tokenizer_info = xgr.TokenizerInfo( - config.encoded_vocab, - config.vocab_type, + config_data.encoded_vocab, + config_data.vocab_type, vocab_size=config.vocab_size, - stop_token_ids=config.stop_token_ids) + stop_token_ids=config_data.stop_token_ids) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -137,10 +149,11 @@ class GrammarConfig: json_object: bool | None = None max_threads: int = 8 # Only populated if tokenizer_hash not in cache - encoded_vocab: list[str] | None = None - stop_token_ids: list[int] | None = None - backend_str: str | None = None - vocab_type: xgr.VocabType = xgr.VocabType.RAW + # encoded_vocab: list[str] | None = None + # stop_token_ids: list[int] | None = None + # backend_str: str | None = None + # vocab_type: xgr.VocabType | None = None + tokenizer_data: TokenizerData | None = None @classmethod def from_guided_params(cls, @@ -151,31 +164,24 @@ def from_guided_params(cls, tokenizer_hash = hash(tokenizer) # Only get tokenizer data if not already cached - if tokenizer_hash in TokenizerDataCache._cache: - encoded_vocab = None - stop_token_ids = None - backend_str = None - vocab_type = xgr.VocabType.RAW - else: - tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) - encoded_vocab = tokenizer_data.encoded_vocab - stop_token_ids = tokenizer_data.stop_token_ids - backend_str = tokenizer_data.backend_str - vocab_type = tokenizer_data.vocab_type + tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) \ + if tokenizer_hash not in TokenizerDataCache._cache else None if guided_params.json: if not isinstance(guided_params.json, str): json_str = json.dumps(guided_params.json) else: json_str = guided_params.json - return cls(json_str=json_str, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads, - vocab_type=vocab_type) + return cls( + json_str=json_str, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + # encoded_vocab=encoded_vocab, + # stop_token_ids=stop_token_ids, + # backend_str=backend_str, + # vocab_type=vocab_type + tokenizer_data=tokenizer_data) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -189,23 +195,29 @@ def from_guided_params(cls, f"Conversion error: {str(e)}") from e else: grammar_str = guided_params.grammar - return cls(grammar_str=grammar_str, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads, - vocab_type=vocab_type) + return cls( + grammar_str=grammar_str, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data + # encoded_vocab=encoded_vocab, + # stop_token_ids=stop_token_ids, + # backend_str=backend_str, + # vocab_type=vocab_type + ) elif guided_params.json_object: - return cls(json_object=True, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads, - vocab_type=vocab_type) + return cls( + json_object=True, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + # encoded_vocab=encoded_vocab, + # stop_token_ids=stop_token_ids, + # backend_str=backend_str, + # vocab_type=vocab_type + ) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" From d7c7161b47c15a66bd07f707a9e7595e52550739 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 11 Dec 2024 14:18:42 -0300 Subject: [PATCH 06/10] code cleanup Signed-off-by: Wallas Santos --- .../guided_decoding/xgrammar_decoding.py | 46 ++++++------------- vllm/transformers_utils/tokenizers/mistral.py | 2 +- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index be285b0150c93..eacdc909aac68 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -43,9 +43,9 @@ class TokenizerData: """Immutable container for cached tokenizer data.""" encoded_vocab: list[str] = field(default_factory=list) stop_token_ids: list[int] | None = None - # These fields are mutually exclusive `backend_str` is used to create a - # TokenizeInfo with `TokenizerInfo.from_huggingface` while vocab_type is - # used with the constructor of TokenizeInfo + # These fields are mutually exclusive: `backend_str` is used to create a + # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is + # used within the constructor of TokenizeInfo backend_str: str | None = None vocab_type: xgr.VocabType | None = None @@ -148,11 +148,6 @@ class GrammarConfig: grammar_str: str | None = None json_object: bool | None = None max_threads: int = 8 - # Only populated if tokenizer_hash not in cache - # encoded_vocab: list[str] | None = None - # stop_token_ids: list[int] | None = None - # backend_str: str | None = None - # vocab_type: xgr.VocabType | None = None tokenizer_data: TokenizerData | None = None @classmethod @@ -172,16 +167,11 @@ def from_guided_params(cls, json_str = json.dumps(guided_params.json) else: json_str = guided_params.json - return cls( - json_str=json_str, - vocab_size=model_config.hf_text_config.vocab_size, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads, - # encoded_vocab=encoded_vocab, - # stop_token_ids=stop_token_ids, - # backend_str=backend_str, - # vocab_type=vocab_type - tokenizer_data=tokenizer_data) + return cls(json_str=json_str, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -195,17 +185,11 @@ def from_guided_params(cls, f"Conversion error: {str(e)}") from e else: grammar_str = guided_params.grammar - return cls( - grammar_str=grammar_str, - vocab_size=model_config.hf_text_config.vocab_size, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads, - tokenizer_data=tokenizer_data - # encoded_vocab=encoded_vocab, - # stop_token_ids=stop_token_ids, - # backend_str=backend_str, - # vocab_type=vocab_type - ) + return cls(grammar_str=grammar_str, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.json_object: return cls( json_object=True, @@ -213,10 +197,6 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, - # encoded_vocab=encoded_vocab, - # stop_token_ids=stop_token_ids, - # backend_str=backend_str, - # vocab_type=vocab_type ) else: raise ValueError( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index b360bf4e38b13..17d722e3d88fe 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -320,7 +320,7 @@ def _token_to_id(t: str): return decoded - # WARN: Outlines logits processors can be overwrite this method. + # WARN: Outlines logits processors can overwrite this method. # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer # for more. def decode(self, From d61257d78881a27d00bbd70056cd9f84482a0f5e Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 12 Dec 2024 19:37:09 -0300 Subject: [PATCH 07/10] answering to reviews suggestion Signed-off-by: Wallas Santos --- .../model_executor/test_guided_processors.py | 34 ++++++++++++++-- .../guided_decoding/__init__.py | 40 +++++++++++++------ .../guided_decoding/xgrammar_decoding.py | 9 +++++ 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 0c97ebacf81a7..9702a60cc75e7 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,14 +1,22 @@ +import contextlib +import pickle + import pytest import torch from transformers import AutoTokenizer from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor, + get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData from vllm.sampling_params import GuidedDecodingParams +with contextlib.suppress(ImportError): + import xgrammar as xgr + MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' @@ -41,7 +49,9 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -async def test_guided_logits_processor_black_box(backend: str, sample_regex, +@pytest.mark.parametrize("is_local", [True, False]) +async def test_guided_logits_processor_black_box(backend: str, is_local: bool, + sample_regex, sample_json_schema): config = ModelConfig( @@ -57,8 +67,11 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer, config) + + regex_lp = get_local_guided_decoding_logits_processor( + regex_request, tokenizer, config) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -97,3 +110,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") + + +def test_pickle_xgrammar_tokenizer_data(): + + tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) + pickled = pickle.dumps(tokenizer_data) + + assert pickled is not None + + depickled: TokenizerData = pickle.loads(pickled) + + assert depickled is not None + assert depickled.vocab_type == xgr.VocabType.RAW diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index af0e12d1e1948..83f270aac23d7 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -97,6 +96,34 @@ async def get_guided_decoding_logits_processor( get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( guided_params, tokenizer) + + return _get_local_guided_decoding_logits_processor(guided_params, + tokenizer, model_config) + + +def get_local_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + + guided_params = maybe_backend_fallback(guided_params) + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_local_outlines_guided_decoding_logits_processor) + return get_local_outlines_guided_decoding_logits_processor( + guided_params, tokenizer) + + return _get_local_guided_decoding_logits_processor(guided_params, + tokenizer, model_config) + + +def _get_local_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + + assert guided_params.backend != 'outlines' + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -111,14 +138,3 @@ async def get_guided_decoding_logits_processor( raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") - - -def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: - - loop = asyncio.get_event_loop() - f = get_guided_decoding_logits_processor(guided_params, tokenizer, - model_config) - res = loop.run_until_complete(f) - return res diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index eacdc909aac68..a01e4c35fc6da 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -122,6 +122,15 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: assert config.tokenizer_data.encoded_vocab is not None config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + # - If tokenizer_data has backend_str set, use + # xgr_core.TokenizerInfo.from_huggingface (a C++ bind). + # - Otherwise, use the default constructor with vocab_type. + # - xgr_core.TokenizerInfo.from_huggingface != + # xgr.TokenizerInfo.from_huggingface. if config_data.backend_str: tokenizer_info = xgr.TokenizerInfo._create_from_handle( xgr_core.TokenizerInfo.from_huggingface( From 78e7dc2db7fee07bace560df4934074d9309ef98 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Mon, 16 Dec 2024 16:25:14 -0300 Subject: [PATCH 08/10] applying PR suggestions Signed-off-by: Wallas Santos --- .../model_executor/test_guided_processors.py | 9 ++++--- .../guided_decoding/__init__.py | 27 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9702a60cc75e7..436a67fbab438 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,4 +1,3 @@ -import contextlib import pickle import pytest @@ -14,9 +13,6 @@ from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData from vllm.sampling_params import GuidedDecodingParams -with contextlib.suppress(ImportError): - import xgrammar as xgr - MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' @@ -114,6 +110,11 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): def test_pickle_xgrammar_tokenizer_data(): + # TODO: move to another test file for xgrammar + try: + import xgrammar as xgr + except ImportError: + pytest.skip("Could not import xgrammar to run test") tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) pickled = pickle.dumps(tokenizer_data) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 83f270aac23d7..e631aec928ec5 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -96,15 +96,25 @@ async def get_guided_decoding_logits_processor( get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) - return _get_local_guided_decoding_logits_processor(guided_params, - tokenizer, model_config) + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") def get_local_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig) -> LogitsProcessor | None: - guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': @@ -113,17 +123,6 @@ def get_local_guided_decoding_logits_processor( get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_params, tokenizer) - - return _get_local_guided_decoding_logits_processor(guided_params, - tokenizer, model_config) - - -def _get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: - - assert guided_params.backend != 'outlines' - if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) From bfedea77a4c6962a6b5b3093d9770b2b15f38b3b Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Mon, 16 Dec 2024 16:56:30 -0300 Subject: [PATCH 09/10] fix import tokenizer_data on test Signed-off-by: Wallas Santos --- tests/model_executor/test_guided_processors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 436a67fbab438..3334c0df149b5 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -10,7 +10,6 @@ get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) -from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData from vllm.sampling_params import GuidedDecodingParams MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' @@ -115,6 +114,9 @@ def test_pickle_xgrammar_tokenizer_data(): import xgrammar as xgr except ImportError: pytest.skip("Could not import xgrammar to run test") + + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + TokenizerData) tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) pickled = pickle.dumps(tokenizer_data) From 0173c2dc78a96a24dbafbe62d07dcca4ee61bda5 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 17 Dec 2024 13:18:23 -0300 Subject: [PATCH 10/10] set lark to 1.2.2 Signed-off-by: Wallas Santos --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 11984260c580d..4c84d9f659ecf 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,12 +14,13 @@ aiohttp openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) uvicorn[standard] pydantic >= 2.9 # Required for fastapi >= 0.113.0 -pillow # Required for image processing prometheus_client >= 0.18.0 +pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.9 +lark == 1.2.2 xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317