From 9792cee6488f76df32bb7154fd45c6303bed948d Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 10 Dec 2024 15:02:26 -0300 Subject: [PATCH] reverted tokenizer get_vocab_size and adjusted tests Signed-off-by: Wallas Santos --- .../model_executor/test_guided_processors.py | 19 +++++++++++++++---- .../decoder_only/language/test_mistral.py | 4 +++- .../guided_decoding/__init__.py | 8 ++++---- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9f4d81b583141..eaaebc1cd18bf 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -2,13 +2,14 @@ import torch from transformers import AutoTokenizer +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams - +MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') @@ -40,12 +41,22 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + + config = ModelConfig( + MODEL_NAME, + task="generate", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer) + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,7 +70,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer) + json_request, tokenizer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 420e05ccd9551..d4736337849a7 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -303,7 +303,9 @@ def test_mistral_guided_decoding( model: str, guided_backend: str, ) -> None: - with vllm_runner(model, tokenizer_mode="mistral") as vllm_model: + with vllm_runner(model, + dtype='bfloat16', + tokenizer_mode="mistral") as vllm_model: guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, backend=guided_backend) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4f33da3677631..467dbe38f6a1c 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -107,7 +107,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer) + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -115,10 +115,10 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: loop = asyncio.get_event_loop() - f = get_guided_decoding_logits_processor(guided_params, tokenizer) + f = get_guided_decoding_logits_processor(guided_params, tokenizer, model_config) res = loop.run_until_complete(f) return res \ No newline at end of file