Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Add support for our Fine-tuning API (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuegoio authored Jun 5, 2024
1 parent ac6e138 commit 0d1cc9c
Show file tree
Hide file tree
Showing 13 changed files with 685 additions and 52 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,47 @@ const embeddingsBatchResponse = await client.embeddings({
console.log('Embeddings Batch:', embeddingsBatchResponse.data);
```

### Files

```typescript
// Create a new file
const file = fs.readFileSync('file.jsonl');
const createdFile = await client.files.create({ file });

// List files
const files = await client.files.list();

// Retrieve a file
const retrievedFile = await client.files.retrieve({ fileId: createdFile.id });

// Delete a file
const deletedFile = await client.files.delete({ fileId: createdFile.id });
```

### Fine-tuning Jobs

```typescript
// Create a new job
const createdJob = await client.jobs.create({
model: 'open-mistral-7B',
trainingFiles: [trainingFile.id],
validationFiles: [validationFile.id],
hyperparameters: {
trainingSteps: 10,
learningRate: 0.0001,
},
});

// List jobs
const jobs = await client.jobs.list();

// Retrieve a job
const retrievedJob = await client.jobs.retrieve({ jobId: createdJob.id });

// Cancel a job
const canceledJob = await client.jobs.cancel({ jobId: createdJob.id });
```

## Run examples

You can run the examples in the examples directory by installing them locally:
Expand Down
3 changes: 3 additions & 0 deletions examples/file.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters.", "weight": 0}]}
27 changes: 27 additions & 0 deletions examples/files.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import MistralClient from '@mistralai/mistralai';
import * as fs from 'fs';


const apiKey = process.env.MISTRAL_API_KEY;

const client = new MistralClient(apiKey);

// Create a new file
const blob = new Blob(
[fs.readFileSync('file.jsonl')],
{type: 'application/json'},
);
const createdFile = await client.files.create({file: blob});
console.log(createdFile);

// List files
const files = await client.files.list();
console.log(files);

// Retrieve a file
const retrievedFile = await client.files.retrieve({fileId: createdFile.id});
console.log(retrievedFile);

// Delete a file
const deletedFile = await client.files.delete({fileId: createdFile.id});
console.log(deletedFile);
39 changes: 39 additions & 0 deletions examples/jobs.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import MistralClient from '@mistralai/mistralai';
import * as fs from 'fs';


const apiKey = process.env.MISTRAL_API_KEY;

const client = new MistralClient(apiKey);

// Create a new file
const blob = new Blob(
[fs.readFileSync('file.jsonl')],
{type: 'application/json'},
);
const createdFile = await client.files.create({file: blob});

// Create a new job
const hyperparameters = {
training_steps: 10,
learning_rate: 0.0001,
};
const createdJob = await client.jobs.create({
model: 'open-mistral-7b',
trainingFiles: [createdFile.id],
validationFiles: [createdFile.id],
hyperparameters,
});
console.log(createdJob);

// List jobs
const jobs = await client.jobs.list();
console.log(jobs);

// Retrieve a job
const retrievedJob = await client.jobs.retrieve({jobId: createdJob.id});
console.log(retrievedJob);

// Cancel a job
const canceledJob = await client.jobs.cancel({jobId: createdJob.id});
console.log(canceledJob);
9 changes: 4 additions & 5 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,15 @@ declare module "@mistralai/mistralai" {
): AsyncGenerator<ChatCompletionResponseChunk, void>;

completion(
request: CompletionRequest,
options?: ChatRequestOptions
request: CompletionRequest,
options?: ChatRequestOptions
): Promise<ChatCompletionResponse>;

completionStream(
request: CompletionRequest,
options?: ChatRequestOptions
request: CompletionRequest,
options?: ChatRequestOptions
): AsyncGenerator<ChatCompletionResponseChunk, void>;


embeddings(options: {
model: string;
input: string | string[];
Expand Down
40 changes: 17 additions & 23 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import FilesClient from './files.js';
import JobsClient from './jobs.js';

const VERSION = '0.3.0';
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';
Expand Down Expand Up @@ -79,6 +82,9 @@ class MistralClient {
if (this.endpoint.indexOf('inference.azure.com')) {
this.modelDefault = 'mistral';
}

this.files = new FilesClient(this);
this.jobs = new JobsClient(this);
}

/**
Expand All @@ -98,9 +104,10 @@ class MistralClient {
* @param {*} path
* @param {*} request
* @param {*} signal
* @param {*} formData
* @return {Promise<*>}
*/
_request = async function(method, path, request, signal) {
_request = async function(method, path, request, signal, formData = null) {
const url = `${this.endpoint}/${path}`;
const options = {
method: method,
Expand All @@ -110,13 +117,18 @@ class MistralClient {
'Content-Type': 'application/json',
'Authorization': `Bearer ${this.apiKey}`,
},
body: method !== 'get' ? JSON.stringify(request) : null,
signal: combineSignals([
AbortSignal.timeout(this.timeout * 1000),
signal,
]),
body: method !== 'get' ? formData ?? JSON.stringify(request) : null,
timeout: this.timeout * 1000,
};

if (formData) {
delete options.headers['Content-Type'];
}

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
try {
const response = await this._fetch(url, options);
Expand Down Expand Up @@ -161,7 +173,7 @@ class MistralClient {
} else {
throw new MistralAPIError(
`HTTP error! status: ${response.status} ` +
`Response: \n${await response.text()}`,
`Response: \n${await response.text()}`,
);
}
} catch (error) {
Expand Down Expand Up @@ -467,16 +479,7 @@ class MistralClient {
* @return {Promise<Object>}
*/
completion = async function(
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
{signal} = {},
) {
const request = this._makeCompletionRequest(
Expand Down Expand Up @@ -523,16 +526,7 @@ class MistralClient {
* @return {Promise<Object>}
*/
completionStream = async function* (
{
model,
prompt,
suffix,
temperature,
maxTokens,
topP,
randomSeed,
stop,
},
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
{signal} = {},
) {
const request = this._makeCompletionRequest(
Expand Down
30 changes: 30 additions & 0 deletions src/files.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export enum Purpose {
finetune = 'fine-tune',
}

export interface FileObject {
id: string;
object: string;
bytes: number;
created_at: number;
filename: string;
purpose?: Purpose;
}

export interface FileDeleted {
id: string;
object: string;
deleted: boolean;
}

export class FilesClient {
constructor(client: MistralClient);

create(options: { file: File; purpose?: string }): Promise<FileObject>;

retrieve(options: { fileId: string }): Promise<FileObject>;

list(): Promise<FileObject[]>;

delete(options: { fileId: string }): Promise<FileDeleted>;
}
65 changes: 65 additions & 0 deletions src/files.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Class representing a client for file operations.
*/
class FilesClient {
/**
* Create a FilesClient object.
* @param {MistralClient} client - The client object used for making requests.
*/
constructor(client) {
this.client = client;
}

/**
* Create a new file.
* @param {File} file - The file to be created.
* @param {string} purpose - The purpose of the file. Default is 'fine-tune'.
* @return {Promise<*>} A promise that resolves to a FileObject.
* @throws {MistralAPIError} If no response is received from the server.
*/
async create({file, purpose = 'fine-tune'}) {
const formData = new FormData();
formData.append('file', file);
formData.append('purpose', purpose);
const response = await this.client._request(
'post',
'v1/files',
null,
undefined,
formData,
);
return response;
}

/**
* Retrieve a file.
* @param {string} fileId - The ID of the file to retrieve.
* @return {Promise<*>} A promise that resolves to the file data.
*/
async retrieve({fileId}) {
const response = await this.client._request('get', `v1/files/${fileId}`);
return response;
}

/**
* List all files.
* @return {Promise<Array<FileObject>>} A promise that resolves to
* an array of FileObject.
*/
async list() {
const response = await this.client._request('get', 'v1/files');
return response;
}

/**
* Delete a file.
* @param {string} fileId - The ID of the file to delete.
* @return {Promise<*>} A promise that resolves to the response.
*/
async delete({fileId}) {
const response = await this.client._request('delete', `v1/files/${fileId}`);
return response;
}
}

export default FilesClient;
Loading

0 comments on commit 0d1cc9c

Please sign in to comment.