From 26725f50a0d836b8e1fe5184029165e1b098654b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Nov 2024 01:15:46 +0000 Subject: [PATCH] Fix auto detection not working when loop iterable is being piped to a filter Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 008f0f06a4b37..9d14b4e59342a 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -163,27 +163,32 @@ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: return False +def _iter_self_and_descendants(node: jinja2.nodes.Node): + yield node + yield from node.find_all(jinja2.nodes.Node) + + def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): # Search for {%- for message in messages -%} loops for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter loop_target = loop_ast.target - if _is_var_access(loop_iter, "messages"): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name + for loop_iter_desc in _iter_self_and_descendants(loop_ast.iter): + if _is_var_access(loop_iter_desc, "messages"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): for node, message_varname in _iter_nodes_define_message(chat_template_ast): # Search for {%- for content in message['content'] -%} loops for loop_ast in node.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter loop_target = loop_ast.target - if _is_attr_access(loop_iter, message_varname, "content"): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_iter, loop_target.name + for loop_iter_desc in _iter_self_and_descendants(loop_ast.iter): + if _is_attr_access(loop_iter_desc, message_varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_iter_desc, loop_target.name def _detect_content_format(