From ca5d141a8cdb56ef0715cf2e5cc43590d2d321b5 Mon Sep 17 00:00:00 2001 From: mgoin Date: Sat, 21 Dec 2024 02:23:52 +0000 Subject: [PATCH] [Bugfix] Fall back to outlines for grammars that can't convert to GBNF Signed-off-by: mgoin --- .../guided_decoding/__init__.py | 87 ++++--------------- .../outlines_logits_processors.py | 23 +++-- .../{xgrammar_utils.py => utils.py} | 69 +++++++++++++++ .../guided_decoding/xgrammar_decoding.py | 4 +- 4 files changed, 102 insertions(+), 81 deletions(-) rename vllm/model_executor/guided_decoding/{xgrammar_utils.py => utils.py} (72%) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 550b892303feb..694c5b68b1cbd 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,6 +3,9 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark, + has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) from vllm.platforms import CpuArchEnum, current_platform if TYPE_CHECKING: @@ -15,76 +18,6 @@ logger = init_logger(__name__) -def has_xgrammar_unsupported_json_features(schema: dict) -> bool: - """Check if JSON schema contains features unsupported by xgrammar.""" - - def check_object(obj: dict) -> bool: - if not isinstance(obj, dict): - return False - - # Check for pattern restrictions - if "pattern" in obj: - return True - - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj for key in [ - "minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf" - ]): - return True - - # Recursively check all nested objects and arrays - for value in obj.values(): - if isinstance(value, dict): - if check_object(value): - return True - elif isinstance(value, list): - for item in value: - if isinstance(item, dict) and check_object(item): - return True - - return False - - return check_object(schema) - - -def has_lmf_unsupported_json_features(schema: dict) -> bool: - """ - Check if JSON schema contains features unsupported - by lm_format_enforcer. - - Known issues: - - Regex patterns: - "grade": { - "type": "string", - "pattern": "^[A-D]$" # Regex pattern - }, - """ - - def check_object(obj: dict) -> bool: - if not isinstance(obj, dict): - return False - - # Check for pattern restrictions - if "pattern" in obj: - return True - - # Recursively check all nested objects and arrays - for value in obj.values(): - if isinstance(value, dict): - if check_object(value): - return True - elif isinstance(value, list): - for item in value: - if isinstance(item, dict) and check_object(item): - return True - - return False - - return check_object(schema) - - def maybe_backend_fallback( guided_params: GuidedDecodingParams) -> GuidedDecodingParams: # lm-format-enforce doesn't support grammar, fallback to xgrammar @@ -127,6 +60,20 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" + # xgrammar only supports GBNF grammars, so we must convert Lark. + # We must check if the grammar is likely Lark and if that + # grammar is convertible to GBNF + elif (guided_params.grammar is not None + and grammar_is_likely_lark(guided_params.grammar)): + try: + convert_lark_to_gbnf(guided_params.grammar) + except Exception: + logger.warning( + "xgrammar does not support Lark grammars and the " + "grammar failed to convert to GBNF. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + if (guided_params.backend == "outlines" and guided_params.json_object is not None): # outlines doesn't support json_object, fallback to xgrammar diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index b63fed1c8a8c3..e4eb3f16e56cf 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,10 +21,11 @@ import numpy as np import torch -from lark import Lark from outlines import grammars from outlines.caching import cache -from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write +from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, + RegexGuide, Write) +from outlines.fsm.parsing import PartialLark from outlines_core.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -34,7 +35,9 @@ class BaseLogitsProcessor: def __init__(self, guide: Guide): self._guide: Guide = guide - self._fsm_state: DefaultDict[int, int] = defaultdict(int) + # CFGState is used for the FSM state for CFGGuide + self._fsm_state: DefaultDict[int, Union[int, + CFGState]] = defaultdict(int) def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: @@ -54,15 +57,13 @@ def __call__(self, input_ids: List[int], # On the first time this is called, we simply re-create # the Lark object. if isinstance(self._guide, CFGGuide): - self._guide.parser = Lark( + self._guide.parser = PartialLark( self._guide.cfg_string, parser="lalr", - lexer="contextual", - propagate_positions=False, - maybe_placeholders=False, - regex=True, import_paths=[grammars.GRAMMAR_PATH], ) + self._fsm_state[seq_id] = CFGState( + parser_state=self._guide.parser.parse(""), prev_token=None) instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) @@ -200,7 +201,8 @@ def convert_token_to_string(token: str) -> str: string = tokenizer.convert_tokens_to_string([token]) # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + if (type(token) is str and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>"): return " " + string return string @@ -211,6 +213,9 @@ def change_decoder( """Sync vLLM's decoder with the outlines by returning list.""" def new_decoder(inp_tokens: List[int]) -> List[str]: + if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 + and isinstance(inp_tokens[0], list)): + inp_tokens = inp_tokens[0] return [decoder(inp_tokens)] return new_decoder diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/utils.py similarity index 72% rename from vllm/model_executor/guided_decoding/xgrammar_utils.py rename to vllm/model_executor/guided_decoding/utils.py index 9a0463964de49..abc11e9373baf 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -1,5 +1,74 @@ import re +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def has_lmf_unsupported_json_features(schema: dict) -> bool: + """ + Check if JSON schema contains features unsupported + by lm_format_enforcer. + + Known issues: + - Regex patterns: + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + """ + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + def grammar_is_likely_lark(grammar_str: str) -> bool: """ diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 5b97f03257502..5e1948977bff4 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -14,8 +14,8 @@ except ImportError: pass -from vllm.model_executor.guided_decoding.xgrammar_utils import ( - convert_lark_to_gbnf, grammar_is_likely_lark) +from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, + grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer if TYPE_CHECKING: