diff --git a/aphrodite/modeling/guided_decoding/__init__.py b/aphrodite/modeling/guided_decoding/__init__.py index 139b8f78b..ffac0c6b7 100644 --- a/aphrodite/modeling/guided_decoding/__init__.py +++ b/aphrodite/modeling/guided_decoding/__init__.py @@ -8,11 +8,6 @@ GuidedDecodingRequest) from aphrodite.triton_utils import HAS_TRITON -if HAS_TRITON: - from aphrodite.modeling.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, - get_outlines_guided_decoding_logits_processor) - async def get_guided_decoding_logits_processor( guided_decoding_backend: str, request: Union[CompletionRequest, @@ -20,6 +15,8 @@ async def get_guided_decoding_logits_processor( tokenizer) -> Optional[LogitsProcessorFunc]: request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + from aphrodite.modeling.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) if HAS_TRITON: return await get_outlines_guided_decoding_logits_processor( request, tokenizer) @@ -42,6 +39,8 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + from aphrodite.modeling.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': diff --git a/aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py b/aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py index d21af3289..f5a814621 100644 --- a/aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py +++ b/aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py @@ -16,12 +16,6 @@ from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501 build_aphrodite_logits_processor, build_aphrodite_token_enforcer_tokenizer_data) -from aphrodite.triton_utils import HAS_TRITON - -if HAS_TRITON: - from aphrodite.modeling.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, - get_outlines_guided_decoding_logits_processor) async def get_lm_format_enforcer_guided_decoding_logits_processor( @@ -46,6 +40,9 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( elif request.guided_regex: character_level_parser = RegexParser(request.guided_regex) elif request.guided_grammar: + from aphrodite.modeling.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) + # CFG grammar not supported by LMFE, revert to outlines return await get_outlines_guided_decoding_logits_processor( request, tokenizer) @@ -83,6 +80,9 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( elif guided_options.guided_regex: character_level_parser = RegexParser(guided_options.guided_regex) elif guided_options.guided_grammar: + from aphrodite.modeling.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor) + # CFG grammar not supported by LMFE, revert to outlines return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer)