Skip to content

Commit

Permalink
revert tokenizer.get_vocab_size
Browse files Browse the repository at this point in the history
Signed-off-by: Wallas Santos <[email protected]>
  • Loading branch information
wallashss committed Dec 10, 2024
1 parent a9a1e3c commit 710fcc9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
4 changes: 3 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,9 @@ async def build_guided_decoding_logits_processor_async(
guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)

if processor:
if sampling_params.logits_processors is None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,9 @@ def _build_logits_processors(
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor:
logits_processors.append(processor)

Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams

Expand Down Expand Up @@ -87,7 +88,8 @@ def maybe_backend_fallback(

async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer) -> LogitsProcessor | None:
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams


# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)
Expand Down Expand Up @@ -142,6 +145,7 @@ class GrammarConfig:
@classmethod
def from_guided_params(cls,
guided_params: GuidedDecodingParams,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8) -> GrammarConfig:

Expand Down

0 comments on commit 710fcc9

Please sign in to comment.