diff --git a/src/inspect_ai/model/_providers/mistral.py b/src/inspect_ai/model/_providers/mistral.py index c34d34eca..d66d1f827 100644 --- a/src/inspect_ai/model/_providers/mistral.py +++ b/src/inspect_ai/model/_providers/mistral.py @@ -3,7 +3,6 @@ from typing import Any from mistralai import ( - ChatCompletionRequestToolChoice, FunctionCall, FunctionName, Mistral, @@ -179,11 +178,11 @@ def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]: def mistral_chat_tool_choice( tool_choice: ToolChoice, -) -> ChatCompletionRequestToolChoice: +) -> str | dict[str, Any]: if isinstance(tool_choice, ToolFunction): return MistralToolChoice( type="function", function=FunctionName(name=tool_choice.name) - ) + ).model_dump() elif tool_choice == "any": return "any" elif tool_choice == "auto": diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 6c5289840..f30bf5b02 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -77,7 +77,9 @@ async def add(x: int, y: int): return add -def check_tools(model: Model, disable: list[str] = ["calls", "force", "none"]) -> None: +def check_tools( + model: Model, disable: list[Literal["calls", "force", "none"]] = [] +) -> None: if "calls" not in disable: check_tools_calls(model) if "force" not in disable: @@ -171,7 +173,7 @@ def test_anthropic_tools(): @skip_if_no_mistral def test_mistral_tools(): - check_tools("mistral/mistral-large-latest", disable=["force"]) + check_tools("mistral/mistral-large-latest") @skip_if_no_groq