diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 32d43866c2fed..94520b3e3adb9 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -125,6 +125,7 @@ class XGrammarLogitsProcessor: matchers: List[xgr.GrammarMatcher] = field(default_factory=list) batch_size: int = 1 token_bitmask: Optional[torch.Tensor] = None + prefilled: boolean = False def __getstate__(self) -> Dict[str, Any]: return {'config': self.config} @@ -136,6 +137,7 @@ def __setstate__(self, state: Dict[str, Any]): self.matchers = [] self.batch_size = 1 self.token_bitmask = None + self.prefilled = False def _ensure_ctx(self): """Lazily initialize the processor in the worker process""" @@ -158,6 +160,15 @@ def __call__(self, input_ids: List[int], self.token_bitmask = xgr.allocate_token_bitmask( self.batch_size, self.config.vocab_size) + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + for i, matcher in enumerate(self.matchers): if not matcher.is_terminated(): matcher.fill_next_token_bitmask(self.token_bitmask, i)