diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 68eecaccb2..9d8abd2777 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,6 +33,7 @@ ToolResponse, ToolResponseMessage, UserMessage, + ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_choice: Optional[ToolChoice] = Field( + default=ToolChoice.auto, deprecated="use tool_config instead" + ) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=None, deprecated="use tool_config instead" + ) + tool_config: Optional[ToolConfig] = Field(default=None) max_infer_iters: int = 10 @@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False - + tool_config: Optional[ToolConfig] = None @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): @@ -327,6 +333,7 @@ async def create_agent_turn( stream: Optional[bool] = False, documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 2debce1a7e..2c4e57766a 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel): logprobs: Optional[List[TokenLogProbs]] = None +@json_schema_type +class SystemMessageBehavior(Enum): + """Config for how to override the default system prompt. + + :cvar append: Appends the provided system message to the default system prompt: + https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt- + :cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string + '{{function_definitions}}' to indicate where the function definitions should be inserted. + """ + + append = "append" + replace = "replace" + + +@json_schema_type +class ToolConfig(BaseModel): + """Configuration for tool use. + + :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. + - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. + - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. + - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. + """ + + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + system_message_behavior: SystemMessageBehavior = Field( + default=SystemMessageBehavior.append + ) + + # This is an internally used class +@json_schema_type class ChatCompletionRequest(BaseModel): model: str messages: List[Message] sampling_params: Optional[SamplingParams] = SamplingParams() + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_config: Optional[ToolConfig] = None + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -406,6 +440,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: @@ -416,15 +451,20 @@ async def chat_completion( :param sampling_params: Parameters to control the sampling strategy :param tools: (Optional) List of tool definitions available to the model :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + .. deprecated:: + Use tool_config instead. :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. + .. deprecated:: + Use tool_config instead. :param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options: - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned. + :param tool_config: (Optional) Configuration for tool use. :returns: If stream=False, returns a ChatCompletionResponse with the full completion. If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk """ diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6bb2045bd5..77c0204878 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -24,6 +24,7 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -138,6 +139,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.routing_table.get_model(model_id) if model is None: @@ -146,6 +148,20 @@ async def chat_completion( raise ValueError( f"Model '{model_id}' is an embedding model and does not support chat completions" ) + if tool_config: + if tool_choice != tool_config.tool_choice: + raise ValueError( + "tool_choice and tool_config.tool_choice must match" + ) + if tool_prompt_format != tool_config.tool_prompt_format: + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) + else: + tool_config = ToolConfig( + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + ) params = dict( model_id=model_id, messages=messages, @@ -156,6 +172,7 @@ async def chat_completion( response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) if stream: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 706dd74f1b..248f69b034 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -515,10 +515,11 @@ async def _run( for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP ], - tool_prompt_format=self.agent_config.tool_prompt_format, + tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, sampling_params=sampling_params, + tool_config=self.agent_config.tool_config, ): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index b1844f4d0f..2e574b5613 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -25,7 +25,12 @@ Session, Turn, ) -from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage +from llama_stack.apis.inference import ( + Inference, + ToolConfig, + ToolResponseMessage, + UserMessage, +) from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -146,6 +151,7 @@ async def create_agent_turn( toolgroups: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -154,6 +160,7 @@ async def create_agent_turn( stream=True, toolgroups=toolgroups, documents=documents, + tool_config=tool_config, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index a96409cab2..5d832206d7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -400,7 +400,7 @@ def chat_completion( yield from self.generate( model_input=self.formatter.encode_dialog_prompt( request.messages, - request.tool_prompt_format, + request.tool_config.tool_prompt_format, ), max_gen_len=max_gen_len, temperature=temperature, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 73962ca7f8..24948f16da 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -38,6 +38,7 @@ ResponseFormat, TokenLogProbs, ToolChoice, + ToolConfig, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -270,6 +271,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" @@ -280,11 +282,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) self.check_model(request) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 3920ee1ad9..d34befbd9a 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -17,6 +17,7 @@ ToolChoice, ToolDefinition, ToolPromptFormat, + ToolConfig, ) from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( @@ -71,5 +72,6 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 49dd8316e3..6dc3a8d6bf 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -30,6 +30,7 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -159,6 +160,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: assert self.engine is not None @@ -167,10 +169,9 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) log.info("Sampling params: %s", sampling_params) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 10b51e86b6..5886c3a3e1 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -24,6 +24,7 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -102,6 +103,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: @@ -111,11 +113,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 0b6ce142ce..601edd9326 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -24,6 +24,7 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -130,6 +131,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -142,6 +144,7 @@ async def chat_completion( response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 2964b2aaa0..4de26c6971 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -89,16 +89,16 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( model=model, messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 5c98d20542..f53a970ce5 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -25,6 +25,7 @@ ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -208,6 +209,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -215,11 +217,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index e3f3fefa3e..8e552c8b52 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -99,6 +99,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: @@ -117,10 +118,9 @@ async def chat_completion( sampling_params=sampling_params, response_format=response_format, tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) ) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 99fa8219cd..74cbf3607d 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -79,7 +79,7 @@ def convert_chat_completion_request( # so we exclude it for now warnings.warn("repetition_penalty is not supported") - if request.tool_prompt_format != ToolPromptFormat.json: + if request.tool_config.tool_prompt_format != ToolPromptFormat.json: warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") sampling_options = get_sampling_strategy_options(request.sampling_params) @@ -93,7 +93,11 @@ def convert_chat_completion_request( temperature=sampling_options.get("temperature", 1.0), top_p=sampling_options.get("top_p", 1.0), tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], - tool_choice=request.tool_choice.value if request.tool_choice else None, + tool_choice=( + request.tool_config.tool_choice.value + if request.tool_config.tool_choice + else None + ), ) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 81751e038a..af2376c9cd 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -178,6 +178,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: @@ -193,10 +194,9 @@ async def chat_completion( sampling_params=sampling_params, response_format=response_format, tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ), n=1, ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 43be0fc94b..61003953c5 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -253,9 +253,9 @@ def convert_chat_completion_request( payload.update( tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools] ) - if request.tool_choice: + if request.tool_config.tool_choice: payload.update( - tool_choice=request.tool_choice.value + tool_choice=request.tool_config.tool_choice.value ) # we cannot include tool_choice w/o tools, server will complain if request.logprobs: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 6811d435b9..2c62b72345 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -29,6 +29,7 @@ ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -224,6 +225,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -231,11 +233,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, response_format=response_format, + tool_config=tool_config, ) if stream: return self._stream_chat_completion(request) @@ -322,6 +323,7 @@ async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: params = await self._get_params(request) + print(params) async def _generate_and_convert_to_openai_compat(): if "messages" in params: diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index e5b19426f9..12d42777d7 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -85,10 +85,9 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b601d4b3f9..2b0d7f7bc8 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -125,10 +125,9 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) request_sambanova = await self.convert_chat_completion_request(request) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 7f8c9d8abf..4ed3c9ab8b 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -26,6 +26,7 @@ ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -213,6 +214,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -220,11 +222,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 605b3ce976..e1dd1e1acf 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -24,6 +24,7 @@ ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -198,6 +199,7 @@ async def chat_completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -205,11 +207,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 0cf16f0133..9258c12d0e 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -27,6 +27,7 @@ ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -119,6 +120,7 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -126,11 +128,10 @@ async def chat_completion( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, response_format=response_format, + tool_config=tool_config, ) if stream: return self._stream_chat_completion(request, self.client) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 5e07978718..7538f4a0ba 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -181,7 +181,7 @@ def test_includes_greedy_strategy(self): def test_includes_tool_choice(self): request = self._dummy_chat_completion_request() - request.tool_choice = ToolChoice.required + request.tool_config.tool_choice = ToolChoice.required converted = convert_chat_completion_request(request) diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 4826e89d53..0596758b17 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -13,7 +13,12 @@ ToolPromptFormat, ) -from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage +from llama_stack.apis.inference import ( + ChatCompletionRequest, + SystemMessage, + ToolConfig, + UserMessage, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) @@ -73,7 +78,7 @@ async def test_system_custom_only(self): }, ) ], - tool_prompt_format=ToolPromptFormat.json, + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), ) messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index e497719809..5e33877921 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -49,6 +49,7 @@ SystemMessage, ToolChoice, UserMessage, + SystemMessageBehavior, ) from llama_stack.providers.utils.inference import supported_inference_models @@ -319,7 +320,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]): def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" + assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages existing_system_message = None @@ -368,7 +369,7 @@ def _process(c): has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) if has_custom_tools: - fmt = request.tool_prompt_format or ToolPromptFormat.json + fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() elif fmt == ToolPromptFormat.function_tag: @@ -389,7 +390,7 @@ def _process(c): def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" + assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages existing_system_message = None @@ -419,19 +420,24 @@ def augment_messages_for_tools_llama_3_2( custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] if custom_tools: - fmt = request.tool_prompt_format or ToolPromptFormat.python_list + fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: raise ValueError( - f"Non supported ToolPromptFormat {request.tool_prompt_format}" + f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}" ) tool_gen = PythonListCustomToolGenerator() - tool_template = tool_gen.gen(custom_tools) + + system_prompt = None + if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: + system_prompt = existing_system_message.content + + tool_template = tool_gen.gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n" - if existing_system_message: + if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.append: sys_content += interleaved_content_as_str( existing_system_message.content, sep="\n" )