Skip to content

Commit

Permalink
fix: lint error
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Nov 30, 2024
1 parent cef4201 commit 919e5f8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
3 changes: 2 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"])
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down
35 changes: 17 additions & 18 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# noqa: UP007
from __future__ import annotations

import json, torch

import json
import torch
from transformers import PreTrainedTokenizerFast
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, Any, Optional, List
from typing import TYPE_CHECKING, Dict, Optional, List, Any

try:
import xgrammar as xgr
Expand All @@ -23,7 +24,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
max_threads=8):
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
Expand Down Expand Up @@ -57,10 +58,9 @@ def from_guided_params(cls,
key=lambda x: x[1])
]
except AttributeError as e:
msg = (
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer "
"should have a get_vocab method.")
raise ValueError(msg) from e
raise ValueError(
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer should have a get_vocab method."
) from e

stop_token_ids = None
backend_str = xgr.VocabType.RAW
Expand All @@ -79,11 +79,6 @@ def from_guided_params(cls,
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
else:
logger.warning(
"When constructing TokenizerInfo from a huggingface tokenizer, "
"stop_token_ids is neither provided by user nor found from the tokenizer. "
"It will be automatically detected.")

if guided_params.json:
if not isinstance(guided_params.json, str):
Expand Down Expand Up @@ -124,8 +119,8 @@ class XGrammarLogitsProcessor:
ctx: Optional[xgr.CompiledGrammar] = None
matchers: List[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = 1
token_bitmask: Optional[torch.Tensor] = None
prefilled: boolean = False
token_bitmask: torch.Tensor = None
prefilled: bool = False

def __getstate__(self) -> Dict[str, Any]:
return {'config': self.config}
Expand All @@ -152,7 +147,9 @@ def _ensure_ctx(self):

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
if self.ctx is None: self._ensure_ctx()
if self.ctx is None:
self._ensure_ctx()

if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
Expand All @@ -174,9 +171,11 @@ def __call__(self, input_ids: List[int],
matcher.fill_next_token_bitmask(self.token_bitmask, i)

device_type = scores.device.type
if device_type != "cuda": scores = scores.to("cpu")
if device_type != "cuda":
scores = scores.to("cpu")
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda": scores = scores.to(device_type)
if device_type != "cuda":
scores = scores.to(device_type)

return scores

0 comments on commit 919e5f8

Please sign in to comment.