From 7f1c9adb8196574f84d7bedacab75f4d8680e7f8 Mon Sep 17 00:00:00 2001 From: Alexis Tacnet Date: Thu, 23 May 2024 17:19:46 +0200 Subject: [PATCH] Add tool call id in TS and example for v3 tokenization (#76) Co-authored-by: Harizo Rajaona --- examples/function_calling.js | 56 ++++++++++++++++++++++++------------ src/client.d.ts | 15 +++++++--- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/examples/function_calling.js b/examples/function_calling.js index a3570a8..9743ce0 100644 --- a/examples/function_calling.js +++ b/examples/function_calling.js @@ -6,9 +6,13 @@ const apiKey = process.env.MISTRAL_API_KEY; const data = { transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'], customerId: ['C001', 'C002', 'C003', 'C002', 'C001'], - paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20], + paymentAmount: [125.5, 89.99, 120.0, 54.3, 210.2], paymentDate: [ - '2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08', + '2021-10-05', + '2021-10-06', + '2021-10-07', + '2021-10-05', + '2021-10-08', ], paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'], }; @@ -22,7 +26,7 @@ const data = { function retrievePaymentStatus({data, transactionId}) { const transactionIndex = data.transactionId.indexOf(transactionId); if (transactionIndex != -1) { - return JSON.stringify({status: data.payment_status[transactionIndex]}); + return JSON.stringify({status: data.paymentStatus[transactionIndex]}); } else { return JSON.stringify({status: 'error - transaction id not found.'}); } @@ -60,8 +64,8 @@ const tools = [ parameters: { type: 'object', required: ['transactionId'], - properties: {transactionId: - {type: 'string', description: 'The transaction id.'}, + properties: { + transactionId: {type: 'string', description: 'The transaction id.'}, }, }, }, @@ -74,38 +78,43 @@ const tools = [ parameters: { type: 'object', required: ['transactionId'], - properties: {transactionId: - {type: 'string', description: 'The transaction id.'}, + properties: { + transactionId: {type: 'string', description: 'The transaction id.'}, }, }, }, }, ]; +const model = 'mistral-small-latest'; -const model = 'mistral-large'; - -const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net'); +const client = new MistralClient(apiKey); const messages = [ {role: 'user', content: 'What\'s the status of my transaction?'}, ]; let response = await client.chat({ - model: model, messages: messages, tools: tools, + model: model, + messages: messages, + tools: tools, }); - console.log(response.choices[0].message.content); -messages.push( - {role: 'assistant', content: response.choices[0].message.content}, -); +messages.push({ + role: 'assistant', + content: response.choices[0].message.content, +}); messages.push({role: 'user', content: 'My transaction ID is T1001.'}); -response = await client.chat({model: model, messages: messages, tools: tools}); +response = await client.chat({ + model: model, + messages: messages, + tools: tools, +}); -const toolCall = response.choices[0].message.toolCalls[0]; +const toolCall = response.choices[0].message.tool_calls[0]; const functionName = toolCall.function.name; const functionParams = JSON.parse(toolCall.function.arguments); @@ -115,8 +124,17 @@ console.log(`functionParams: ${toolCall.function.arguments}`); const functionResult = namesToFunctions[functionName](functionParams); messages.push(response.choices[0].message); -messages.push({role: 'tool', name: functionName, content: functionResult}); +messages.push({ + role: 'tool', + name: functionName, + content: functionResult, + tool_call_id: toolCall.id, +}); -response = await client.chat({model: model, messages: messages, tools: tools}); +response = await client.chat({ + model: model, + messages: messages, + tools: tools, +}); console.log(response.choices[0].message.content); diff --git a/src/client.d.ts b/src/client.d.ts index 5b63e9c..fe820ef 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -107,10 +107,17 @@ declare module "@mistralai/mistralai" { usage: TokenUsage; } - export interface Message { - role: string; - content: string | string[]; - } + export type Message = + | { + role: "system" | "user" | "assistant"; + content: string | string[]; + } + | { + role: "tool"; + content: string | string[]; + name: string; + tool_call_id: string; + }; export interface Tool { type: "function";