diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index 9ea8803a343e..0ddaa9e1389e 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -800,7 +800,7 @@ export class ChatMistralAI< }, }, ], - tool_choice: "auto", + tool_choice: "any", } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, @@ -830,7 +830,7 @@ export class ChatMistralAI< function: openAIFunctionDefinition, }, ], - tool_choice: "auto", + tool_choice: "any", } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, diff --git a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts index b052f2c3e30d..fd8d25f361d7 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts @@ -982,3 +982,32 @@ test("Invoke token count usage_metadata", async () => { res.usage_metadata.input_tokens + res.usage_metadata.output_tokens ); }); + +test("withStructuredOutput will always force tool usage", async () => { + const model = new ChatMistralAI({ + temperature: 0, + model: "mistral-large-latest", + }); + + const weatherTool = z + .object({ + location: z.string().describe("The name of city to get the weather for."), + }) + .describe( + "Get the weather of a specific location and return the temperature in Celsius." + ); + const modelWithTools = model.withStructuredOutput(weatherTool, { + name: "get_weather", + includeRaw: true, + }); + const response = await modelWithTools.invoke( + "What is the sum of 271623 and 281623? It is VERY important you use a calculator tool to give me the answer." + ); + + if (!("tool_calls" in response.raw)) { + throw new Error("Tool call not found in response"); + } + const castMessage = response.raw as AIMessage; + expect(castMessage.tool_calls).toHaveLength(1); + expect(castMessage.tool_calls?.[0].name).toBe("get_weather"); +});