diff --git a/README.md b/README.md index 23d9e4b..a66e1a3 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,47 @@ const embeddingsBatchResponse = await client.embeddings({ console.log('Embeddings Batch:', embeddingsBatchResponse.data); ``` +### Files + +```typescript +// Create a new file +const file = fs.readFileSync('file.jsonl'); +const createdFile = await client.files.create({ file }); + +// List files +const files = await client.files.list(); + +// Retrieve a file +const retrievedFile = await client.files.retrieve({ fileId: createdFile.id }); + +// Delete a file +const deletedFile = await client.files.delete({ fileId: createdFile.id }); +``` + +### Fine-tuning Jobs + +```typescript +// Create a new job +const createdJob = await client.jobs.create({ + model: 'open-mistral-7B', + trainingFiles: [trainingFile.id], + validationFiles: [validationFile.id], + hyperparameters: { + trainingSteps: 10, + learningRate: 0.0001, + }, +}); + +// List jobs +const jobs = await client.jobs.list(); + +// Retrieve a job +const retrievedJob = await client.jobs.retrieve({ jobId: createdJob.id }); + +// Cancel a job +const canceledJob = await client.jobs.cancel({ jobId: createdJob.id }); +``` + ## Run examples You can run the examples in the examples directory by installing them locally: diff --git a/examples/file.jsonl b/examples/file.jsonl new file mode 100644 index 0000000..51bbfd4 --- /dev/null +++ b/examples/file.jsonl @@ -0,0 +1,3 @@ +{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} +{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} +{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters.", "weight": 0}]} \ No newline at end of file diff --git a/examples/files.js b/examples/files.js new file mode 100644 index 0000000..f89124f --- /dev/null +++ b/examples/files.js @@ -0,0 +1,27 @@ +import MistralClient from '@mistralai/mistralai'; +import * as fs from 'fs'; + + +const apiKey = process.env.MISTRAL_API_KEY; + +const client = new MistralClient(apiKey); + +// Create a new file +const blob = new Blob( + [fs.readFileSync('file.jsonl')], + {type: 'application/json'}, +); +const createdFile = await client.files.create({file: blob}); +console.log(createdFile); + +// List files +const files = await client.files.list(); +console.log(files); + +// Retrieve a file +const retrievedFile = await client.files.retrieve({fileId: createdFile.id}); +console.log(retrievedFile); + +// Delete a file +const deletedFile = await client.files.delete({fileId: createdFile.id}); +console.log(deletedFile); diff --git a/examples/jobs.js b/examples/jobs.js new file mode 100644 index 0000000..f72a1ec --- /dev/null +++ b/examples/jobs.js @@ -0,0 +1,39 @@ +import MistralClient from '@mistralai/mistralai'; +import * as fs from 'fs'; + + +const apiKey = process.env.MISTRAL_API_KEY; + +const client = new MistralClient(apiKey); + +// Create a new file +const blob = new Blob( + [fs.readFileSync('file.jsonl')], + {type: 'application/json'}, +); +const createdFile = await client.files.create({file: blob}); + +// Create a new job +const hyperparameters = { + training_steps: 10, + learning_rate: 0.0001, +}; +const createdJob = await client.jobs.create({ + model: 'open-mistral-7b', + trainingFiles: [createdFile.id], + validationFiles: [createdFile.id], + hyperparameters, +}); +console.log(createdJob); + +// List jobs +const jobs = await client.jobs.list(); +console.log(jobs); + +// Retrieve a job +const retrievedJob = await client.jobs.retrieve({jobId: createdJob.id}); +console.log(retrievedJob); + +// Cancel a job +const canceledJob = await client.jobs.cancel({jobId: createdJob.id}); +console.log(canceledJob); diff --git a/src/client.d.ts b/src/client.d.ts index 411a0b3..fedca05 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -182,16 +182,15 @@ declare module "@mistralai/mistralai" { ): AsyncGenerator; completion( - request: CompletionRequest, - options?: ChatRequestOptions + request: CompletionRequest, + options?: ChatRequestOptions ): Promise; completionStream( - request: CompletionRequest, - options?: ChatRequestOptions + request: CompletionRequest, + options?: ChatRequestOptions ): AsyncGenerator; - embeddings(options: { model: string; input: string | string[]; diff --git a/src/client.js b/src/client.js index 00cf199..c90afbd 100644 --- a/src/client.js +++ b/src/client.js @@ -1,3 +1,6 @@ +import FilesClient from './files.js'; +import JobsClient from './jobs.js'; + const VERSION = '0.3.0'; const RETRY_STATUS_CODES = [429, 500, 502, 503, 504]; const ENDPOINT = 'https://api.mistral.ai'; @@ -79,6 +82,9 @@ class MistralClient { if (this.endpoint.indexOf('inference.azure.com')) { this.modelDefault = 'mistral'; } + + this.files = new FilesClient(this); + this.jobs = new JobsClient(this); } /** @@ -98,9 +104,10 @@ class MistralClient { * @param {*} path * @param {*} request * @param {*} signal + * @param {*} formData * @return {Promise<*>} */ - _request = async function(method, path, request, signal) { + _request = async function(method, path, request, signal, formData = null) { const url = `${this.endpoint}/${path}`; const options = { method: method, @@ -110,13 +117,18 @@ class MistralClient { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.apiKey}`, }, - body: method !== 'get' ? JSON.stringify(request) : null, signal: combineSignals([ AbortSignal.timeout(this.timeout * 1000), signal, ]), + body: method !== 'get' ? formData ?? JSON.stringify(request) : null, + timeout: this.timeout * 1000, }; + if (formData) { + delete options.headers['Content-Type']; + } + for (let attempts = 0; attempts < this.maxRetries; attempts++) { try { const response = await this._fetch(url, options); @@ -161,7 +173,7 @@ class MistralClient { } else { throw new MistralAPIError( `HTTP error! status: ${response.status} ` + - `Response: \n${await response.text()}`, + `Response: \n${await response.text()}`, ); } } catch (error) { @@ -467,16 +479,7 @@ class MistralClient { * @return {Promise} */ completion = async function( - { - model, - prompt, - suffix, - temperature, - maxTokens, - topP, - randomSeed, - stop, - }, + {model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop}, {signal} = {}, ) { const request = this._makeCompletionRequest( @@ -523,16 +526,7 @@ class MistralClient { * @return {Promise} */ completionStream = async function* ( - { - model, - prompt, - suffix, - temperature, - maxTokens, - topP, - randomSeed, - stop, - }, + {model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop}, {signal} = {}, ) { const request = this._makeCompletionRequest( diff --git a/src/files.d.ts b/src/files.d.ts new file mode 100644 index 0000000..a011f37 --- /dev/null +++ b/src/files.d.ts @@ -0,0 +1,30 @@ +export enum Purpose { + finetune = 'fine-tune', +} + +export interface FileObject { + id: string; + object: string; + bytes: number; + created_at: number; + filename: string; + purpose?: Purpose; +} + +export interface FileDeleted { + id: string; + object: string; + deleted: boolean; +} + +export class FilesClient { + constructor(client: MistralClient); + + create(options: { file: File; purpose?: string }): Promise; + + retrieve(options: { fileId: string }): Promise; + + list(): Promise; + + delete(options: { fileId: string }): Promise; +} diff --git a/src/files.js b/src/files.js new file mode 100644 index 0000000..d750592 --- /dev/null +++ b/src/files.js @@ -0,0 +1,65 @@ +/** + * Class representing a client for file operations. + */ +class FilesClient { + /** + * Create a FilesClient object. + * @param {MistralClient} client - The client object used for making requests. + */ + constructor(client) { + this.client = client; + } + + /** + * Create a new file. + * @param {File} file - The file to be created. + * @param {string} purpose - The purpose of the file. Default is 'fine-tune'. + * @return {Promise<*>} A promise that resolves to a FileObject. + * @throws {MistralAPIError} If no response is received from the server. + */ + async create({file, purpose = 'fine-tune'}) { + const formData = new FormData(); + formData.append('file', file); + formData.append('purpose', purpose); + const response = await this.client._request( + 'post', + 'v1/files', + null, + undefined, + formData, + ); + return response; + } + + /** + * Retrieve a file. + * @param {string} fileId - The ID of the file to retrieve. + * @return {Promise<*>} A promise that resolves to the file data. + */ + async retrieve({fileId}) { + const response = await this.client._request('get', `v1/files/${fileId}`); + return response; + } + + /** + * List all files. + * @return {Promise>} A promise that resolves to + * an array of FileObject. + */ + async list() { + const response = await this.client._request('get', 'v1/files'); + return response; + } + + /** + * Delete a file. + * @param {string} fileId - The ID of the file to delete. + * @return {Promise<*>} A promise that resolves to the response. + */ + async delete({fileId}) { + const response = await this.client._request('delete', `v1/files/${fileId}`); + return response; + } +} + +export default FilesClient; diff --git a/src/jobs.d.ts b/src/jobs.d.ts new file mode 100644 index 0000000..5eeebfe --- /dev/null +++ b/src/jobs.d.ts @@ -0,0 +1,86 @@ +export enum JobStatus { + QUEUED = 'QUEUED', + STARTED = 'STARTED', + RUNNING = 'RUNNING', + FAILED = 'FAILED', + SUCCESS = 'SUCCESS', + CANCELLED = 'CANCELLED', + CANCELLATION_REQUESTED = 'CANCELLATION_REQUESTED', +} + +export interface TrainingParameters { + training_steps: number; + learning_rate: number; +} + +export interface WandbIntegration { + type: Literal<'wandb'>; + project: string; + name: string | null; + api_key: string | null; + run_name: string | null; +} + +export type Integration = WandbIntegration; + +export interface Job { + id: string; + hyperparameters: TrainingParameters; + fine_tuned_model: string; + model: string; + status: JobStatus; + jobType: string; + created_at: number; + modified_at: number; + training_files: string[]; + validation_files?: string[]; + object: 'job'; + integrations: Integration[]; +} + +export interface Event { + name: string; + data?: Record; + created_at: number; +} + +export interface Metric { + train_loss: float | null; + valid_loss: float | null; + valid_mean_token_accuracy: float | null; +} + +export interface Checkpoint { + metrics: Metric; + step_number: int; + created_at: int; +} + +export interface DetailedJob extends Job { + events: Event[]; + checkpoints: Checkpoint[]; +} + +export interface Jobs { + data: Job[]; + object: 'list'; +} + +export class JobsClient { + constructor(client: MistralClient); + + create(options: { + model: string; + trainingFiles: string[]; + validationFiles?: string[]; + hyperparameters?: TrainingParameters; + suffix?: string; + integrations?: Integration[]; + }): Promise; + + retrieve(options: { jobId: string }): Promise; + + list(params?: Record): Promise; + + cancel(options: { jobId: string }): Promise; +} diff --git a/src/jobs.js b/src/jobs.js new file mode 100644 index 0000000..043c6be --- /dev/null +++ b/src/jobs.js @@ -0,0 +1,82 @@ +/** + * Class representing a client for job operations. + */ +class JobsClient { + /** + * Create a JobsClient object. + * @param {MistralClient} client - The client object used for making requests. + */ + constructor(client) { + this.client = client; + } + + /** + * Create a new job. + * @param {string} model - The model to be used for the job. + * @param {Array} trainingFiles - The list of training files. + * @param {Array} validationFiles - The list of validation files. + * @param {TrainingParameters} hyperparameters - The hyperparameters. + * @param {string} suffix - The suffix for the job. + * @param {Array} integrations - The integrations for the job. + * @return {Promise<*>} A promise that resolves to a Job object. + * @throws {MistralAPIError} If no response is received from the server. + */ + async create({ + model, + trainingFiles, + validationFiles = [], + hyperparameters = { + training_steps: 1800, + learning_rate: 1.0e-4, + }, + suffix = null, + integrations = null, + }) { + const response = await this.client._request('post', 'v1/fine_tuning/jobs', { + model, + training_files: trainingFiles, + validation_files: validationFiles, + hyperparameters, + suffix, + integrations, + }); + return response; + } + + /** + * Retrieve a job. + * @param {string} jobId - The ID of the job to retrieve. + * @return {Promise<*>} A promise that resolves to the job data. + */ + async retrieve({jobId}) { + const response = await this.client._request( + 'get', `v1/fine_tuning/jobs/${jobId}`, {}, + ); + return response; + } + + /** + * List all jobs. + * @return {Promise>} A promise that resolves to an array of Job. + */ + async list() { + const response = await this.client._request( + 'get', 'v1/fine_tuning/jobs', {}, + ); + return response; + } + + /** + * Cancel a job. + * @param {string} jobId - The ID of the job to cancel. + * @return {Promise<*>} A promise that resolves to the response. + */ + async cancel({jobId}) { + const response = await this.client._request( + 'post', `v1/fine_tuning/jobs/${jobId}/cancel`, {}, + ); + return response; + } +} + +export default JobsClient; diff --git a/tests/files.test.js b/tests/files.test.js new file mode 100644 index 0000000..fb54676 --- /dev/null +++ b/tests/files.test.js @@ -0,0 +1,65 @@ +import MistralClient from '../src/client'; +import { + mockFetch, + mockFileResponsePayload, + mockFilesResponsePayload, + mockDeletedFileResponsePayload, +} from './utils'; + +// Test the list models endpoint +describe('Mistral Client', () => { + let client; + beforeEach(() => { + client = new MistralClient(); + }); + + describe('create()', () => { + it('should return a file response object', async() => { + // Mock the fetch function + const mockResponse = mockFileResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.files.create({ + file: null, + }); + expect(response).toEqual(mockResponse); + }); + }); + + describe('retrieve()', () => { + it('should return a file response object', async() => { + // Mock the fetch function + const mockResponse = mockFileResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.files.retrieve({ + fileId: 'fileId', + }); + expect(response).toEqual(mockResponse); + }); + }); + + describe('retrieve()', () => { + it('should return a list of files response object', async() => { + // Mock the fetch function + const mockResponse = mockFilesResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.files.list(); + expect(response).toEqual(mockResponse); + }); + }); + + describe('delete()', () => { + it('should return a deleted file response object', async() => { + // Mock the fetch function + const mockResponse = mockDeletedFileResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.files.delete({ + fileId: 'fileId', + }); + expect(response).toEqual(mockResponse); + }); + }); +}); diff --git a/tests/jobs.test.js b/tests/jobs.test.js new file mode 100644 index 0000000..450ea92 --- /dev/null +++ b/tests/jobs.test.js @@ -0,0 +1,71 @@ +import MistralClient from '../src/client'; +import { + mockFetch, + mockJobResponsePayload, + mockJobsResponsePayload, + mockDeletedJobResponsePayload, +} from './utils'; + +// Test the jobs endpoint +describe('Mistral Client', () => { + let client; + beforeEach(() => { + client = new MistralClient(); + }); + + describe('createJob()', () => { + it('should return a job response object', async() => { + // Mock the fetch function + const mockResponse = mockJobResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.jobs.create({ + model: 'mistral-medium', + trainingFiles: [], + validationFiles: [], + hyperparameters: { + training_steps: 1800, + learning_rate: 1.0e-4, + }, + }); + expect(response).toEqual(mockResponse); + }); + }); + + describe('retrieveJob()', () => { + it('should return a job response object', async() => { + // Mock the fetch function + const mockResponse = mockJobResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.jobs.retrieve({ + jobId: 'jobId', + }); + expect(response).toEqual(mockResponse); + }); + }); + + describe('listJobs()', () => { + it('should return a list of jobs response object', async() => { + // Mock the fetch function + const mockResponse = mockJobsResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.jobs.list(); + expect(response).toEqual(mockResponse); + }); + }); + + describe('cancelJob()', () => { + it('should return a deleted job response object', async() => { + // Mock the fetch function + const mockResponse = mockDeletedJobResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.jobs.cancel({ + jobId: 'jobId', + }); + expect(response).toEqual(mockResponse); + }); + }); +}); diff --git a/tests/utils.js b/tests/utils.js index b49f1d0..8b00ca6 100644 --- a/tests/utils.js +++ b/tests/utils.js @@ -183,40 +183,45 @@ export function mockChatResponsePayload() { */ export function mockChatResponseStreamingPayload() { const encoder = new TextEncoder(); - const firstMessage = - [encoder.encode('data: ' + - JSON.stringify({ - id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', - model: 'mistral-small-latest', - choices: [ - { - index: 0, - delta: {role: 'assistant'}, - finish_reason: null, - }, - ], - }) + - '\n\n')]; - const lastMessage = [encoder.encode('data: [DONE]\n\n')]; - - const dataMessages = []; - for (let i = 0; i < 10; i++) { - dataMessages.push(encoder.encode( + const firstMessage = [ + encoder.encode( 'data: ' + JSON.stringify({ id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', - object: 'chat.completion.chunk', - created: 1703168544, model: 'mistral-small-latest', choices: [ { - index: i, - delta: {content: `stream response ${i}`}, + index: 0, + delta: {role: 'assistant'}, finish_reason: null, }, ], }) + - '\n\n'), + '\n\n', + ), + ]; + const lastMessage = [encoder.encode('data: [DONE]\n\n')]; + + const dataMessages = []; + for (let i = 0; i < 10; i++) { + dataMessages.push( + encoder.encode( + 'data: ' + + JSON.stringify({ + id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', + object: 'chat.completion.chunk', + created: 1703168544, + model: 'mistral-small-latest', + choices: [ + { + index: i, + delta: {content: `stream response ${i}`}, + finish_reason: null, + }, + ], + }) + + '\n\n', + ), ); } @@ -255,3 +260,129 @@ export function mockEmbeddingRequest() { input: 'embed', }; } + +/** + * Mock file response payload + * @return {Object} + */ +export function mockFileResponsePayload() { + return { + id: 'fileId', + object: 'file', + bytes: 0, + created_at: 1633046400000, + filename: 'file.jsonl', + purpose: 'fine-tune', + }; +} + +/** + * Mock files response payload + * @return {Object} + */ +export function mockFilesResponsePayload() { + return { + data: [ + { + id: 'fileId', + object: 'file', + bytes: 0, + created_at: 1633046400000, + filename: 'file.jsonl', + purpose: 'fine-tune', + }, + ], + object: 'list', + }; +} + +/** + * Mock deleted file response payload + * @return {Object} + */ +export function mockDeletedFileResponsePayload() { + return { + id: 'fileId', + object: 'file', + deleted: true, + }; +} + +/** + * Mock job response payload + * @return {Object} + */ +export function mockJobResponsePayload() { + return { + id: 'jobId', + hyperparameters: { + training_steps: 1800, + learning_rate: 1.0e-4, + }, + fine_tuned_model: 'fine_tuned_model_id', + model: 'mistral-medium', + status: 'QUEUED', + job_type: 'fine_tuning', + created_at: 1633046400000, + modified_at: 1633046400000, + training_files: ['file1.jsonl', 'file2.jsonl'], + validation_files: ['file3.jsonl', 'file4.jsonl'], + object: 'job', + }; +} + +/** + * Mock jobs response payload + * @return {Object} + */ +export function mockJobsResponsePayload() { + return { + data: [ + { + id: 'jobId1', + hyperparameters: { + training_steps: 1800, + learning_rate: 1.0e-4, + }, + fine_tuned_model: 'fine_tuned_model_id1', + model: 'mistral-medium', + status: 'QUEUED', + job_type: 'fine_tuning', + created_at: 1633046400000, + modified_at: 1633046400000, + training_files: ['file1.jsonl', 'file2.jsonl'], + validation_files: ['file3.jsonl', 'file4.jsonl'], + object: 'job', + }, + { + id: 'jobId2', + hyperparameters: { + training_steps: 1800, + learning_rate: 1.0e-4, + }, + fine_tuned_model: 'fine_tuned_model_id2', + model: 'mistral-medium', + status: 'RUNNING', + job_type: 'fine_tuning', + created_at: 1633046400000, + modified_at: 1633046400000, + training_files: ['file5.jsonl', 'file6.jsonl'], + validation_files: ['file7.jsonl', 'file8.jsonl'], + object: 'job', + }, + ], + object: 'list', + }; +} + +/** + * Mock deleted job response payload + * @return {Object} + */ +export function mockDeletedJobResponsePayload() { + return { + id: 'jobId', + object: 'job', + deleted: true, + }; +}