Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

release 0.4.0: add support for completion #79

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/json_format.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const apiKey = process.env.MISTRAL_API_KEY;
const client = new MistralClient(apiKey);

const chatResponse = await client.chat({
model: 'mistral-large',
model: 'mistral-large-latest',
messages: [{role: 'user', content: 'What is the best French cheese?'}],
responseFormat: {type: 'json_object'},
});
Expand Down
2 changes: 1 addition & 1 deletion examples/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@mistralai/mistralai",
"version": "0.3.0",
"version": "0.4.0",
"description": "",
"author": "[email protected]",
"license": "ISC",
Expand Down
22 changes: 22 additions & 0 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ declare module "@mistralai/mistralai" {
responseFormat?: ResponseFormat;
}

export interface CompletionRequest {
model: string;
prompt: string;
suffix?: string;
temperature?: number;
maxTokens?: number;
topP?: number;
randomSeed?: number;
stop?: string | string[];
}

export interface ChatRequestOptions {
signal?: AbortSignal;
}
Expand Down Expand Up @@ -170,6 +181,17 @@ declare module "@mistralai/mistralai" {
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;

completion(
request: CompletionRequest,
options?: ChatRequestOptions
): Promise<ChatCompletionResponse>;

completionStream(
request: CompletionRequest,
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;


embeddings(options: {
model: string;
input: string | string[];
Expand Down
171 changes: 170 additions & 1 deletion src/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class MistralClient {
} else {
throw new MistralAPIError(
`HTTP error! status: ${response.status} ` +
`Response: \n${await response.text()}`,
`Response: \n${await response.text()}`,
);
}
} catch (error) {
Expand Down Expand Up @@ -228,6 +228,47 @@ class MistralClient {
};
};

/**
* Creates a completion request
* @param {*} model
* @param {*} prompt
* @param {*} suffix
* @param {*} temperature
* @param {*} maxTokens
* @param {*} topP
* @param {*} randomSeed
* @param {*} stop
* @param {*} stream
* @return {Promise<Object>}
*/
_makeCompletionRequest = function(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
stream,
) {
// if modelDefault and model are undefined, throw an error
if (!model && !this.modelDefault) {
throw new MistralAPIError('You must provide a model name');
}
return {
model: model ?? this.modelDefault,
prompt: prompt,
suffix: suffix ?? undefined,
temperature: temperature ?? undefined,
max_tokens: maxTokens ?? undefined,
top_p: topP ?? undefined,
random_seed: randomSeed ?? undefined,
stop: stop ?? undefined,
stream: stream ?? undefined,
};
};

/**
* Returns a list of the available models
* @return {Promise<Object>}
Expand Down Expand Up @@ -401,6 +442,134 @@ class MistralClient {
const response = await this._request('post', 'v1/embeddings', request);
return response;
};

/**
* A completion endpoint without streaming.
*
* @param {Object} data - The main completion configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.prompt - the prompt to complete,
* e.g. 'def fibonacci(n: int):'
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
* @param {*} data.suffix - the suffix to append to the prompt,
* e.g. 'n = int(input(\'Enter a number: \'))'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
completion = async function(
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{signal} = {},
) {
const request = this._makeCompletionRequest(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
false,
);
const response = await this._request(
'post',
'v1/fim/completions',
request,
signal,
);
return response;
};

/**
* A completion endpoint that streams responses.
*
* @param {Object} data - The main completion configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.prompt - the prompt to complete,
* e.g. 'def fibonacci(n: int):'
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
* @param {*} data.suffix - the suffix to append to the prompt,
* e.g. 'n = int(input(\'Enter a number: \'))'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
completionStream = async function* (
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{signal} = {},
) {
const request = this._makeCompletionRequest(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
true,
);
const response = await this._request(
'post',
'v1/fim/completions',
request,
signal,
);

let buffer = '';
const decoder = new TextDecoder();
for await (const chunk of response) {
buffer += decoder.decode(chunk, {stream: true});
let firstNewline;
while ((firstNewline = buffer.indexOf('\n')) !== -1) {
const chunkLine = buffer.substring(0, firstNewline);
buffer = buffer.substring(firstNewline + 1);
if (chunkLine.startsWith('data:')) {
const json = chunkLine.substring(6).trim();
if (json !== '[DONE]') {
yield JSON.parse(json);
}
}
}
}
};
}

export default MistralClient;
26 changes: 20 additions & 6 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -40,7 +40,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -58,7 +58,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -78,7 +78,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -101,7 +101,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -125,7 +125,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand Down Expand Up @@ -176,4 +176,18 @@ describe('Mistral Client', () => {
expect(response).toEqual(mockResponse);
});
});

describe('completion()', () => {
it('should return a chat response object', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
client._fetch = mockFetch(200, mockResponse);

const response = await client.completion({
model: 'mistral-small-latest',
prompt: '# this is a',
});
expect(response).toEqual(mockResponse);
});
});
});
8 changes: 4 additions & 4 deletions tests/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function mockListModels() {
],
},
{
id: 'mistral-small',
id: 'mistral-small-latest',
object: 'model',
created: 1703186988,
owned_by: 'mistralai',
Expand Down Expand Up @@ -172,7 +172,7 @@ export function mockChatResponsePayload() {
index: 0,
},
],
model: 'mistral-small',
model: 'mistral-small-latest',
usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
};
}
Expand All @@ -187,7 +187,7 @@ export function mockChatResponseStreamingPayload() {
[encoder.encode('data: ' +
JSON.stringify({
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
model: 'mistral-small',
model: 'mistral-small-latest',
choices: [
{
index: 0,
Expand All @@ -207,7 +207,7 @@ export function mockChatResponseStreamingPayload() {
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
object: 'chat.completion.chunk',
created: 1703168544,
model: 'mistral-small',
model: 'mistral-small-latest',
choices: [
{
index: i,
Expand Down