Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #37 from mistralai/bam4d/safe_prompt
Browse files Browse the repository at this point in the history
deprecating safeMode in favour of safePrompt
  • Loading branch information
Bam4d authored Jan 15, 2024
2 parents 4cac6c5 + e6386c4 commit 1750bab
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
14 changes: 13 additions & 1 deletion src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ declare module '@mistralai/mistralai' {
topP?: number,
randomSeed?: number,
stream?: boolean,
safeMode?: boolean
/**
* @deprecated use safePrompt instead
*/
safeMode?: boolean,
safePrompt?: boolean
): object;

listModels(): Promise<ListModelsResponse>;
Expand All @@ -113,7 +117,11 @@ declare module '@mistralai/mistralai' {
maxTokens?: number;
topP?: number;
randomSeed?: number;
/**
* @deprecated use safePrompt instead
*/
safeMode?: boolean;
safePrompt?: boolean;
}): Promise<ChatCompletionResponse>;

chatStream(options: {
Expand All @@ -123,7 +131,11 @@ declare module '@mistralai/mistralai' {
maxTokens?: number;
topP?: number;
randomSeed?: number;
/**
* @deprecated use safePrompt instead
*/
safeMode?: boolean;
safePrompt?: boolean;
}): AsyncGenerator<ChatCompletionResponseChunk, void, unknown>;

embeddings(options: {
Expand Down
16 changes: 12 additions & 4 deletions src/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class MistralClient {
* @param {*} topP
* @param {*} randomSeed
* @param {*} stream
* @param {*} safeMode
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt
* @return {Promise<Object>}
*/
_makeChatCompletionRequest = function(
Expand All @@ -165,6 +166,7 @@ class MistralClient {
randomSeed,
stream,
safeMode,
safePrompt,
) {
return {
model: model,
Expand All @@ -174,7 +176,7 @@ class MistralClient {
top_p: topP ?? undefined,
random_seed: randomSeed ?? undefined,
stream: stream ?? undefined,
safe_prompt: safeMode ?? undefined,
safe_prompt: (safeMode || safePrompt) ?? undefined,
};
};

Expand All @@ -196,7 +198,8 @@ class MistralClient {
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
* @param {*} safeMode whether to use safe mode, e.g. true
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt whether to use safe mode, e.g. true
* @return {Promise<Object>}
*/
chat = async function({
Expand All @@ -207,6 +210,7 @@ class MistralClient {
topP,
randomSeed,
safeMode,
safePrompt,
}) {
const request = this._makeChatCompletionRequest(
model,
Expand All @@ -217,6 +221,7 @@ class MistralClient {
randomSeed,
false,
safeMode,
safePrompt,
);
const response = await this._request(
'post',
Expand All @@ -235,7 +240,8 @@ class MistralClient {
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
* @param {*} safeMode whether to use safe mode, e.g. true
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt whether to use safe mode, e.g. true
* @return {Promise<Object>}
*/
chatStream = async function* ({
Expand All @@ -246,6 +252,7 @@ class MistralClient {
topP,
randomSeed,
safeMode,
safePrompt,
}) {
const request = this._makeChatCompletionRequest(
model,
Expand All @@ -256,6 +263,7 @@ class MistralClient {
randomSeed,
true,
safeMode,
safePrompt,
);
const response = await this._request(
'post',
Expand Down
84 changes: 84 additions & 0 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,42 @@ describe('Mistral Client', () => {
});
expect(response).toEqual(mockResponse);
});

it('should return a chat response object if safeMode is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
messages: [
{
role: 'user',
content: 'What is the best French cheese?',
},
],
safeMode: true,
});
expect(response).toEqual(mockResponse);
});

it('should return a chat response object if safePrompt is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
messages: [
{
role: 'user',
content: 'What is the best French cheese?',
},
],
safePrompt: true,
});
expect(response).toEqual(mockResponse);
});
});

describe('chatStream()', () => {
Expand All @@ -58,6 +94,54 @@ describe('Mistral Client', () => {

expect(parsedResponse.length).toEqual(11);
});

it('should return parsed, streamed response with safeMode', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
messages: [
{
role: 'user',
content: 'What is the best French cheese?',
},
],
safeMode: true,
});

const parsedResponse = [];
for await (const r of response) {
parsedResponse.push(r);
}

expect(parsedResponse.length).toEqual(11);
});

it('should return parsed, streamed response with safePrompt', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
messages: [
{
role: 'user',
content: 'What is the best French cheese?',
},
],
safePrompt: true,
});

const parsedResponse = [];
for await (const r of response) {
parsedResponse.push(r);
}

expect(parsedResponse.length).toEqual(11);
});
});

describe('embeddings()', () => {
Expand Down

0 comments on commit 1750bab

Please sign in to comment.