Skip to content

Commit

Permalink
sys_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ehhuang committed Feb 2, 2025
1 parent 15dcc4e commit 6d035b3
Show file tree
Hide file tree
Showing 26 changed files with 147 additions and 48 deletions.
13 changes: 10 additions & 3 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
Expand Down Expand Up @@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_choice: Optional[ToolChoice] = Field(
default=ToolChoice.auto, deprecated="use tool_config instead"
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=None, deprecated="use tool_config instead"
)
tool_config: Optional[ToolConfig] = Field(default=None)

max_infer_iters: int = 10

Expand Down Expand Up @@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None

stream: Optional[bool] = False

tool_config: Optional[ToolConfig] = None

@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
Expand Down Expand Up @@ -327,6 +333,7 @@ async def create_agent_turn(
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

@webmethod(
Expand Down
44 changes: 42 additions & 2 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel):
logprobs: Optional[List[TokenLogProbs]] = None


@json_schema_type
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
:cvar append: Appends the provided system message to the default system prompt:
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""

append = "append"
replace = "replace"


@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
"""

tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
system_message_behavior: SystemMessageBehavior = Field(
default=SystemMessageBehavior.append
)


# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()

tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_config: Optional[ToolConfig] = None

response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
Expand Down Expand Up @@ -406,6 +440,7 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
Expand All @@ -416,15 +451,20 @@ async def chat_completion(
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
"""
Expand Down
17 changes: 17 additions & 0 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
Expand Down Expand Up @@ -138,6 +139,7 @@ async def chat_completion(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
Expand All @@ -146,6 +148,20 @@ async def chat_completion(
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config:
if tool_choice != tool_config.tool_choice:
raise ValueError(
"tool_choice and tool_config.tool_choice must match"
)
if tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
tool_config = ToolConfig(
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
)
params = dict(
model_id=model_id,
messages=messages,
Expand All @@ -156,6 +172,7 @@ async def chat_completion(
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,11 @@ async def _run(
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
Expand Down
9 changes: 8 additions & 1 deletion llama_stack/providers/inline/agents/meta_reference/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
Expand Down Expand Up @@ -146,6 +151,7 @@ async def create_agent_turn(
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
Expand All @@ -154,6 +160,7 @@ async def create_agent_turn(
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def chat_completion(
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
request.messages,
request.tool_prompt_format,
request.tool_config.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ResponseFormat,
TokenLogProbs,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
Expand Down Expand Up @@ -270,6 +271,7 @@ async def chat_completion(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
Expand All @@ -280,11 +282,10 @@ async def chat_completion(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
self.check_model(request)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ToolChoice,
ToolDefinition,
ToolPromptFormat,
ToolConfig,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
Expand Down Expand Up @@ -71,5 +72,6 @@ async def chat_completion(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
5 changes: 3 additions & 2 deletions llama_stack/providers/inline/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
Expand Down Expand Up @@ -159,6 +160,7 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None

Expand All @@ -167,10 +169,9 @@ async def chat_completion(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)

log.info("Sampling params: %s", sampling_params)
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
Expand Down Expand Up @@ -102,6 +103,7 @@ async def chat_completion(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
Expand All @@ -111,11 +113,10 @@ async def chat_completion(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)

if stream:
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/remote/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
Expand Down Expand Up @@ -130,6 +131,7 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
Expand All @@ -142,6 +144,7 @@ async def chat_completion(
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)

if stream:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,16 @@ async def chat_completion(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)

client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
Expand Down Expand Up @@ -208,18 +209,18 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)

if stream:
Expand Down
Loading

0 comments on commit 6d035b3

Please sign in to comment.