Skip to content

Commit

Permalink
reverted tokenizer get_vocab_size and adjusted tests
Browse files Browse the repository at this point in the history
Signed-off-by: Wallas Santos <[email protected]>
  • Loading branch information
wallashss committed Dec 10, 2024
1 parent 710fcc9 commit 9792cee
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
19 changes: 15 additions & 4 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from transformers import AutoTokenizer

from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams


MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down Expand Up @@ -40,12 +41,22 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

config = ModelConfig(
MODEL_NAME,
task="generate",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)
regex_request, tokenizer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,7 +70,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
json_request, tokenizer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def test_mistral_guided_decoding(
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, tokenizer_mode="mistral") as vllm_model:
with vllm_runner(model,
dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,18 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer)
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")


def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:

loop = asyncio.get_event_loop()
f = get_guided_decoding_logits_processor(guided_params, tokenizer)
f = get_guided_decoding_logits_processor(guided_params, tokenizer, model_config)

Check failure on line 122 in vllm/model_executor/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/guided_decoding/__init__.py:122:81: E501 Line too long (84 > 80)
res = loop.run_until_complete(f)
return res

0 comments on commit 9792cee

Please sign in to comment.