Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Add support for our Fine-tuning API #80

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,47 @@ const embeddingsBatchResponse = await client.embeddings({
console.log('Embeddings Batch:', embeddingsBatchResponse.data);
```

### Files

```typescript
// Create a new file
const file = fs.readFileSync('file.jsonl');
const createdFile = await client.files.create({ file });

// List files
const files = await client.files.list();

// Retrieve a file
const retrievedFile = await client.files.retrieve({ fileId: createdFile.id });

// Delete a file
const deletedFile = await client.files.delete({ fileId: createdFile.id });
```

### Fine-tuning Jobs

```typescript
// Create a new job
const createdJob = await client.jobs.create({
model: 'open-mistral-7B',
trainingFiles: [trainingFile.id],
validationFiles: [validationFile.id],
hyperparameters: {
trainingSteps: 10,
learningRate: 0.0001,
},
});

// List jobs
const jobs = await client.jobs.list();

// Retrieve a job
const retrievedJob = await client.jobs.retrieve({ jobId: createdJob.id });

// Cancel a job
const canceledJob = await client.jobs.cancel({ jobId: createdJob.id });
```

## Run examples

You can run the examples in the examples directory by installing them locally:
Expand Down
3 changes: 3 additions & 0 deletions examples/file.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters.", "weight": 0}]}
27 changes: 27 additions & 0 deletions examples/files.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import MistralClient from '@mistralai/mistralai';
import * as fs from 'fs';


const apiKey = process.env.MISTRAL_API_KEY;

const client = new MistralClient(apiKey);

// Create a new file
const blob = new Blob(
[fs.readFileSync('file.jsonl')],
{type: 'application/json'},
);
const createdFile = await client.files.create({file: blob});
console.log(createdFile);

// List files
const files = await client.files.list();
console.log(files);

// Retrieve a file
const retrievedFile = await client.files.retrieve({fileId: createdFile.id});
console.log(retrievedFile);

// Delete a file
const deletedFile = await client.files.delete({fileId: createdFile.id});
console.log(deletedFile);
39 changes: 39 additions & 0 deletions examples/jobs.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import MistralClient from '@mistralai/mistralai';
import * as fs from 'fs';


const apiKey = process.env.MISTRAL_API_KEY;

const client = new MistralClient(apiKey);

// Create a new file
const blob = new Blob(
[fs.readFileSync('file.jsonl')],
{type: 'application/json'},
);
const createdFile = await client.files.create({file: blob});

// Create a new job
const hyperparameters = {
training_steps: 10,
learning_rate: 0.0001,
};
const createdJob = await client.jobs.create({
model: 'open-mistral-7b',
trainingFiles: [createdFile.id],
validationFiles: [createdFile.id],
hyperparameters,
});
console.log(createdJob);

// List jobs
const jobs = await client.jobs.list();
console.log(jobs);

// Retrieve a job
const retrievedJob = await client.jobs.retrieve({jobId: createdJob.id});
console.log(retrievedJob);

// Cancel a job
const canceledJob = await client.jobs.cancel({jobId: createdJob.id});
console.log(canceledJob);
9 changes: 4 additions & 5 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,15 @@ declare module "@mistralai/mistralai" {
): AsyncGenerator<ChatCompletionResponseChunk, void>;

completion(
request: CompletionRequest,
options?: ChatRequestOptions
request: CompletionRequest,
options?: ChatRequestOptions
): Promise<ChatCompletionResponse>;

completionStream(
request: CompletionRequest,
options?: ChatRequestOptions
request: CompletionRequest,
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;


embeddings(options: {
model: string;
input: string | string[];
Expand Down
40 changes: 17 additions & 23 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import FilesClient from './files.js';
import JobsClient from './jobs.js';

const VERSION = '0.3.0';
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';
Expand Down Expand Up @@ -79,6 +82,9 @@ class MistralClient {
if (this.endpoint.indexOf('inference.azure.com')) {
this.modelDefault = 'mistral';
}

this.files = new FilesClient(this);
this.jobs = new JobsClient(this);
}

/**
Expand All @@ -98,9 +104,10 @@ class MistralClient {
* @param {*} path
* @param {*} request
* @param {*} signal
* @param {*} formData
* @return {Promise<*>}
*/
_request = async function(method, path, request, signal) {
_request = async function(method, path, request, signal, formData = null) {
const url = `${this.endpoint}/${path}`;
const options = {
method: method,
Expand All @@ -110,13 +117,18 @@ class MistralClient {
'Content-Type': 'application/json',
'Authorization': `Bearer ${this.apiKey}`,
},
body: method !== 'get' ? JSON.stringify(request) : null,
signal: combineSignals([
AbortSignal.timeout(this.timeout * 1000),
signal,
]),
body: method !== 'get' ? formData ?? JSON.stringify(request) : null,
timeout: this.timeout * 1000,
};

if (formData) {
delete options.headers['Content-Type'];
}

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
try {
const response = await this._fetch(url, options);
Expand Down Expand Up @@ -161,7 +173,7 @@ class MistralClient {
} else {
throw new MistralAPIError(
`HTTP error! status: ${response.status} ` +
`Response: \n${await response.text()}`,
`Response: \n${await response.text()}`,
);
}
} catch (error) {
Expand Down Expand Up @@ -467,16 +479,7 @@ class MistralClient {
* @return {Promise<Object>}
*/
completion = async function(
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
{signal} = {},
) {
const request = this._makeCompletionRequest(
Expand Down Expand Up @@ -523,16 +526,7 @@ class MistralClient {
* @return {Promise<Object>}
*/
completionStream = async function* (
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
{signal} = {},
) {
const request = this._makeCompletionRequest(
Expand Down
30 changes: 30 additions & 0 deletions src/files.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export enum Purpose {
finetune = 'fine-tune',
}

export interface FileObject {
id: string;
object: string;
bytes: number;
created_at: number;
filename: string;
purpose?: Purpose;
}

export interface FileDeleted {
id: string;
object: string;
deleted: boolean;
}

export class FilesClient {
constructor(client: MistralClient);

create(options: { file: File; purpose?: string }): Promise<FileObject>;

retrieve(options: { fileId: string }): Promise<FileObject>;

list(): Promise<FileObject[]>;

delete(options: { fileId: string }): Promise<FileDeleted>;
}
65 changes: 65 additions & 0 deletions src/files.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Class representing a client for file operations.
*/
class FilesClient {
/**
* Create a FilesClient object.
* @param {MistralClient} client - The client object used for making requests.
*/
constructor(client) {
this.client = client;
}

/**
* Create a new file.
* @param {File} file - The file to be created.
* @param {string} purpose - The purpose of the file. Default is 'fine-tune'.
* @return {Promise<*>} A promise that resolves to a FileObject.
* @throws {MistralAPIError} If no response is received from the server.
*/
async create({file, purpose = 'fine-tune'}) {
const formData = new FormData();
formData.append('file', file);
formData.append('purpose', purpose);
const response = await this.client._request(
'post',
'v1/files',
null,
undefined,
formData,
);
return response;
}

/**
* Retrieve a file.
* @param {string} fileId - The ID of the file to retrieve.
* @return {Promise<*>} A promise that resolves to the file data.
*/
async retrieve({fileId}) {
const response = await this.client._request('get', `v1/files/${fileId}`);
return response;
}

/**
* List all files.
* @return {Promise<Array<FileObject>>} A promise that resolves to
* an array of FileObject.
*/
async list() {
const response = await this.client._request('get', 'v1/files');
return response;
}

/**
* Delete a file.
* @param {string} fileId - The ID of the file to delete.
* @return {Promise<*>} A promise that resolves to the response.
*/
async delete({fileId}) {
const response = await this.client._request('delete', `v1/files/${fileId}`);
return response;
}
}

export default FilesClient;
Loading