Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #79 from mistralai/release/v0.4.0
Browse files Browse the repository at this point in the history
release 0.4.0: add support for completion
  • Loading branch information
jean-malo authored May 29, 2024
2 parents 6d2639e + 7d8cf44 commit ac6e138
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/json_format.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const apiKey = process.env.MISTRAL_API_KEY;
const client = new MistralClient(apiKey);

const chatResponse = await client.chat({
model: 'mistral-large',
model: 'mistral-large-latest',
messages: [{role: 'user', content: 'What is the best French cheese?'}],
responseFormat: {type: 'json_object'},
});
Expand Down
2 changes: 1 addition & 1 deletion examples/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@mistralai/mistralai",
"version": "0.3.0",
"version": "0.4.0",
"description": "",
"author": "[email protected]",
"license": "ISC",
Expand Down
22 changes: 22 additions & 0 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ declare module "@mistralai/mistralai" {
responseFormat?: ResponseFormat;
}

export interface CompletionRequest {
model: string;
prompt: string;
suffix?: string;
temperature?: number;
maxTokens?: number;
topP?: number;
randomSeed?: number;
stop?: string | string[];
}

export interface ChatRequestOptions {
signal?: AbortSignal;
}
Expand Down Expand Up @@ -170,6 +181,17 @@ declare module "@mistralai/mistralai" {
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;

completion(
request: CompletionRequest,
options?: ChatRequestOptions
): Promise<ChatCompletionResponse>;

completionStream(
request: CompletionRequest,
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;


embeddings(options: {
model: string;
input: string | string[];
Expand Down
171 changes: 170 additions & 1 deletion src/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class MistralClient {
} else {
throw new MistralAPIError(
`HTTP error! status: ${response.status} ` +
`Response: \n${await response.text()}`,
`Response: \n${await response.text()}`,
);
}
} catch (error) {
Expand Down Expand Up @@ -228,6 +228,47 @@ class MistralClient {
};
};

/**
* Creates a completion request
* @param {*} model
* @param {*} prompt
* @param {*} suffix
* @param {*} temperature
* @param {*} maxTokens
* @param {*} topP
* @param {*} randomSeed
* @param {*} stop
* @param {*} stream
* @return {Promise<Object>}
*/
_makeCompletionRequest = function(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
stream,
) {
// if modelDefault and model are undefined, throw an error
if (!model && !this.modelDefault) {
throw new MistralAPIError('You must provide a model name');
}
return {
model: model ?? this.modelDefault,
prompt: prompt,
suffix: suffix ?? undefined,
temperature: temperature ?? undefined,
max_tokens: maxTokens ?? undefined,
top_p: topP ?? undefined,
random_seed: randomSeed ?? undefined,
stop: stop ?? undefined,
stream: stream ?? undefined,
};
};

/**
* Returns a list of the available models
* @return {Promise<Object>}
Expand Down Expand Up @@ -401,6 +442,134 @@ class MistralClient {
const response = await this._request('post', 'v1/embeddings', request);
return response;
};

/**
* A completion endpoint without streaming.
*
* @param {Object} data - The main completion configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.prompt - the prompt to complete,
* e.g. 'def fibonacci(n: int):'
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
* @param {*} data.suffix - the suffix to append to the prompt,
* e.g. 'n = int(input(\'Enter a number: \'))'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
completion = async function(
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{signal} = {},
) {
const request = this._makeCompletionRequest(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
false,
);
const response = await this._request(
'post',
'v1/fim/completions',
request,
signal,
);
return response;
};

/**
* A completion endpoint that streams responses.
*
* @param {Object} data - The main completion configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.prompt - the prompt to complete,
* e.g. 'def fibonacci(n: int):'
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.stop - the stop sequence to use, e.g. ['\n']
* @param {*} data.suffix - the suffix to append to the prompt,
* e.g. 'n = int(input(\'Enter a number: \'))'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
completionStream = async function* (
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{signal} = {},
) {
const request = this._makeCompletionRequest(
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
true,
);
const response = await this._request(
'post',
'v1/fim/completions',
request,
signal,
);

let buffer = '';
const decoder = new TextDecoder();
for await (const chunk of response) {
buffer += decoder.decode(chunk, {stream: true});
let firstNewline;
while ((firstNewline = buffer.indexOf('\n')) !== -1) {
const chunkLine = buffer.substring(0, firstNewline);
buffer = buffer.substring(firstNewline + 1);
if (chunkLine.startsWith('data:')) {
const json = chunkLine.substring(6).trim();
if (json !== '[DONE]') {
yield JSON.parse(json);
}
}
}
}
};
}

export default MistralClient;
26 changes: 20 additions & 6 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -40,7 +40,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -58,7 +58,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -78,7 +78,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -101,7 +101,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand All @@ -125,7 +125,7 @@ describe('Mistral Client', () => {
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
{
role: 'user',
Expand Down Expand Up @@ -176,4 +176,18 @@ describe('Mistral Client', () => {
expect(response).toEqual(mockResponse);
});
});

describe('completion()', () => {
it('should return a chat response object', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
client._fetch = mockFetch(200, mockResponse);

const response = await client.completion({
model: 'mistral-small-latest',
prompt: '# this is a',
});
expect(response).toEqual(mockResponse);
});
});
});
8 changes: 4 additions & 4 deletions tests/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function mockListModels() {
],
},
{
id: 'mistral-small',
id: 'mistral-small-latest',
object: 'model',
created: 1703186988,
owned_by: 'mistralai',
Expand Down Expand Up @@ -172,7 +172,7 @@ export function mockChatResponsePayload() {
index: 0,
},
],
model: 'mistral-small',
model: 'mistral-small-latest',
usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
};
}
Expand All @@ -187,7 +187,7 @@ export function mockChatResponseStreamingPayload() {
[encoder.encode('data: ' +
JSON.stringify({
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
model: 'mistral-small',
model: 'mistral-small-latest',
choices: [
{
index: 0,
Expand All @@ -207,7 +207,7 @@ export function mockChatResponseStreamingPayload() {
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
object: 'chat.completion.chunk',
created: 1703168544,
model: 'mistral-small',
model: 'mistral-small-latest',
choices: [
{
index: i,
Expand Down

0 comments on commit ac6e138

Please sign in to comment.