diff --git a/.gitignore b/.gitignore index 6a7d6d8..2065ee5 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,6 @@ dist .yarn/unplugged .yarn/build-state.yml .yarn/install-state.gz -.pnp.* \ No newline at end of file +.pnp.* + +changes.diff \ No newline at end of file diff --git a/examples/function_calling.js b/examples/function_calling.js new file mode 100644 index 0000000..a3570a8 --- /dev/null +++ b/examples/function_calling.js @@ -0,0 +1,122 @@ +import MistralClient from '@mistralai/mistralai'; + +const apiKey = process.env.MISTRAL_API_KEY; + +// Assuming we have the following data +const data = { + transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'], + customerId: ['C001', 'C002', 'C003', 'C002', 'C001'], + paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20], + paymentDate: [ + '2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08', + ], + paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'], +}; + +/** + * This function retrieves the payment status of a transaction id. + * @param {object} data - The data object. + * @param {string} transactionId - The transaction id. + * @return {string} - The payment status. + */ +function retrievePaymentStatus({data, transactionId}) { + const transactionIndex = data.transactionId.indexOf(transactionId); + if (transactionIndex != -1) { + return JSON.stringify({status: data.payment_status[transactionIndex]}); + } else { + return JSON.stringify({status: 'error - transaction id not found.'}); + } +} + +/** + * This function retrieves the payment date of a transaction id. + * @param {object} data - The data object. + * @param {string} transactionId - The transaction id. + * @return {string} - The payment date. + * + */ +function retrievePaymentDate({data, transactionId}) { + const transactionIndex = data.transactionId.indexOf(transactionId); + if (transactionIndex != -1) { + return JSON.stringify({status: data.payment_date[transactionIndex]}); + } else { + return JSON.stringify({status: 'error - transaction id not found.'}); + } +} + +const namesToFunctions = { + retrievePaymentStatus: (transactionId) => + retrievePaymentStatus({data, ...transactionId}), + retrievePaymentDate: (transactionId) => + retrievePaymentDate({data, ...transactionId}), +}; + +const tools = [ + { + type: 'function', + function: { + name: 'retrievePaymentStatus', + description: 'Get payment status of a transaction id', + parameters: { + type: 'object', + required: ['transactionId'], + properties: {transactionId: + {type: 'string', description: 'The transaction id.'}, + }, + }, + }, + }, + { + type: 'function', + function: { + name: 'retrievePaymentDate', + description: 'Get payment date of a transaction id', + parameters: { + type: 'object', + required: ['transactionId'], + properties: {transactionId: + {type: 'string', description: 'The transaction id.'}, + }, + }, + }, + }, +]; + + +const model = 'mistral-large'; + +const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net'); + +const messages = [ + {role: 'user', content: 'What\'s the status of my transaction?'}, +]; + +let response = await client.chat({ + model: model, messages: messages, tools: tools, +}); + + +console.log(response.choices[0].message.content); + +messages.push( + {role: 'assistant', content: response.choices[0].message.content}, +); +messages.push({role: 'user', content: 'My transaction ID is T1001.'}); + +response = await client.chat({model: model, messages: messages, tools: tools}); + +const toolCall = response.choices[0].message.toolCalls[0]; +const functionName = toolCall.function.name; +const functionParams = JSON.parse(toolCall.function.arguments); + +console.log(`calling functionName: ${functionName}`); +console.log(`functionParams: ${toolCall.function.arguments}`); + +const functionResult = namesToFunctions[functionName](functionParams); + +messages.push(response.choices[0].message); +messages.push({role: 'tool', name: functionName, content: functionResult}); + +response = await client.chat({model: model, messages: messages, tools: tools}); + +console.log(response.choices[0].message.content); diff --git a/examples/json_format.js b/examples/json_format.js new file mode 100644 index 0000000..7803c56 --- /dev/null +++ b/examples/json_format.js @@ -0,0 +1,13 @@ +import MistralClient from '@mistralai/mistralai'; + +const apiKey = process.env.MISTRAL_API_KEY; + +const client = new MistralClient(apiKey); + +const chatResponse = await client.chat({ + model: 'mistral-large', + messages: [{role: 'user', content: 'What is the best French cheese?'}], + responseFormat: {type: 'json_object'}, +}); + +console.log('Chat:', chatResponse.choices[0].message.content); diff --git a/src/client.d.ts b/src/client.d.ts index 5a82b60..45adc72 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -29,6 +29,42 @@ declare module '@mistralai/mistralai' { data: Model[]; } + export interface Function { + name: string; + description: string; + parameters: object; + } + + export enum ToolType { + function = 'function', + } + + export interface FunctionCall { + name: string; + arguments: string; + } + + export interface ToolCalls { + id: 'null'; + type: ToolType = ToolType.function; + function: FunctionCall; + } + + export enum ResponseFormats { + text = 'text', + json_object = 'json_object', + } + + export enum ToolChoice { + auto = 'auto', + any = 'any', + none = 'none', + } + + export interface ResponseFormat { + type: ResponseFormats = ResponseFormats.text; + } + export interface TokenUsage { prompt_tokens: number; completion_tokens: number; @@ -49,6 +85,7 @@ declare module '@mistralai/mistralai' { delta: { role?: string; content?: string; + tool_calls?: ToolCalls[]; }; finish_reason: string; } @@ -95,7 +132,8 @@ declare module '@mistralai/mistralai' { private _makeChatCompletionRequest( model: string, - messages: Array<{ role: string; content: string }>, + messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>, + tools?: Array<{ type: string; function:Function; }>, temperature?: number, maxTokens?: number, topP?: number, @@ -105,14 +143,17 @@ declare module '@mistralai/mistralai' { * @deprecated use safePrompt instead */ safeMode?: boolean, - safePrompt?: boolean + safePrompt?: boolean, + toolChoice?: ToolChoice, + responseFormat?: ResponseFormat ): object; listModels(): Promise; chat(options: { model: string; - messages: Array<{ role: string; content: string }>; + messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>; + tools?: Array<{ type: string; function:Function; }>; temperature?: number; maxTokens?: number; topP?: number; @@ -122,11 +163,14 @@ declare module '@mistralai/mistralai' { */ safeMode?: boolean; safePrompt?: boolean; + toolChoice?: ToolChoice; + responseFormat?: ResponseFormat; }): Promise; chatStream(options: { model: string; - messages: Array<{ role: string; content: string }>; + messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>; + tools?: Array<{ type: string; function:Function; }>; temperature?: number; maxTokens?: number; topP?: number; @@ -136,6 +180,8 @@ declare module '@mistralai/mistralai' { */ safeMode?: boolean; safePrompt?: boolean; + toolChoice?: ToolChoice; + responseFormat?: ResponseFormat; }): AsyncGenerator; embeddings(options: { diff --git a/src/client.js b/src/client.js index d0a801b..10def72 100644 --- a/src/client.js +++ b/src/client.js @@ -61,6 +61,10 @@ class MistralClient { this.maxRetries = maxRetries; this.timeout = timeout; + + if (this.endpoint.indexOf('inference.azure.com')) { + this.modelDefault = 'mistral'; + } } /** @@ -149,6 +153,7 @@ class MistralClient { * Creates a chat completion request * @param {*} model * @param {*} messages + * @param {*} tools * @param {*} temperature * @param {*} maxTokens * @param {*} topP @@ -156,11 +161,14 @@ class MistralClient { * @param {*} stream * @param {*} safeMode deprecated use safePrompt instead * @param {*} safePrompt + * @param {*} toolChoice + * @param {*} responseFormat * @return {Promise} */ _makeChatCompletionRequest = function( model, messages, + tools, temperature, maxTokens, topP, @@ -168,16 +176,27 @@ class MistralClient { stream, safeMode, safePrompt, + toolChoice, + responseFormat, ) { + // if modelDefault and model are undefined, throw an error + if (!model && !this.modelDefault) { + throw new MistralAPIError( + 'You must provide a model name', + ); + } return { - model: model, + model: model ?? this.modelDefault, messages: messages, + tools: tools ?? undefined, temperature: temperature ?? undefined, max_tokens: maxTokens ?? undefined, top_p: topP ?? undefined, random_seed: randomSeed ?? undefined, stream: stream ?? undefined, safe_prompt: (safeMode || safePrompt) ?? undefined, + tool_choice: toolChoice ?? undefined, + response_format: responseFormat ?? undefined, }; }; @@ -195,27 +214,34 @@ class MistralClient { * @param {*} model the name of the model to chat with, e.g. mistral-tiny * @param {*} messages an array of messages to chat with, e.g. * [{role: 'user', content: 'What is the best French cheese?'}] + * @param {*} tools a list of tools to use. * @param {*} temperature the temperature to use for sampling, e.g. 0.5 * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100 * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9 * @param {*} randomSeed the random seed to use for sampling, e.g. 42 * @param {*} safeMode deprecated use safePrompt instead * @param {*} safePrompt whether to use safe mode, e.g. true + * @param {*} toolChoice the tool to use, e.g. 'auto' + * @param {*} responseFormat the format of the response, e.g. 'json_format' * @return {Promise} */ chat = async function({ model, messages, + tools, temperature, maxTokens, topP, randomSeed, safeMode, safePrompt, + toolChoice, + responseFormat, }) { const request = this._makeChatCompletionRequest( model, messages, + tools, temperature, maxTokens, topP, @@ -223,6 +249,8 @@ class MistralClient { false, safeMode, safePrompt, + toolChoice, + responseFormat, ); const response = await this._request( 'post', @@ -237,27 +265,34 @@ class MistralClient { * @param {*} model the name of the model to chat with, e.g. mistral-tiny * @param {*} messages an array of messages to chat with, e.g. * [{role: 'user', content: 'What is the best French cheese?'}] + * @param {*} tools a list of tools to use. * @param {*} temperature the temperature to use for sampling, e.g. 0.5 * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100 * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9 * @param {*} randomSeed the random seed to use for sampling, e.g. 42 * @param {*} safeMode deprecated use safePrompt instead * @param {*} safePrompt whether to use safe mode, e.g. true + * @param {*} toolChoice the tool to use, e.g. 'auto' + * @param {*} responseFormat the format of the response, e.g. 'json_format' * @return {Promise} */ chatStream = async function* ({ model, messages, + tools, temperature, maxTokens, topP, randomSeed, safeMode, safePrompt, + toolChoice, + responseFormat, }) { const request = this._makeChatCompletionRequest( model, messages, + tools, temperature, maxTokens, topP, @@ -265,6 +300,8 @@ class MistralClient { true, safeMode, safePrompt, + toolChoice, + responseFormat, ); const response = await this._request( 'post', diff --git a/version.txt b/version.txt deleted file mode 100644 index b1e80bb..0000000 --- a/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.1.3