diff --git a/.gitignore b/.gitignore index 83edc3d..b471fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ __pycache__ .vscode node_modules .wrangler -.pytest_cache \ No newline at end of file +.pytest_cache +*.jpg +*.json \ No newline at end of file diff --git a/models.py b/models.py index 11eb949..44887d9 100644 --- a/models.py +++ b/models.py @@ -10,16 +10,14 @@ class ImageGenerationRequest(BaseModel): class FunctionParameter(BaseModel): type: str - properties: Dict[str, Dict[str, str]] + properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]] required: List[str] -# 定义 Function 模型 class Function(BaseModel): name: str description: str parameters: Optional[FunctionParameter] = Field(default=None, exclude=None) -# 定义 Tool 模型 class Tool(BaseModel): type: str function: Function @@ -58,6 +56,13 @@ class Message(BaseModel): class Config: extra = "allow" # 允许额外的字段 +class FunctionChoice(BaseModel): + name: str + +class ToolChoice(BaseModel): + type: str + function: Optional[FunctionChoice] = None + class RequestModel(BaseModel): model: str messages: List[Message] @@ -72,5 +77,5 @@ class RequestModel(BaseModel): frequency_penalty: Optional[float] = 0.0 n: Optional[int] = 1 user: Optional[str] = None - tool_choice: Optional[str] = None + tool_choice: Optional[Union[str, ToolChoice]] = None tools: Optional[List[Tool]] = None \ No newline at end of file diff --git a/request.py b/request.py index 800a3fe..765c66f 100644 --- a/request.py +++ b/request.py @@ -474,19 +474,21 @@ async def get_vertex_claude_payload(request, engine, provider): tools.append(json_tool) payload["tools"] = tools if "tool_choice" in payload: - if payload["tool_choice"]["type"] == "auto": - payload["tool_choice"] = { - "type": "auto" - } - if payload["tool_choice"]["type"] == "any": - payload["tool_choice"] = { - "type": "any" - } - if payload["tool_choice"]["type"] == "function": - payload["tool_choice"] = { - "type": "tool", - "name": payload["tool_choice"]["function"]["name"] - } + if isinstance(payload["tool_choice"], dict): + if payload["tool_choice"]["type"] == "function": + payload["tool_choice"] = { + "type": "tool", + "name": payload["tool_choice"]["function"]["name"] + } + if isinstance(payload["tool_choice"], str): + if payload["tool_choice"] == "auto": + payload["tool_choice"] = { + "type": "auto" + } + if payload["tool_choice"] == "none": + payload["tool_choice"] = { + "type": "any" + } if provider.get("tools") == False: payload.pop("tools", None) @@ -746,19 +748,21 @@ async def get_claude_payload(request, engine, provider): tools.append(json_tool) payload["tools"] = tools if "tool_choice" in payload: - if payload["tool_choice"]["type"] == "auto": - payload["tool_choice"] = { - "type": "auto" - } - if payload["tool_choice"]["type"] == "any": - payload["tool_choice"] = { - "type": "any" - } - if payload["tool_choice"]["type"] == "function": - payload["tool_choice"] = { - "type": "tool", - "name": payload["tool_choice"]["function"]["name"] - } + if isinstance(payload["tool_choice"], dict): + if payload["tool_choice"]["type"] == "function": + payload["tool_choice"] = { + "type": "tool", + "name": payload["tool_choice"]["function"]["name"] + } + if isinstance(payload["tool_choice"], str): + if payload["tool_choice"] == "auto": + payload["tool_choice"] = { + "type": "auto" + } + if payload["tool_choice"] == "none": + payload["tool_choice"] = { + "type": "any" + } if provider.get("tools") == False: payload.pop("tools", None) diff --git a/test/test_nostream.py b/test/test_nostream.py index 7378d7c..0ae7642 100644 --- a/test/test_nostream.py +++ b/test/test_nostream.py @@ -45,7 +45,6 @@ def get_model_response(image_base64): ] payload = { - "model": "claude-3-5-sonnet", "messages": [ { @@ -64,7 +63,7 @@ def get_model_response(image_base64): ] } ], - "stream": True, + # "stream": True, "tools": tools, "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}}, "max_tokens": 300 @@ -117,5 +116,5 @@ def main(image_path): print("\n無法解析回應。") if __name__ == "__main__": - image_path = "00001 (8).jpg" # 替換為您的圖像路徑 + image_path = "1.jpg" # 替換為您的圖像路徑 main(image_path) \ No newline at end of file