From 133146d63cd4921430375990f1ed8153e6099d8c Mon Sep 17 00:00:00 2001 From: Alexis Tacnet Date: Thu, 23 May 2024 10:01:37 +0200 Subject: [PATCH] Add tool call id in TS and example for v3 tokenization --- examples/function_calling.js | 100 +++++++++++++++++++++-------------- src/client.d.ts | 11 ++-- 2 files changed, 67 insertions(+), 44 deletions(-) diff --git a/examples/function_calling.js b/examples/function_calling.js index a3570a8..dd2382b 100644 --- a/examples/function_calling.js +++ b/examples/function_calling.js @@ -1,16 +1,20 @@ -import MistralClient from '@mistralai/mistralai'; +import MistralClient from "@mistralai/mistralai"; const apiKey = process.env.MISTRAL_API_KEY; // Assuming we have the following data const data = { - transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'], - customerId: ['C001', 'C002', 'C003', 'C002', 'C001'], - paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20], + transactionId: ["T1001", "T1002", "T1003", "T1004", "T1005"], + customerId: ["C001", "C002", "C003", "C002", "C001"], + paymentAmount: [125.5, 89.99, 120.0, 54.3, 210.2], paymentDate: [ - '2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08', + "2021-10-05", + "2021-10-06", + "2021-10-07", + "2021-10-05", + "2021-10-08", ], - paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'], + paymentStatus: ["Paid", "Unpaid", "Paid", "Paid", "Pending"], }; /** @@ -19,12 +23,12 @@ const data = { * @param {string} transactionId - The transaction id. * @return {string} - The payment status. */ -function retrievePaymentStatus({data, transactionId}) { +function retrievePaymentStatus({ data, transactionId }) { const transactionIndex = data.transactionId.indexOf(transactionId); if (transactionIndex != -1) { - return JSON.stringify({status: data.payment_status[transactionIndex]}); + return JSON.stringify({ status: data.payment_status[transactionIndex] }); } else { - return JSON.stringify({status: 'error - transaction id not found.'}); + return JSON.stringify({ status: "error - transaction id not found." }); } } @@ -35,75 +39,80 @@ function retrievePaymentStatus({data, transactionId}) { * @return {string} - The payment date. * */ -function retrievePaymentDate({data, transactionId}) { +function retrievePaymentDate({ data, transactionId }) { const transactionIndex = data.transactionId.indexOf(transactionId); if (transactionIndex != -1) { - return JSON.stringify({status: data.payment_date[transactionIndex]}); + return JSON.stringify({ status: data.payment_date[transactionIndex] }); } else { - return JSON.stringify({status: 'error - transaction id not found.'}); + return JSON.stringify({ status: "error - transaction id not found." }); } } const namesToFunctions = { retrievePaymentStatus: (transactionId) => - retrievePaymentStatus({data, ...transactionId}), + retrievePaymentStatus({ data, ...transactionId }), retrievePaymentDate: (transactionId) => - retrievePaymentDate({data, ...transactionId}), + retrievePaymentDate({ data, ...transactionId }), }; const tools = [ { - type: 'function', + type: "function", function: { - name: 'retrievePaymentStatus', - description: 'Get payment status of a transaction id', + name: "retrievePaymentStatus", + description: "Get payment status of a transaction id", parameters: { - type: 'object', - required: ['transactionId'], - properties: {transactionId: - {type: 'string', description: 'The transaction id.'}, + type: "object", + required: ["transactionId"], + properties: { + transactionId: { type: "string", description: "The transaction id." }, }, }, }, }, { - type: 'function', + type: "function", function: { - name: 'retrievePaymentDate', - description: 'Get payment date of a transaction id', + name: "retrievePaymentDate", + description: "Get payment date of a transaction id", parameters: { - type: 'object', - required: ['transactionId'], - properties: {transactionId: - {type: 'string', description: 'The transaction id.'}, + type: "object", + required: ["transactionId"], + properties: { + transactionId: { type: "string", description: "The transaction id." }, }, }, }, }, ]; +const model = "mistral-small-latest"; -const model = 'mistral-large'; - -const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net'); +const client = new MistralClient(apiKey); const messages = [ - {role: 'user', content: 'What\'s the status of my transaction?'}, + { role: "user", content: "What's the status of my transaction?" }, ]; let response = await client.chat({ - model: model, messages: messages, tools: tools, + model: model, + messages: messages, + tools: tools, }); - console.log(response.choices[0].message.content); -messages.push( - {role: 'assistant', content: response.choices[0].message.content}, -); -messages.push({role: 'user', content: 'My transaction ID is T1001.'}); +messages.push({ + role: "assistant", + content: response.choices[0].message.content, +}); +messages.push({ role: "user", content: "My transaction ID is T1001." }); -response = await client.chat({model: model, messages: messages, tools: tools}); +response = await client.chat({ + model: model, + messages: messages, + tools: tools, +}); const toolCall = response.choices[0].message.toolCalls[0]; const functionName = toolCall.function.name; @@ -115,8 +124,17 @@ console.log(`functionParams: ${toolCall.function.arguments}`); const functionResult = namesToFunctions[functionName](functionParams); messages.push(response.choices[0].message); -messages.push({role: 'tool', name: functionName, content: functionResult}); +messages.push({ + role: "tool", + name: functionName, + content: functionResult, + tol_call_id: toolCall.id, +}); -response = await client.chat({model: model, messages: messages, tools: tools}); +response = await client.chat({ + model: model, + messages: messages, + tools: tools, +}); console.log(response.choices[0].message.content); diff --git a/src/client.d.ts b/src/client.d.ts index 5b63e9c..47b661d 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -107,10 +107,15 @@ declare module "@mistralai/mistralai" { usage: TokenUsage; } - export interface Message { - role: string; + export type Message = { + role: "system" | "user" | "assistant"; content: string | string[]; - } + } & { + role: "tool"; + content: string | string[]; + name: string; + tool_call_id: string; + }; export interface Tool { type: "function";