Skip to content

Commit

Permalink
🐛 Bug: Fix the bug of tool use request body format error
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 4, 2024
1 parent 0ce2715 commit 44caf41
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 34 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ __pycache__
.vscode
node_modules
.wrangler
.pytest_cache
.pytest_cache
*.jpg
*.json
13 changes: 9 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ class ImageGenerationRequest(BaseModel):

class FunctionParameter(BaseModel):
type: str
properties: Dict[str, Dict[str, str]]
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
required: List[str]

# 定义 Function 模型
class Function(BaseModel):
name: str
description: str
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)

# 定义 Tool 模型
class Tool(BaseModel):
type: str
function: Function
Expand Down Expand Up @@ -58,6 +56,13 @@ class Message(BaseModel):
class Config:
extra = "allow" # 允许额外的字段

class FunctionChoice(BaseModel):
name: str

class ToolChoice(BaseModel):
type: str
function: Optional[FunctionChoice] = None

class RequestModel(BaseModel):
model: str
messages: List[Message]
Expand All @@ -72,5 +77,5 @@ class RequestModel(BaseModel):
frequency_penalty: Optional[float] = 0.0
n: Optional[int] = 1
user: Optional[str] = None
tool_choice: Optional[str] = None
tool_choice: Optional[Union[str, ToolChoice]] = None
tools: Optional[List[Tool]] = None
56 changes: 30 additions & 26 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,19 +474,21 @@ async def get_vertex_claude_payload(request, engine, provider):
tools.append(json_tool)
payload["tools"] = tools
if "tool_choice" in payload:
if payload["tool_choice"]["type"] == "auto":
payload["tool_choice"] = {
"type": "auto"
}
if payload["tool_choice"]["type"] == "any":
payload["tool_choice"] = {
"type": "any"
}
if payload["tool_choice"]["type"] == "function":
payload["tool_choice"] = {
"type": "tool",
"name": payload["tool_choice"]["function"]["name"]
}
if isinstance(payload["tool_choice"], dict):
if payload["tool_choice"]["type"] == "function":
payload["tool_choice"] = {
"type": "tool",
"name": payload["tool_choice"]["function"]["name"]
}
if isinstance(payload["tool_choice"], str):
if payload["tool_choice"] == "auto":
payload["tool_choice"] = {
"type": "auto"
}
if payload["tool_choice"] == "none":
payload["tool_choice"] = {
"type": "any"
}

if provider.get("tools") == False:
payload.pop("tools", None)
Expand Down Expand Up @@ -746,19 +748,21 @@ async def get_claude_payload(request, engine, provider):
tools.append(json_tool)
payload["tools"] = tools
if "tool_choice" in payload:
if payload["tool_choice"]["type"] == "auto":
payload["tool_choice"] = {
"type": "auto"
}
if payload["tool_choice"]["type"] == "any":
payload["tool_choice"] = {
"type": "any"
}
if payload["tool_choice"]["type"] == "function":
payload["tool_choice"] = {
"type": "tool",
"name": payload["tool_choice"]["function"]["name"]
}
if isinstance(payload["tool_choice"], dict):
if payload["tool_choice"]["type"] == "function":
payload["tool_choice"] = {
"type": "tool",
"name": payload["tool_choice"]["function"]["name"]
}
if isinstance(payload["tool_choice"], str):
if payload["tool_choice"] == "auto":
payload["tool_choice"] = {
"type": "auto"
}
if payload["tool_choice"] == "none":
payload["tool_choice"] = {
"type": "any"
}

if provider.get("tools") == False:
payload.pop("tools", None)
Expand Down
5 changes: 2 additions & 3 deletions test/test_nostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def get_model_response(image_base64):
]

payload = {

"model": "claude-3-5-sonnet",
"messages": [
{
Expand All @@ -64,7 +63,7 @@ def get_model_response(image_base64):
]
}
],
"stream": True,
# "stream": True,
"tools": tools,
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
"max_tokens": 300
Expand Down Expand Up @@ -117,5 +116,5 @@ def main(image_path):
print("\n無法解析回應。")

if __name__ == "__main__":
image_path = "00001 (8).jpg" # 替換為您的圖像路徑
image_path = "1.jpg" # 替換為您的圖像路徑
main(image_path)

0 comments on commit 44caf41

Please sign in to comment.