Skip to content

Commit

Permalink
fix: correctness for prefill
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Nov 30, 2024
1 parent ee8e796 commit cef4201
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class XGrammarLogitsProcessor:
matchers: List[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = 1
token_bitmask: Optional[torch.Tensor] = None
prefilled: boolean = False

Check failure on line 128 in vllm/model_executor/guided_decoding/xgrammar_decoding.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "boolean" is not defined [name-defined]

Check failure on line 128 in vllm/model_executor/guided_decoding/xgrammar_decoding.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "boolean" is not defined [name-defined]

Check failure on line 128 in vllm/model_executor/guided_decoding/xgrammar_decoding.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "boolean" is not defined [name-defined]

Check failure on line 128 in vllm/model_executor/guided_decoding/xgrammar_decoding.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "boolean" is not defined [name-defined]

def __getstate__(self) -> Dict[str, Any]:
return {'config': self.config}
Expand All @@ -136,6 +137,7 @@ def __setstate__(self, state: Dict[str, Any]):
self.matchers = []
self.batch_size = 1
self.token_bitmask = None
self.prefilled = False

def _ensure_ctx(self):
"""Lazily initialize the processor in the worker process"""
Expand All @@ -158,6 +160,15 @@ def __call__(self, input_ids: List[int],
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.config.vocab_size)

if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)

for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.token_bitmask, i)
Expand Down

0 comments on commit cef4201

Please sign in to comment.