Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Update version to 0.1.3
Browse files Browse the repository at this point in the history
  • Loading branch information
GitHub Actions committed Feb 26, 2024
1 parent 18fbf66 commit 72ab604
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 7 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,6 @@ dist
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
.pnp.*

changes.diff
122 changes: 122 additions & 0 deletions examples/function_calling.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import MistralClient from '@mistralai/mistralai';

const apiKey = process.env.MISTRAL_API_KEY;

// Assuming we have the following data
const data = {
transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'],
customerId: ['C001', 'C002', 'C003', 'C002', 'C001'],
paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20],
paymentDate: [
'2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08',
],
paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'],
};

/**
* This function retrieves the payment status of a transaction id.
* @param {object} data - The data object.
* @param {string} transactionId - The transaction id.
* @return {string} - The payment status.
*/
function retrievePaymentStatus({data, transactionId}) {
const transactionIndex = data.transactionId.indexOf(transactionId);
if (transactionIndex != -1) {
return JSON.stringify({status: data.payment_status[transactionIndex]});
} else {
return JSON.stringify({status: 'error - transaction id not found.'});
}
}

/**
* This function retrieves the payment date of a transaction id.
* @param {object} data - The data object.
* @param {string} transactionId - The transaction id.
* @return {string} - The payment date.
*
*/
function retrievePaymentDate({data, transactionId}) {
const transactionIndex = data.transactionId.indexOf(transactionId);
if (transactionIndex != -1) {
return JSON.stringify({status: data.payment_date[transactionIndex]});
} else {
return JSON.stringify({status: 'error - transaction id not found.'});
}
}

const namesToFunctions = {
retrievePaymentStatus: (transactionId) =>
retrievePaymentStatus({data, ...transactionId}),
retrievePaymentDate: (transactionId) =>
retrievePaymentDate({data, ...transactionId}),
};

const tools = [
{
type: 'function',
function: {
name: 'retrievePaymentStatus',
description: 'Get payment status of a transaction id',
parameters: {
type: 'object',
required: ['transactionId'],
properties: {transactionId:
{type: 'string', description: 'The transaction id.'},
},
},
},
},
{
type: 'function',
function: {
name: 'retrievePaymentDate',
description: 'Get payment date of a transaction id',
parameters: {
type: 'object',
required: ['transactionId'],
properties: {transactionId:
{type: 'string', description: 'The transaction id.'},
},
},
},
},
];


const model = 'mistral-large';

const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net');

const messages = [
{role: 'user', content: 'What\'s the status of my transaction?'},
];

let response = await client.chat({
model: model, messages: messages, tools: tools,
});


console.log(response.choices[0].message.content);

messages.push(
{role: 'assistant', content: response.choices[0].message.content},
);
messages.push({role: 'user', content: 'My transaction ID is T1001.'});

response = await client.chat({model: model, messages: messages, tools: tools});

const toolCall = response.choices[0].message.toolCalls[0];
const functionName = toolCall.function.name;
const functionParams = JSON.parse(toolCall.function.arguments);

console.log(`calling functionName: ${functionName}`);
console.log(`functionParams: ${toolCall.function.arguments}`);

const functionResult = namesToFunctions[functionName](functionParams);

messages.push(response.choices[0].message);
messages.push({role: 'tool', name: functionName, content: functionResult});

response = await client.chat({model: model, messages: messages, tools: tools});

console.log(response.choices[0].message.content);
13 changes: 13 additions & 0 deletions examples/json_format.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import MistralClient from '@mistralai/mistralai';

const apiKey = process.env.MISTRAL_API_KEY;

const client = new MistralClient(apiKey);

const chatResponse = await client.chat({
model: 'mistral-large',
messages: [{role: 'user', content: 'What is the best French cheese?'}],
responseFormat: {type: 'json_object'},
});

console.log('Chat:', chatResponse.choices[0].message.content);
54 changes: 50 additions & 4 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,42 @@ declare module '@mistralai/mistralai' {
data: Model[];
}

export interface Function {
name: string;
description: string;
parameters: object;
}

export enum ToolType {
function = 'function',
}

export interface FunctionCall {
name: string;
arguments: string;
}

export interface ToolCalls {
id: 'null';
type: ToolType = ToolType.function;
function: FunctionCall;
}

export enum ResponseFormats {
text = 'text',
json_object = 'json_object',
}

export enum ToolChoice {
auto = 'auto',
any = 'any',
none = 'none',
}

export interface ResponseFormat {
type: ResponseFormats = ResponseFormats.text;
}

export interface TokenUsage {
prompt_tokens: number;
completion_tokens: number;
Expand All @@ -49,6 +85,7 @@ declare module '@mistralai/mistralai' {
delta: {
role?: string;
content?: string;
tool_calls?: ToolCalls[];
};
finish_reason: string;
}
Expand Down Expand Up @@ -95,7 +132,8 @@ declare module '@mistralai/mistralai' {

private _makeChatCompletionRequest(
model: string,
messages: Array<{ role: string; content: string }>,
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>,
tools?: Array<{ type: string; function:Function; }>,
temperature?: number,
maxTokens?: number,
topP?: number,
Expand All @@ -105,14 +143,17 @@ declare module '@mistralai/mistralai' {
* @deprecated use safePrompt instead
*/
safeMode?: boolean,
safePrompt?: boolean
safePrompt?: boolean,
toolChoice?: ToolChoice,
responseFormat?: ResponseFormat
): object;

listModels(): Promise<ListModelsResponse>;

chat(options: {
model: string;
messages: Array<{ role: string; content: string }>;
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
tools?: Array<{ type: string; function:Function; }>;
temperature?: number;
maxTokens?: number;
topP?: number;
Expand All @@ -122,11 +163,14 @@ declare module '@mistralai/mistralai' {
*/
safeMode?: boolean;
safePrompt?: boolean;
toolChoice?: ToolChoice;
responseFormat?: ResponseFormat;
}): Promise<ChatCompletionResponse>;

chatStream(options: {
model: string;
messages: Array<{ role: string; content: string }>;
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
tools?: Array<{ type: string; function:Function; }>;
temperature?: number;
maxTokens?: number;
topP?: number;
Expand All @@ -136,6 +180,8 @@ declare module '@mistralai/mistralai' {
*/
safeMode?: boolean;
safePrompt?: boolean;
toolChoice?: ToolChoice;
responseFormat?: ResponseFormat;
}): AsyncGenerator<ChatCompletionResponseChunk, void, unknown>;

embeddings(options: {
Expand Down
Loading

0 comments on commit 72ab604

Please sign in to comment.