From 919e5f8fc42f72667779cca685150f6b27cd0361 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 30 Nov 2024 00:23:52 +0000 Subject: [PATCH] fix: lint error Signed-off-by: Aaron Pham --- .../model_executor/test_guided_processors.py | 3 +- .../guided_decoding/xgrammar_decoding.py | 35 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 2dde6dad85e8f..9f4d81b583141 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"]) +@pytest.mark.parametrize("backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 94520b3e3adb9..7bd7f5002c326 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -1,10 +1,11 @@ +# noqa: UP007 from __future__ import annotations -import json, torch - +import json +import torch from transformers import PreTrainedTokenizerFast from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Any, Optional, List +from typing import TYPE_CHECKING, Dict, Optional, List, Any try: import xgrammar as xgr @@ -23,7 +24,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - max_threads=8): + max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, tokenizer=tokenizer, @@ -57,10 +58,9 @@ def from_guided_params(cls, key=lambda x: x[1]) ] except AttributeError as e: - msg = ( - f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer " - "should have a get_vocab method.") - raise ValueError(msg) from e + raise ValueError( + f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer should have a get_vocab method." + ) from e stop_token_ids = None backend_str = xgr.VocabType.RAW @@ -79,11 +79,6 @@ def from_guided_params(cls, tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: stop_token_ids = [tokenizer.eos_token_id] - else: - logger.warning( - "When constructing TokenizerInfo from a huggingface tokenizer, " - "stop_token_ids is neither provided by user nor found from the tokenizer. " - "It will be automatically detected.") if guided_params.json: if not isinstance(guided_params.json, str): @@ -124,8 +119,8 @@ class XGrammarLogitsProcessor: ctx: Optional[xgr.CompiledGrammar] = None matchers: List[xgr.GrammarMatcher] = field(default_factory=list) batch_size: int = 1 - token_bitmask: Optional[torch.Tensor] = None - prefilled: boolean = False + token_bitmask: torch.Tensor = None + prefilled: bool = False def __getstate__(self) -> Dict[str, Any]: return {'config': self.config} @@ -152,7 +147,9 @@ def _ensure_ctx(self): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: - if self.ctx is None: self._ensure_ctx() + if self.ctx is None: + self._ensure_ctx() + if len(self.matchers) == 0: self.matchers = [ xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) @@ -174,9 +171,11 @@ def __call__(self, input_ids: List[int], matcher.fill_next_token_bitmask(self.token_bitmask, i) device_type = scores.device.type - if device_type != "cuda": scores = scores.to("cpu") + if device_type != "cuda": + scores = scores.to("cpu") xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) - if device_type != "cuda": scores = scores.to(device_type) + if device_type != "cuda": + scores = scores.to(device_type) return scores