Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Automatic detection of chat content format from AST #9919

Merged
merged 27 commits into from
Nov 16, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Nov 1, 2024

This PR renames --chat-template-text-format (introduced by #9358) to --chat-template-content-format and moves it to the CLI parser specific to OpenAI-compatible server. Also, it removes the redundant hardcoded logic for Llama-3.2-Vision (last updated by #9393) since we can now run online inference with --chat-template-content-format openai.

To avoid causing incompatibilities with how users are currently serving Llama-3.2-Vision, I have added code to automatically detect the format to use based on the AST of the provided chat template.

cc @vrdn-23 @ywang96 @heheda12345 @alex-jw-brooks

FIX #10286

Copy link

github-actions bot commented Nov 1, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Nov 1, 2024
@vrdn-23
Copy link
Contributor

vrdn-23 commented Nov 1, 2024

Great idea with the PR @DarkLight1337 !
The problem with auto-detecting is a lot of chat templates do not throw errors with jinja even if they do not fit into the right format which is what made the bug in #9294 so subtle. The content string was just not being looped over and no content was being added to the conversation. I'm not completely familiar with how jinja works so if you figure out a way to detect this, let me know and I can help out!

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 2, 2024

Great idea with the PR @DarkLight1337 !
The problem with auto-detecting is a lot of chat templates do not throw errors with jinja even if they do not fit into the right format which is what made the bug in #9294 so subtle. The content string was just not being looped over and no content was being added to the conversation. I'm not completely familiar with how jinja works so if you figure out a way to detect this, let me know and I can help out!

Right now I am thinking of using Jinja's AST parser and working off that. The basic idea is to detect whether messages[int]['content'] is being treated as a string or a list of dictionaries.

Comment on lines 147 to 217
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname

return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (node.ctx == "load" and _is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key)

if isinstance(node, jinja2.nodes.Getattr):
return (node.ctx == "load" and _is_var_access(node.node, varname)
and node.attr == key)

return False


def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template):
# Search for {%- for message in messages -%} loops
for loop_ast in chat_template_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target

if _is_var_access(loop_iter, "messages"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name


def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template):
for node, message_varname in _iter_nodes_define_message(chat_template_ast):
# Search for {%- for content in message['content'] -%} loops
for loop_ast in node.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target

if _is_attr_access(loop_iter, message_varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_iter, loop_target.name


def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
jinja_ast = jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return default

try:
next(_iter_nodes_define_content_item(jinja_ast))
except StopIteration:
return "string"
else:
return "openai"
Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles the most common case of iterating through OpenAI-formatted message['content'] as a list, assuming that no relevant variable reassignments are made other than those in the for loops.

Please tell me if you are aware of any chat templates that don't work with this code.

@@ -380,10 +521,7 @@ def load_chat_template(

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")

logger.info("Using supplied chat template:\n%s", resolved_chat_template)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thie logging line has been moved to vllm/entrypoints/openai/api_server.py.

Comment on lines +1083 to +1095
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These arguments are present in other chat-based APIs so I added them here as well.

@DarkLight1337 DarkLight1337 changed the title [Frontend] Rename and auto-detect --chat-template-text-format [Frontend] Automatic detection of chat template content format using AST parsing Nov 2, 2024
@DarkLight1337 DarkLight1337 changed the title [Frontend] Automatic detection of chat template content format using AST parsing [Frontend] Automatic detection of chat content format from AST Nov 2, 2024
@DarkLight1337 DarkLight1337 force-pushed the chat-template-content-format branch from 8ce013b to e262745 Compare November 2, 2024 16:58
@mergify mergify bot added the ci/build label Nov 2, 2024
Copy link

mergify bot commented Nov 2, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @DarkLight1337 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
loop_target = loop_ast.target

for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also handle cases where content is reassingned?

Pseudo code example:

for message in messages:
    content = message["content"]
    for c in content:
        do_stuff(c)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, currently it doesn't do that. Let me think a bit about how to handle this...

Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote some code to enable this, but found that this causes false positives. In particular, tool_chat_template_mistral.jinja is detected as having OpenAI format because of L54 and L57 in the chat template.

It would be quite complicated to condition the detected content format based on message["role"]... we might as well build a CFG, otherwise our code would be quite unmaintainable 😅

Let's keep this simple for now. I am by no means an expert in program analysis.

Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, here's the code I changed to handle reassignment of message["content"]:

diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index d6ab3c04e..c0edb7c24 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -204,21 +204,47 @@ def _is_var_or_elems_access(
     ) # yapf: enable
 
 
-def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
-    # Global variable that is implicitly defined at the root
-    yield root, varname
+def _iter_nodes_assign_var_or_elems(
+    root: jinja2.nodes.Node,
+    varname: str,
+    key: Optional[str] = None,
+):
+    if key is None:
+        # Global variable that is implicitly defined at the root
+        yield root, varname
 
     related_varnames: List[str] = [varname]
     for assign_ast in root.find_all(jinja2.nodes.Assign):
         lhs = assign_ast.target
         rhs = assign_ast.node
 
-        if any(_is_var_or_elems_access(rhs, name) for name in related_varnames):
+        if any(_is_var_or_elems_access(rhs, related_varname, key)
+               for related_varname in related_varnames):
             assert isinstance(lhs, jinja2.nodes.Name)
             yield assign_ast, lhs.name
             related_varnames.append(lhs.name)
 
 
+def _iter_nodes_assign_elem(
+    root: jinja2.nodes.Node,
+    varname: str,
+    key: Optional[str] = None,
+):
+    for loop_ast in root.find_all(jinja2.nodes.For):
+        loop_iter = loop_ast.iter
+        loop_target = loop_ast.target
+
+        if _is_var_or_elems_access(loop_iter, varname, key):
+            assert isinstance(loop_target, jinja2.nodes.Name)
+            yield loop_ast, loop_target.name
+            break
+
+    if key is not None:
+        for _, related_varname in _iter_nodes_assign_var_or_elems(
+            root, varname, key):
+            yield from _iter_nodes_assign_elem(root, related_varname)
+
+
 # NOTE: The proper way to handle this is to build a CFG so that we can handle
 # the scope in which each variable is defined, but that is too complicated
 def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
@@ -227,16 +253,8 @@ def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
         for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
     ]
 
-    # Search for {%- for message in messages -%} loops
-    for loop_ast in root.find_all(jinja2.nodes.For):
-        loop_iter = loop_ast.iter
-        loop_target = loop_ast.target
-
-        for varname in messages_varnames:
-            if _is_var_or_elems_access(loop_iter, varname):
-                assert isinstance(loop_target, jinja2.nodes.Name)
-                yield loop_ast, loop_target.name
-                break
+    for messages_varname in messages_varnames:
+        yield from _iter_nodes_assign_elem(root, messages_varname)
 
 
 def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
@@ -244,16 +262,8 @@ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
         varname for _, varname in _iter_nodes_assign_messages_item(root)
     ]
 
-    # Search for {%- for content in message['content'] -%} loops
-    for loop_ast in root.find_all(jinja2.nodes.For):
-        loop_iter = loop_ast.iter
-        loop_target = loop_ast.target
-
-        for varname in message_varnames:
-            if _is_var_or_elems_access(loop_iter, varname, "content"):
-                assert isinstance(loop_target, jinja2.nodes.Name)
-                yield loop_ast, loop_target.name
-                break
+    for message_varname in message_varnames:
+        yield from _iter_nodes_assign_elem(root, message_varname, "content")
 
 
 def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:

Copy link

mergify bot commented Nov 14, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DarkLight1337.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 14, 2024
@mergify mergify bot removed the needs-rebase label Nov 14, 2024
@DarkLight1337
Copy link
Member Author

@maxdebayser does this look good to you now?

Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 , I've left a few comments, I think the one about the assignment search is worth of your consideration but other than that it looks good to me.

vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Copy link

mergify bot commented Nov 15, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DarkLight1337.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 15, 2024
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 15, 2024
@njhill
Copy link
Member

njhill commented Nov 15, 2024

@DarkLight1337 looks like there's one test failure remaining

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 15, 2024

The network is quite slow right now (HF keeps timing out for a lot of other PRs). This error comes from not being able to download the video before timeout occurs. (It passes when I run it locally.) Can you approve this PR? Then I'll retry the CI once the network returns to normal.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 DarkLight1337 merged commit 32e46e0 into main Nov 16, 2024
52 checks passed
@DarkLight1337 DarkLight1337 deleted the chat-template-content-format branch November 16, 2024 05:35
coolkp pushed a commit to coolkp/vllm that referenced this pull request Nov 20, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Dec 3, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: vllm serve works incorrect for (some) Vision LM models
5 participants