diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3bef3f3226062..8d39e6f14a59c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -194,8 +194,8 @@ Text Generation (``--task generate``) - - ✅︎ * - :code:`GraniteForCausalLM` - - Granite 3.0, PowerLM - - :code:`ibm-granite/granite-3.0-2b-base`, :code:`ibm-granite/granite-3.0-8b-instruct`, :code:`ibm/PowerLM-3b`, etc. + - Granite 3.0, Granite 3.1, PowerLM + - :code:`ibm-granite/granite-3.0-2b-base`, :code:`ibm-granite/granite-3.1-8b-instruct`, :code:`ibm/PowerLM-3b`, etc. - ✅︎ - ✅︎ * - :code:`GraniteMoeForCausalLM` diff --git a/docs/source/usage/tool_calling.md b/docs/source/usage/tool_calling.md index f8be023307b0c..34b26647a959f 100644 --- a/docs/source/usage/tool_calling.md +++ b/docs/source/usage/tool_calling.md @@ -170,6 +170,12 @@ Recommended flags: `--tool-call-parser granite --chat-template examples/tool_cha `examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. +* `ibm-granite/granite-3.1-8b-instruct` + +Recommended flags: `--tool-call-parser granite` + +The chat template from Huggingface can be used directly. Parallel function calls are supported. + * `ibm-granite/granite-20b-functioncalling` Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` @@ -284,4 +290,3 @@ Then you can use this plugin in the command line like this. --tool-call-parser example \ --chat-template \ ``` - diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 6818ac44b2478..2241f1846e746 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -103,7 +103,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "supports_rocm": False, }, - "granite8b": { + "granite-3.0-8b": { "model": "ibm-granite/granite-3.0-8b-instruct", "arguments": [ @@ -111,6 +111,14 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") ], }, + "granite-3.1-8b": { + "model": "ibm-granite/granite-3.1-8b-instruct", + "arguments": [ + "--tool-call-parser", + "granite", + ], + "supports_parallel": True, + }, "internlm": { "model": "internlm/internlm2_5-7b-chat", diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index dae481a2154a1..8aefcd8d58a39 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) + # for granite 3.0, the token `<|tool_call|>` self.bot_token = "<|tool_call|>" + # for granite 3.1, the string `` + self.bot_string = "" def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: - # remove whitespace and the BOT token if it exists - stripped = model_output.strip().removeprefix(self.bot_token).lstrip() + stripped = model_output.strip()\ + .removeprefix(self.bot_token)\ + .removeprefix(self.bot_string)\ + .lstrip() if not stripped or stripped[0] != '[': return ExtractedToolCallInformation(tools_called=False, tool_calls=[], @@ -91,6 +96,9 @@ def extract_tool_calls_streaming( if current_text[start_idx:].startswith(self.bot_token): start_idx = consume_space(start_idx + len(self.bot_token), current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) if not current_text or start_idx >= len(current_text)\ or current_text[start_idx] != '[': return DeltaMessage(content=delta_text)