diff --git a/packages/backend/package.json b/packages/backend/package.json index 1051fde0f..2d2729995 100644 --- a/packages/backend/package.json +++ b/packages/backend/package.json @@ -65,7 +65,7 @@ "xml-js": "^1.6.11" }, "devDependencies": { - "@podman-desktop/api": "0.0.202404101645-5d46ba5", + "@podman-desktop/api": "1.10.3", "@types/js-yaml": "^4.0.9", "@types/node": "^20", "@types/postman-collection": "^3.5.10", diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 42f8e2c01..6c073d979 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -47,11 +47,11 @@ import { ApplicationRegistry } from '../registries/ApplicationRegistry'; import type { TaskRegistry } from '../registries/TaskRegistry'; import { Publisher } from '../utils/Publisher'; import { isQEMUMachine } from '../utils/podman'; -import { SECOND } from '../utils/inferenceUtils'; import { getModelPropertiesForEnvironment } from '../utils/modelsUtils'; import { getRandomName } from '../utils/randomUtils'; import type { BuilderManager } from './recipes/BuilderManager'; import type { PodManager } from './recipes/PodManager'; +import { SECOND } from '../workers/provider/LlamaCppPython'; export const LABEL_MODEL_ID = 'ai-lab-model-id'; export const LABEL_MODEL_PORTS = 'ai-lab-model-ports'; diff --git a/packages/backend/src/managers/inference/inferenceManager.spec.ts b/packages/backend/src/managers/inference/inferenceManager.spec.ts index 7ee3ae68f..c6d81b886 100644 --- a/packages/backend/src/managers/inference/inferenceManager.spec.ts +++ b/packages/backend/src/managers/inference/inferenceManager.spec.ts @@ -17,43 +17,36 @@ ***********************************************************************/ import { containerEngine, - provider, type Webview, type TelemetryLogger, - type ImageInfo, type ContainerInfo, type ContainerInspectInfo, - type ProviderContainerConnection, } from '@podman-desktop/api'; import type { ContainerRegistry } from '../../registries/ContainerRegistry'; import type { PodmanConnection } from '../podmanConnection'; import { beforeEach, expect, describe, test, vi } from 'vitest'; import { InferenceManager } from './inferenceManager'; import type { ModelsManager } from '../modelsManager'; -import { LABEL_INFERENCE_SERVER, INFERENCE_SERVER_IMAGE } from '../../utils/inferenceUtils'; +import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import type { TaskRegistry } from '../../registries/TaskRegistry'; import { Messages } from '@shared/Messages'; +import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry'; +import type { InferenceProvider } from '../../workers/provider/InferenceProvider'; vi.mock('@podman-desktop/api', async () => { return { containerEngine: { startContainer: vi.fn(), stopContainer: vi.fn(), - listContainers: vi.fn(), inspectContainer: vi.fn(), - pullImage: vi.fn(), - listImages: vi.fn(), - createContainer: vi.fn(), deleteContainer: vi.fn(), + listContainers: vi.fn(), }, Disposable: { from: vi.fn(), create: vi.fn(), }, - provider: { - getContainerConnections: vi.fn(), - }, }; }); @@ -87,6 +80,11 @@ const taskRegistryMock = { getTasksByLabels: vi.fn(), } as unknown as TaskRegistry; +const inferenceProviderRegistryMock = { + getAll: vi.fn(), + get: vi.fn(), +} as unknown as InferenceProviderRegistry; + const getInitializedInferenceManager = async (): Promise => { const manager = new InferenceManager( webviewMock, @@ -95,6 +93,7 @@ const getInitializedInferenceManager = async (): Promise => { modelsManager, telemetryMock, taskRegistryMock, + inferenceProviderRegistryMock, ); manager.init(); await vi.waitUntil(manager.isInitialize.bind(manager), { @@ -119,26 +118,6 @@ beforeEach(() => { Health: undefined, }, } as unknown as ContainerInspectInfo); - vi.mocked(provider.getContainerConnections).mockReturnValue([ - { - providerId: 'test@providerId', - connection: { - type: 'podman', - name: 'test@connection', - status: () => 'started', - }, - } as unknown as ProviderContainerConnection, - ]); - vi.mocked(containerEngine.listImages).mockResolvedValue([ - { - Id: 'dummyImageId', - engineId: 'dummyEngineId', - RepoTags: [INFERENCE_SERVER_IMAGE], - }, - ] as unknown as ImageInfo[]); - vi.mocked(containerEngine.createContainer).mockResolvedValue({ - id: 'dummyCreatedContainerId', - }); vi.mocked(taskRegistryMock.getTasksByLabels).mockReturnValue([]); vi.mocked(modelsManager.getLocalModelPath).mockReturnValue('/local/model.guff'); vi.mocked(modelsManager.uploadModelToPodmanMachine).mockResolvedValue('/mnt/path/model.guff'); @@ -233,119 +212,59 @@ describe('init Inference Manager', () => { * Testing the creation logic */ describe('Create Inference Server', () => { - test('unknown providerId', async () => { - const inferenceManager = await getInitializedInferenceManager(); - await expect( - inferenceManager.createInferenceServer( - { - providerId: 'unknown', - } as unknown as InferenceServerConfig, - 'dummyTrackingId', - ), - ).rejects.toThrowError('cannot find any started container provider.'); + test('no provider available should throw an error', async () => { + vi.mocked(inferenceProviderRegistryMock.getAll).mockReturnValue([]); - expect(provider.getContainerConnections).toHaveBeenCalled(); - }); - - test('unknown imageId', async () => { const inferenceManager = await getInitializedInferenceManager(); await expect( - inferenceManager.createInferenceServer( - { - providerId: 'test@providerId', - image: 'unknown', - } as unknown as InferenceServerConfig, - 'dummyTrackingId', - ), - ).rejects.toThrowError('image unknown not found.'); - - expect(containerEngine.listImages).toHaveBeenCalled(); + inferenceManager.createInferenceServer({ + inferenceProvider: undefined, + labels: {}, + modelsInfo: [], + port: 8888, + }), + ).rejects.toThrowError('no enabled provider could be found.'); }); - test('empty modelsInfo', async () => { + test('inference provider provided should use get from InferenceProviderRegistry', async () => { + vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue({ + enabled: () => false, + } as unknown as InferenceProvider); + const inferenceManager = await getInitializedInferenceManager(); await expect( - inferenceManager.createInferenceServer( - { - providerId: 'test@providerId', - image: INFERENCE_SERVER_IMAGE, - modelsInfo: [], - } as unknown as InferenceServerConfig, - 'dummyTrackingId', - ), - ).rejects.toThrowError('Need at least one model info to start an inference server.'); + inferenceManager.createInferenceServer({ + inferenceProvider: 'dummy-inference-provider', + labels: {}, + modelsInfo: [], + port: 8888, + }), + ).rejects.toThrowError('provider requested is not enabled.'); + expect(inferenceProviderRegistryMock.get).toHaveBeenCalledWith('dummy-inference-provider'); }); - test('valid InferenceServerConfig', async () => { + test('selected inference provider should receive config', async () => { + const provider: InferenceProvider = { + enabled: () => true, + name: 'dummy-inference-provider', + dispose: () => {}, + perform: vi.fn().mockResolvedValue({ id: 'dummy-container-id', engineId: 'dummy-engine-id' }), + } as unknown as InferenceProvider; + vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue(provider); + const inferenceManager = await getInitializedInferenceManager(); - await inferenceManager.createInferenceServer( - { - port: 8888, - providerId: 'test@providerId', - image: INFERENCE_SERVER_IMAGE, - modelsInfo: [ - { - id: 'dummyModelId', - file: { - file: 'model.guff', - path: '/mnt/path', - }, - }, - ], - } as unknown as InferenceServerConfig, - 'dummyTrackingId', - ); - expect(modelsManager.uploadModelToPodmanMachine).toHaveBeenCalledWith( - { - id: 'dummyModelId', - file: { - file: 'model.guff', - path: '/mnt/path', - }, - }, - { - trackingId: 'dummyTrackingId', - }, - ); - expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith( - 1, - expect.stringContaining( - 'Pulling ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:', - ), - 'loading', - { - trackingId: 'dummyTrackingId', - }, - ); - expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith(2, 'Creating container.', 'loading', { - trackingId: 'dummyTrackingId', - }); - expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({ - state: 'success', - }); - expect(containerEngine.createContainer).toHaveBeenCalled(); - expect(inferenceManager.getServers()).toStrictEqual([ - { - connection: { - port: 8888, - }, - container: { - containerId: 'dummyCreatedContainerId', - engineId: 'dummyEngineId', - }, - models: [ - { - file: { - file: 'model.guff', - path: '/mnt/path', - }, - id: 'dummyModelId', - }, - ], - status: 'running', - }, - ]); + const config: InferenceServerConfig = { + inferenceProvider: 'dummy-inference-provider', + labels: {}, + modelsInfo: [], + port: 8888, + }; + const result = await inferenceManager.createInferenceServer(config); + + expect(provider.perform).toHaveBeenCalledWith(config); + + expect(result).toBe('dummy-container-id'); }); }); @@ -511,33 +430,6 @@ describe('Request Create Inference Server', () => { trackingId: identifier, }); }); - - test('Pull image error should be reflected in task registry', async () => { - vi.mocked(containerEngine.pullImage).mockRejectedValue(new Error('dummy pull image error')); - - const inferenceManager = await getInitializedInferenceManager(); - inferenceManager.requestCreateInferenceServer({ - port: 8888, - providerId: 'test@providerId', - image: 'quay.io/bootsy/playground:v0', - modelsInfo: [ - { - id: 'dummyModelId', - file: { - file: 'dummyFile', - path: 'dummyPath', - }, - }, - ], - } as unknown as InferenceServerConfig); - - await vi.waitFor(() => { - expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({ - state: 'error', - error: 'Something went wrong while trying to create an inference server Error: dummy pull image error.', - }); - }); - }); }); describe('containerRegistry events', () => { diff --git a/packages/backend/src/managers/inference/inferenceManager.ts b/packages/backend/src/managers/inference/inferenceManager.ts index e76494ffd..23cd7d6ff 100644 --- a/packages/backend/src/managers/inference/inferenceManager.ts +++ b/packages/backend/src/managers/inference/inferenceManager.ts @@ -18,21 +18,9 @@ import type { InferenceServer, InferenceServerStatus } from '@shared/src/models/IInference'; import type { PodmanConnection } from '../podmanConnection'; import { containerEngine, Disposable } from '@podman-desktop/api'; -import { - type ContainerInfo, - type ImageInfo, - type PullEvent, - type TelemetryLogger, - type Webview, -} from '@podman-desktop/api'; +import { type ContainerInfo, type TelemetryLogger, type Webview } from '@podman-desktop/api'; import type { ContainerRegistry, ContainerStart } from '../../registries/ContainerRegistry'; -import { - generateContainerCreateOptions, - getImageInfo, - getProviderContainerConnection, - isTransitioning, - LABEL_INFERENCE_SERVER, -} from '../../utils/inferenceUtils'; +import { isTransitioning, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; import { Publisher } from '../../utils/Publisher'; import { Messages } from '@shared/Messages'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; @@ -40,6 +28,8 @@ import type { ModelsManager } from '../modelsManager'; import type { TaskRegistry } from '../../registries/TaskRegistry'; import { getRandomString } from '../../utils/randomUtils'; import { basename, dirname } from 'node:path'; +import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry'; +import type { InferenceProvider } from '../../workers/provider/InferenceProvider'; export class InferenceManager extends Publisher implements Disposable { // Inference server map (containerId -> InferenceServer) @@ -56,6 +46,7 @@ export class InferenceManager extends Publisher implements Di private modelsManager: ModelsManager, private telemetry: TelemetryLogger, private taskRegistry: TaskRegistry, + private inferenceProviderRegistry: InferenceProviderRegistry, ) { super(webview, Messages.MSG_INFERENCE_SERVERS_UPDATE, () => this.getServers()); this.#servers = new Map(); @@ -116,12 +107,19 @@ export class InferenceManager extends Publisher implements Di * @return a unique tracking identifier to follow the creation request */ requestCreateInferenceServer(config: InferenceServerConfig): string { + // create a tracking id to put in the labels const trackingId: string = getRandomString(); + + config.labels = { + ...config.labels, + trackingId: trackingId, + }; + const task = this.taskRegistry.createTask('Creating Inference server', 'loading', { trackingId: trackingId, }); - this.createInferenceServer(config, trackingId) + this.createInferenceServer(config) .then((containerId: string) => { this.taskRegistry.updateTask({ ...task, @@ -157,64 +155,46 @@ export class InferenceManager extends Publisher implements Di } /** - * Given an engineId, it will create an inference server. + * Given an engineId, it will create an inference server using an InferenceProvider. * @param config - * @param trackingId * * @return the containerId of the created inference server */ - async createInferenceServer(config: InferenceServerConfig, trackingId: string): Promise { + async createInferenceServer(config: InferenceServerConfig): Promise { if (!this.isInitialize()) throw new Error('Cannot start the inference server: not initialized.'); - // Fetch a provider container connection - const provider = getProviderContainerConnection(config.providerId); - - // Creating a task to follow pulling progress - const pullingTask = this.taskRegistry.createTask(`Pulling ${config.image}.`, 'loading', { trackingId: trackingId }); - - // Get the image inspect info - const imageInfo: ImageInfo = await getImageInfo(provider.connection, config.image, (_event: PullEvent) => {}); - - this.taskRegistry.updateTask({ - ...pullingTask, - state: 'success', - progress: undefined, - }); + let provider: InferenceProvider; + if (config.inferenceProvider) { + provider = this.inferenceProviderRegistry.get(config.inferenceProvider); + if (!provider.enabled()) throw new Error('provider requested is not enabled.'); + } else { + const providers: InferenceProvider[] = this.inferenceProviderRegistry + .getAll() + .filter(provider => provider.enabled()); + if (providers.length === 0) throw new Error('no enabled provider could be found.'); + provider = providers[0]; + } // upload models to podman machine if user system is supported config.modelsInfo = await Promise.all( config.modelsInfo.map(modelInfo => - this.modelsManager - .uploadModelToPodmanMachine(modelInfo, { - trackingId: trackingId, - }) - .then(path => ({ - ...modelInfo, - file: { - path: dirname(path), - file: basename(path), - }, - })), + this.modelsManager.uploadModelToPodmanMachine(modelInfo, config.labels).then(path => ({ + ...modelInfo, + file: { + path: dirname(path), + file: basename(path), + }, + })), ), ); - const containerTask = this.taskRegistry.createTask(`Creating container.`, 'loading', { trackingId: trackingId }); - - // Create container on requested engine - const result = await containerEngine.createContainer( - imageInfo.engineId, - generateContainerCreateOptions(config, imageInfo), - ); - - this.taskRegistry.updateTask({ - ...containerTask, - state: 'success', - }); + // create the inference server using the selected inference provider + const result = await provider.perform(config); // Adding a new inference server this.#servers.set(result.id, { container: { - engineId: imageInfo.engineId, + engineId: result.engineId, containerId: result.id, }, connection: { @@ -225,7 +205,7 @@ export class InferenceManager extends Publisher implements Di }); // Watch for container changes - this.watchContainerStatus(imageInfo.engineId, result.id); + this.watchContainerStatus(result.engineId, result.id); // Log usage this.telemetry.logUsage('inference.start', { diff --git a/packages/backend/src/managers/playgroundV2Manager.spec.ts b/packages/backend/src/managers/playgroundV2Manager.spec.ts index e2a845d2f..2cc121a3c 100644 --- a/packages/backend/src/managers/playgroundV2Manager.spec.ts +++ b/packages/backend/src/managers/playgroundV2Manager.spec.ts @@ -24,7 +24,6 @@ import type { InferenceServer } from '@shared/src/models/IInference'; import type { InferenceManager } from './inference/inferenceManager'; import { Messages } from '@shared/Messages'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import { INFERENCE_SERVER_IMAGE } from '../utils/inferenceUtils'; import type { TaskRegistry } from '../registries/TaskRegistry'; import type { Task, TaskState } from '@shared/src/models/ITask'; @@ -333,20 +332,21 @@ test('creating a new playground with no model served should start an inference s } as unknown as ModelInfo, 'tracking-1', ); - expect(createInferenceServerMock).toHaveBeenCalledWith( - { - image: INFERENCE_SERVER_IMAGE, - labels: {}, - modelsInfo: [ - { - id: 'model-1', - name: 'Model 1', - }, - ], - port: expect.anything(), + expect(createInferenceServerMock).toHaveBeenCalledWith({ + image: undefined, + providerId: undefined, + inferenceProvider: undefined, + labels: { + trackingId: 'tracking-1', }, - expect.anything(), - ); + modelsInfo: [ + { + id: 'model-1', + name: 'Model 1', + }, + ], + port: expect.anything(), + }); }); test('creating a new playground with the model already served should not start an inference server', async () => { diff --git a/packages/backend/src/managers/playgroundV2Manager.ts b/packages/backend/src/managers/playgroundV2Manager.ts index 116799169..d6f770285 100644 --- a/packages/backend/src/managers/playgroundV2Manager.ts +++ b/packages/backend/src/managers/playgroundV2Manager.ts @@ -117,8 +117,10 @@ export class PlaygroundV2Manager implements Disposable { await this.inferenceManager.createInferenceServer( await withDefaultConfiguration({ modelsInfo: [model], + labels: { + trackingId: trackingId, + }, }), - trackingId, ); } else if (server.status === 'stopped') { await this.inferenceManager.startInferenceServer(server.container.containerId); diff --git a/packages/backend/src/registries/InferenceProviderRegistry.ts b/packages/backend/src/registries/InferenceProviderRegistry.ts new file mode 100644 index 000000000..2dcda64fe --- /dev/null +++ b/packages/backend/src/registries/InferenceProviderRegistry.ts @@ -0,0 +1,52 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ +import { Publisher } from '../utils/Publisher'; +import type { InferenceProvider } from '../workers/provider/InferenceProvider'; +import { Disposable, type Webview } from '@podman-desktop/api'; +import { Messages } from '@shared/Messages'; + +export class InferenceProviderRegistry extends Publisher { + #providers: Map; + constructor(webview: Webview) { + super(webview, Messages.MSG_INFERENCE_PROVIDER_UPDATE, () => this.getAll().map(provider => provider.name)); + this.#providers = new Map(); + } + + register(provider: InferenceProvider): Disposable { + this.#providers.set(provider.name, provider); + + this.notify(); + return Disposable.create(() => { + this.unregister(provider.name); + }); + } + + unregister(name: string): void { + this.#providers.delete(name); + } + + getAll(): InferenceProvider[] { + return Array.from(this.#providers.values()); + } + + get(name: string): InferenceProvider { + const provider = this.#providers.get(name); + if (provider === undefined) throw new Error(`no provider with name ${name} was found.`); + return provider; + } +} diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 40413dd1a..9fb5c4c83 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -42,8 +42,8 @@ import { engines } from '../package.json'; import { BuilderManager } from './managers/recipes/BuilderManager'; import { PodManager } from './managers/recipes/PodManager'; import { initWebview } from './webviewUtils'; - -export const AI_LAB_COLLECT_GPU_COMMAND = 'ai-lab.gpu.collect'; +import { LlamaCppPython } from './workers/provider/LlamaCppPython'; +import { InferenceProviderRegistry } from './registries/InferenceProviderRegistry'; export class Studio { readonly #extensionContext: ExtensionContext; @@ -56,6 +56,7 @@ export class Studio { modelsManager: ModelsManager | undefined; telemetry: TelemetryLogger | undefined; + #taskRegistry: TaskRegistry | undefined; #inferenceManager: InferenceManager | undefined; constructor(readonly extensionContext: ExtensionContext) { @@ -112,13 +113,19 @@ export class Studio { const gitManager = new GitManager(); const podmanConnection = new PodmanConnection(); - const taskRegistry = new TaskRegistry(this.#panel.webview); + this.#taskRegistry = new TaskRegistry(this.#panel.webview); + + // Init the inference provider registry + const inferenceProviderRegistry = new InferenceProviderRegistry(this.#panel.webview); + this.#extensionContext.subscriptions.push( + inferenceProviderRegistry.register(new LlamaCppPython(this.#taskRegistry)), + ); // Create catalog manager, responsible for loading the catalog files and watching for changes this.catalogManager = new CatalogManager(this.#panel.webview, appUserDirectory); this.catalogManager.init(); - const builderManager = new BuilderManager(taskRegistry); + const builderManager = new BuilderManager(this.#taskRegistry); this.#extensionContext.subscriptions.push(builderManager); const podManager = new PodManager(); @@ -130,7 +137,7 @@ export class Studio { this.#panel.webview, this.catalogManager, this.telemetry, - taskRegistry, + this.#taskRegistry, cancellationTokenRegistry, ); this.modelsManager.init(); @@ -139,7 +146,7 @@ export class Studio { const applicationManager = new ApplicationManager( appUserDirectory, gitManager, - taskRegistry, + this.#taskRegistry, this.#panel.webview, podmanConnection, this.catalogManager, @@ -156,7 +163,8 @@ export class Studio { podmanConnection, this.modelsManager, this.telemetry, - taskRegistry, + this.#taskRegistry, + inferenceProviderRegistry, ); this.#panel.onDidChangeViewState((e: WebviewPanelOnDidChangeViewStateEvent) => { @@ -172,7 +180,7 @@ export class Studio { const playgroundV2 = new PlaygroundV2Manager( this.#panel.webview, this.#inferenceManager, - taskRegistry, + this.#taskRegistry, this.telemetry, ); @@ -186,7 +194,7 @@ export class Studio { this.modelsManager, this.telemetry, localRepositoryRegistry, - taskRegistry, + this.#taskRegistry, this.#inferenceManager, playgroundV2, snippetManager, diff --git a/packages/backend/src/utils/inferenceUtils.spec.ts b/packages/backend/src/utils/inferenceUtils.spec.ts index 6010a076d..fc4e6ce0b 100644 --- a/packages/backend/src/utils/inferenceUtils.spec.ts +++ b/packages/backend/src/utils/inferenceUtils.spec.ts @@ -16,15 +16,7 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ import { vi, test, expect, describe, beforeEach } from 'vitest'; -import { - generateContainerCreateOptions, - withDefaultConfiguration, - INFERENCE_SERVER_IMAGE, - SECOND, - isTransitioning, -} from './inferenceUtils'; -import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; -import type { ImageInfo } from '@podman-desktop/api'; +import { withDefaultConfiguration, isTransitioning } from './inferenceUtils'; import { getFreeRandomPort } from './ports'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { InferenceServer, InferenceServerStatus } from '@shared/src/models/IInference'; @@ -35,133 +27,9 @@ vi.mock('./ports', () => ({ beforeEach(() => { vi.resetAllMocks(); - vi.mocked(getFreeRandomPort).mockResolvedValue(8888); }); -describe('generateContainerCreateOptions', () => { - test('valid arguments', () => { - const result = generateContainerCreateOptions( - { - port: 8888, - providerId: 'test@providerId', - image: INFERENCE_SERVER_IMAGE, - modelsInfo: [ - { - id: 'dummyModelId', - file: { - file: 'dummyFile', - path: 'dummyPath', - }, - }, - ], - } as unknown as InferenceServerConfig, - { - Id: 'dummyImageId', - engineId: 'dummyEngineId', - RepoTags: [INFERENCE_SERVER_IMAGE], - } as unknown as ImageInfo, - ); - expect(result).toStrictEqual({ - Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], - Detach: true, - Env: ['MODEL_PATH=/models/dummyFile', 'HOST=0.0.0.0', 'PORT=8000'], - ExposedPorts: { - '8888': {}, - }, - HealthCheck: { - Interval: SECOND * 5, - Retries: 20, - Test: ['CMD-SHELL', 'curl -sSf localhost:8000/docs > /dev/null'], - }, - HostConfig: { - AutoRemove: false, - Mounts: [ - { - Source: 'dummyPath', - Target: '/models', - Type: 'bind', - }, - ], - PortBindings: { - '8000/tcp': [ - { - HostPort: '8888', - }, - ], - }, - SecurityOpt: ['label=disable'], - }, - Image: 'dummyImageId', - Labels: { - 'ai-lab-inference-server': '["dummyModelId"]', - }, - }); - }); - - test('model info with chat_format properties', () => { - const result = generateContainerCreateOptions( - { - port: 8888, - providerId: 'test@providerId', - image: INFERENCE_SERVER_IMAGE, - modelsInfo: [ - { - id: 'dummyModelId', - file: { - file: 'dummyFile', - path: 'dummyPath', - }, - properties: { - chatFormat: 'dummyChatFormat', - }, - }, - ], - } as unknown as InferenceServerConfig, - { - Id: 'dummyImageId', - engineId: 'dummyEngineId', - RepoTags: [INFERENCE_SERVER_IMAGE], - } as unknown as ImageInfo, - ); - - expect(result.Env).toContain('MODEL_CHAT_FORMAT=dummyChatFormat'); - }); - - test('model info with multiple properties', () => { - const result = generateContainerCreateOptions( - { - port: 8888, - providerId: 'test@providerId', - image: INFERENCE_SERVER_IMAGE, - modelsInfo: [ - { - id: 'dummyModelId', - file: { - file: 'dummyFile', - path: 'dummyPath', - }, - properties: { - basicProp: 'basicProp', - lotOfCamelCases: 'lotOfCamelCases', - lowercase: 'lowercase', - }, - }, - ], - } as unknown as InferenceServerConfig, - { - Id: 'dummyImageId', - engineId: 'dummyEngineId', - RepoTags: [INFERENCE_SERVER_IMAGE], - } as unknown as ImageInfo, - ); - - expect(result.Env).toContain('MODEL_BASIC_PROP=basicProp'); - expect(result.Env).toContain('MODEL_LOT_OF_CAMEL_CASES=lotOfCamelCases'); - expect(result.Env).toContain('MODEL_LOWERCASE=lowercase'); - }); -}); - describe('withDefaultConfiguration', () => { test('zero modelsInfo', async () => { await expect(withDefaultConfiguration({ modelsInfo: [] })).rejects.toThrowError( @@ -175,7 +43,7 @@ describe('withDefaultConfiguration', () => { expect(getFreeRandomPort).toHaveBeenCalledWith('0.0.0.0'); expect(result.port).toBe(8888); - expect(result.image).toBe(INFERENCE_SERVER_IMAGE); + expect(result.image).toBe(undefined); expect(result.labels).toStrictEqual({}); expect(result.providerId).toBe(undefined); }); diff --git a/packages/backend/src/utils/inferenceUtils.ts b/packages/backend/src/utils/inferenceUtils.ts index 04024b10d..0e8611c1b 100644 --- a/packages/backend/src/utils/inferenceUtils.ts +++ b/packages/backend/src/utils/inferenceUtils.ts @@ -16,7 +16,6 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ import { - type ContainerCreateOptions, containerEngine, type ContainerProviderConnection, type ImageInfo, @@ -26,18 +25,11 @@ import { type PullEvent, } from '@podman-desktop/api'; import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; -import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from './utils'; import { getFreeRandomPort } from './ports'; -import { getModelPropertiesForEnvironment } from './modelsUtils'; import type { InferenceServer } from '@shared/src/models/IInference'; -export const SECOND: number = 1_000_000_000; - export const LABEL_INFERENCE_SERVER: string = 'ai-lab-inference-server'; -export const INFERENCE_SERVER_IMAGE = - 'ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:0.3.2'; - /** * Return container connection provider */ @@ -96,67 +88,6 @@ export async function getImageInfo( return imageInfo; } -/** - * Given an {@link InferenceServerConfig} and an {@link ImageInfo} generate a container creation options object - * @param config the config to use - * @param imageInfo the image to use - */ -export function generateContainerCreateOptions( - config: InferenceServerConfig, - imageInfo: ImageInfo, -): ContainerCreateOptions { - if (config.modelsInfo.length === 0) throw new Error('Need at least one model info to start an inference server.'); - - if (config.modelsInfo.length > 1) { - throw new Error('Currently the inference server does not support multiple models serving.'); - } - - const modelInfo = config.modelsInfo[0]; - - if (modelInfo.file === undefined) { - throw new Error('The model info file provided is undefined'); - } - - const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; - envs.push(...getModelPropertiesForEnvironment(modelInfo)); - - return { - Image: imageInfo.Id, - Detach: true, - ExposedPorts: { [`${config.port}`]: {} }, - HostConfig: { - AutoRemove: false, - Mounts: [ - { - Target: '/models', - Source: modelInfo.file.path, - Type: 'bind', - }, - ], - SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], - PortBindings: { - '8000/tcp': [ - { - HostPort: `${config.port}`, - }, - ], - }, - }, - HealthCheck: { - // must be the port INSIDE the container not the exposed one - Test: ['CMD-SHELL', `curl -sSf localhost:8000/docs > /dev/null`], - Interval: SECOND * 5, - Retries: 4 * 5, - }, - Labels: { - ...config.labels, - [LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)), - }, - Env: envs, - Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], - }; -} - export async function withDefaultConfiguration( options: CreationInferenceServerOptions, ): Promise { @@ -164,10 +95,11 @@ export async function withDefaultConfiguration( return { port: options.port || (await getFreeRandomPort('0.0.0.0')), - image: options.image || INFERENCE_SERVER_IMAGE, + image: options.image, labels: options.labels || {}, modelsInfo: options.modelsInfo, providerId: options.providerId, + inferenceProvider: options.inferenceProvider, }; } diff --git a/packages/backend/src/workers/IWorker.ts b/packages/backend/src/workers/IWorker.ts index b66e56257..a0444044f 100644 --- a/packages/backend/src/workers/IWorker.ts +++ b/packages/backend/src/workers/IWorker.ts @@ -18,5 +18,5 @@ export interface IWorker { enabled(): boolean; - perform(trackingId: T): Promise; + perform(args: T): Promise; } diff --git a/packages/backend/src/workers/provider/InferenceProvider.spec.ts b/packages/backend/src/workers/provider/InferenceProvider.spec.ts new file mode 100644 index 000000000..ef3f03739 --- /dev/null +++ b/packages/backend/src/workers/provider/InferenceProvider.spec.ts @@ -0,0 +1,210 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { vi, describe, test, expect, beforeEach } from 'vitest'; +import type { TaskRegistry } from '../../registries/TaskRegistry'; +import { type BetterContainerCreateResult, InferenceProvider } from './InferenceProvider'; +import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; +import { containerEngine } from '@podman-desktop/api'; +import type { + ContainerProviderConnection, + ImageInfo, + ProviderContainerConnection, + ContainerCreateOptions, +} from '@podman-desktop/api'; +import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; +import type { TaskState } from '@shared/src/models/ITask'; + +vi.mock('../../utils/inferenceUtils', () => ({ + getProviderContainerConnection: vi.fn(), + getImageInfo: vi.fn(), + LABEL_INFERENCE_SERVER: 'ai-lab-inference-server', +})); + +vi.mock('@podman-desktop/api', () => ({ + containerEngine: { + createContainer: vi.fn(), + }, +})); + +const DummyProviderContainerConnection: ProviderContainerConnection = { + providerId: 'dummy-provider-id', + connection: { + name: 'dummy-provider-connection', + type: 'podman', + } as unknown as ContainerProviderConnection, +}; + +const DummyImageInfo: ImageInfo = { + Id: 'dummy-image-id', + engineId: 'dummy-engine-id', +} as unknown as ImageInfo; + +const taskRegistry: TaskRegistry = { + createTask: vi.fn(), + updateTask: vi.fn(), +} as unknown as TaskRegistry; + +class TestInferenceProvider extends InferenceProvider { + name: string = 'test-inference-provider'; + + constructor() { + super(taskRegistry); + } + + enabled(): boolean { + throw new Error('not implemented'); + } + + publicPullImage(providerId: string | undefined, image: string, labels: { [id: string]: string }) { + return super.pullImage(providerId, image, labels); + } + + async publicCreateContainer( + engineId: string, + containerCreateOptions: ContainerCreateOptions, + labels: { [id: string]: string } = {}, + ): Promise { + const result = await this.createContainer(engineId, containerCreateOptions, labels); + return { + id: result.id, + engineId: engineId, + }; + } + + async perform(_config: InferenceServerConfig): Promise { + throw new Error('not implemented'); + } + dispose(): void {} +} + +beforeEach(() => { + vi.resetAllMocks(); + + vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection); + vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo); + vi.mocked(taskRegistry.createTask).mockImplementation( + (name: string, state: TaskState, labels: { [id: string]: string } = {}) => ({ + id: 'dummy-task-id', + name: name, + state: state, + labels: labels, + }), + ); + vi.mocked(containerEngine.createContainer).mockResolvedValue({ + id: 'dummy-container-id', + }); +}); + +describe('pullImage', () => { + test('should create a task and mark as success on completion', async () => { + const provider = new TestInferenceProvider(); + await provider.publicPullImage('dummy-provider-id', 'dummy-image', { + key: 'value', + }); + + expect(taskRegistry.createTask).toHaveBeenCalledWith('Pulling dummy-image.', 'loading', { + key: 'value', + }); + + expect(taskRegistry.updateTask).toHaveBeenCalledWith({ + id: 'dummy-task-id', + name: 'Pulling dummy-image.', + labels: { + key: 'value', + }, + state: 'success', + }); + }); + + test('should mark the task as error when pulling failed', async () => { + const provider = new TestInferenceProvider(); + vi.mocked(getImageInfo).mockRejectedValue(new Error('dummy test error')); + + await expect( + provider.publicPullImage('dummy-provider-id', 'dummy-image', { + key: 'value', + }), + ).rejects.toThrowError('dummy test error'); + + expect(taskRegistry.updateTask).toHaveBeenCalledWith({ + id: 'dummy-task-id', + name: 'Pulling dummy-image.', + labels: { + key: 'value', + }, + state: 'error', + error: 'Something went wrong while pulling dummy-image: Error: dummy test error', + }); + }); +}); + +describe('createContainer', () => { + test('should create a task and mark as success on completion', async () => { + const provider = new TestInferenceProvider(); + await provider.publicCreateContainer( + 'dummy-engine-id', + { + name: 'dummy-container-name', + }, + { + key: 'value', + }, + ); + + expect(taskRegistry.createTask).toHaveBeenCalledWith('Creating container.', 'loading', { + key: 'value', + }); + + expect(taskRegistry.updateTask).toHaveBeenCalledWith({ + id: 'dummy-task-id', + name: 'Creating container.', + labels: { + key: 'value', + }, + state: 'success', + }); + }); + + test('should mark the task as error when creation failed', async () => { + const provider = new TestInferenceProvider(); + vi.mocked(containerEngine.createContainer).mockRejectedValue(new Error('dummy test error')); + + await expect( + provider.publicCreateContainer( + 'dummy-provider-id', + { + name: 'dummy-container-name', + }, + { + key: 'value', + }, + ), + ).rejects.toThrowError('dummy test error'); + + expect(taskRegistry.updateTask).toHaveBeenCalledWith({ + id: 'dummy-task-id', + name: 'Creating container.', + labels: { + key: 'value', + }, + state: 'error', + error: 'Something went wrong while creating container: Error: dummy test error', + }); + }); +}); diff --git a/packages/backend/src/workers/provider/InferenceProvider.ts b/packages/backend/src/workers/provider/InferenceProvider.ts new file mode 100644 index 000000000..bcf55d48d --- /dev/null +++ b/packages/backend/src/workers/provider/InferenceProvider.ts @@ -0,0 +1,105 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ +import type { + ContainerCreateOptions, + ContainerCreateResult, + Disposable, + ImageInfo, + PullEvent, +} from '@podman-desktop/api'; +import { containerEngine } from '@podman-desktop/api'; +import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; +import type { IWorker } from '../IWorker'; +import type { TaskRegistry } from '../../registries/TaskRegistry'; +import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; + +export type BetterContainerCreateResult = ContainerCreateResult & { engineId: string }; + +export abstract class InferenceProvider + implements IWorker, Disposable +{ + protected constructor(private taskRegistry: TaskRegistry) {} + + abstract name: string; + abstract enabled(): boolean; + abstract perform(config: InferenceServerConfig): Promise; + abstract dispose(): void; + + protected async createContainer( + engineId: string, + containerCreateOptions: ContainerCreateOptions, + labels: { [id: string]: string }, + ): Promise { + const containerTask = this.taskRegistry.createTask(`Creating container.`, 'loading', labels); + + try { + const result = await containerEngine.createContainer(engineId, containerCreateOptions); + // update the task + containerTask.state = 'success'; + containerTask.progress = undefined; + // return the ContainerCreateResult + return { + id: result.id, + engineId: engineId, + }; + } catch (err: unknown) { + containerTask.state = 'error'; + containerTask.progress = undefined; + containerTask.error = `Something went wrong while creating container: ${String(err)}`; + throw err; + } finally { + this.taskRegistry.updateTask(containerTask); + } + } + + /** + * This method allows to pull the image, while creating a task for the user to follow progress + * @param providerId + * @param image + * @param labels + * @protected + */ + protected pullImage( + providerId: string | undefined, + image: string, + labels: { [id: string]: string }, + ): Promise { + // Creating a task to follow pulling progress + const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); + + // Get the provider + const provider = getProviderContainerConnection(providerId); + + // get the default image info for this provider + return getImageInfo(provider.connection, image, (_event: PullEvent) => {}) + .catch((err: unknown) => { + pullingTask.state = 'error'; + pullingTask.progress = undefined; + pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`; + throw err; + }) + .then(imageInfo => { + pullingTask.state = 'success'; + pullingTask.progress = undefined; + return imageInfo; + }) + .finally(() => { + this.taskRegistry.updateTask(pullingTask); + }); + } +} diff --git a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts new file mode 100644 index 000000000..b8e9b6f5d --- /dev/null +++ b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts @@ -0,0 +1,222 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { vi, describe, test, expect, beforeEach } from 'vitest'; +import type { TaskRegistry } from '../../registries/TaskRegistry'; +import { LLAMA_CPP_INFERENCE_IMAGE, LlamaCppPython, SECOND } from './LlamaCppPython'; +import type { ModelInfo } from '@shared/src/models/IModelInfo'; +import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; +import type { ContainerProviderConnection, ImageInfo, ProviderContainerConnection } from '@podman-desktop/api'; +import { containerEngine } from '@podman-desktop/api'; + +vi.mock('@podman-desktop/api', () => ({ + containerEngine: { + createContainer: vi.fn(), + }, +})); + +vi.mock('../../utils/inferenceUtils', () => ({ + getProviderContainerConnection: vi.fn(), + getImageInfo: vi.fn(), + LABEL_INFERENCE_SERVER: 'ai-lab-inference-server', +})); + +const taskRegistry: TaskRegistry = { + createTask: vi.fn(), + updateTask: vi.fn(), +} as unknown as TaskRegistry; + +const DummyModel: ModelInfo = { + name: 'dummy model', + id: 'dummy-model-id', + file: { + file: 'dummy-file.guff', + path: 'dummy-path', + }, + properties: {}, + description: 'dummy-desc', + hw: 'dummy-hardware', +}; + +const DummyProviderContainerConnection: ProviderContainerConnection = { + providerId: 'dummy-provider-id', + connection: { + name: 'dummy-provider-connection', + type: 'podman', + } as unknown as ContainerProviderConnection, +}; + +const DummyImageInfo: ImageInfo = { + Id: 'dummy-image-id', + engineId: 'dummy-engine-id', +} as unknown as ImageInfo; + +beforeEach(() => { + vi.resetAllMocks(); + + vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection); + vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo); + vi.mocked(taskRegistry.createTask).mockReturnValue({ id: 'dummy-task-id', name: '', labels: {}, state: 'loading' }); + vi.mocked(containerEngine.createContainer).mockResolvedValue({ + id: 'dummy-container-id', + }); +}); + +test('LlamaCppPython being the default, it should always be enable', () => { + const provider = new LlamaCppPython(taskRegistry); + expect(provider.enabled()).toBeTruthy(); +}); + +describe('perform', () => { + test('config without image should use defined image', async () => { + const provider = new LlamaCppPython(taskRegistry); + + await provider.perform({ + port: 8000, + image: undefined, + labels: {}, + modelsInfo: [DummyModel], + providerId: undefined, + }); + + expect(getProviderContainerConnection).toHaveBeenCalledWith(undefined); + expect(getImageInfo).toHaveBeenCalledWith( + DummyProviderContainerConnection.connection, + LLAMA_CPP_INFERENCE_IMAGE, + expect.anything(), + ); + }); + + test('config without models should throw an error', async () => { + const provider = new LlamaCppPython(taskRegistry); + + await expect( + provider.perform({ + port: 8000, + image: undefined, + labels: {}, + modelsInfo: [], + providerId: undefined, + }), + ).rejects.toThrowError('Need at least one model info to start an inference server.'); + }); + + test('config model without file should throw an error', async () => { + const provider = new LlamaCppPython(taskRegistry); + + await expect( + provider.perform({ + port: 8000, + image: undefined, + labels: {}, + modelsInfo: [ + { + id: 'invalid', + } as unknown as ModelInfo, + ], + providerId: undefined, + }), + ).rejects.toThrowError('The model info file provided is undefined'); + }); + + test('valid config should produce expected CreateContainerOptions', async () => { + const provider = new LlamaCppPython(taskRegistry); + + await provider.perform({ + port: 8888, + image: undefined, + labels: {}, + modelsInfo: [DummyModel], + providerId: undefined, + }); + + expect(containerEngine.createContainer).toHaveBeenCalledWith(DummyImageInfo.engineId, { + Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], + Detach: true, + Env: ['MODEL_PATH=/models/dummy-file.guff', 'HOST=0.0.0.0', 'PORT=8000'], + ExposedPorts: { + '8888': {}, + }, + HealthCheck: { + Interval: SECOND * 5, + Retries: 20, + Test: ['CMD-SHELL', 'curl -sSf localhost:8000/docs > /dev/null'], + }, + HostConfig: { + AutoRemove: false, + Mounts: [ + { + Source: 'dummy-path', + Target: '/models', + Type: 'bind', + }, + ], + PortBindings: { + '8000/tcp': [ + { + HostPort: '8888', + }, + ], + }, + SecurityOpt: ['label=disable'], + }, + Image: DummyImageInfo.Id, + Labels: { + 'ai-lab-inference-server': `["${DummyModel.id}"]`, + }, + }); + }); + + test('model properties should be made uppercased', async () => { + const provider = new LlamaCppPython(taskRegistry); + + await provider.perform({ + port: 8000, + image: undefined, + labels: {}, + modelsInfo: [ + { + ...DummyModel, + properties: { + basicProp: 'basicProp', + lotOfCamelCases: 'lotOfCamelCases', + lowercase: 'lowercase', + chatFormat: 'dummyChatFormat', + }, + }, + ], + providerId: undefined, + }); + + expect(containerEngine.createContainer).toHaveBeenCalledWith(DummyImageInfo.engineId, { + Env: expect.arrayContaining([ + 'MODEL_BASIC_PROP=basicProp', + 'MODEL_LOT_OF_CAMEL_CASES=lotOfCamelCases', + 'MODEL_LOWERCASE=lowercase', + 'MODEL_CHAT_FORMAT=dummyChatFormat', + ]), + Cmd: expect.anything(), + HealthCheck: expect.anything(), + HostConfig: expect.anything(), + ExposedPorts: expect.anything(), + Labels: expect.anything(), + Image: DummyImageInfo.Id, + Detach: true, + }); + }); +}); diff --git a/packages/backend/src/workers/provider/LlamaCppPython.ts b/packages/backend/src/workers/provider/LlamaCppPython.ts new file mode 100644 index 000000000..a754278bb --- /dev/null +++ b/packages/backend/src/workers/provider/LlamaCppPython.ts @@ -0,0 +1,115 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ +import type { ContainerCreateOptions, ImageInfo } from '@podman-desktop/api'; +import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; +import { type BetterContainerCreateResult, InferenceProvider } from './InferenceProvider'; +import { getModelPropertiesForEnvironment } from '../../utils/modelsUtils'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils'; +import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; +import type { TaskRegistry } from '../../registries/TaskRegistry'; + +export const LLAMA_CPP_INFERENCE_IMAGE = + 'ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:0.3.2'; + +export const SECOND: number = 1_000_000_000; + +export class LlamaCppPython extends InferenceProvider { + name: string; + + constructor(taskRegistry: TaskRegistry) { + super(taskRegistry); + this.name = 'llama-cpp'; + } + + dispose() {} + + public enabled = (): boolean => true; + + protected async getContainerCreateOptions( + config: InferenceServerConfig, + imageInfo: ImageInfo, + ): Promise { + if (config.modelsInfo.length === 0) throw new Error('Need at least one model info to start an inference server.'); + + if (config.modelsInfo.length > 1) { + throw new Error('Currently the inference server does not support multiple models serving.'); + } + + const modelInfo = config.modelsInfo[0]; + + if (modelInfo.file === undefined) { + throw new Error('The model info file provided is undefined'); + } + + const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; + envs.push(...getModelPropertiesForEnvironment(modelInfo)); + + return { + Image: imageInfo.Id, + Detach: true, + ExposedPorts: { [`${config.port}`]: {} }, + HostConfig: { + AutoRemove: false, + Mounts: [ + { + Target: '/models', + Source: modelInfo.file.path, + Type: 'bind', + }, + ], + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + PortBindings: { + '8000/tcp': [ + { + HostPort: `${config.port}`, + }, + ], + }, + }, + HealthCheck: { + // must be the port INSIDE the container not the exposed one + Test: ['CMD-SHELL', `curl -sSf localhost:8000/docs > /dev/null`], + Interval: SECOND * 5, + Retries: 4 * 5, + }, + Labels: { + ...config.labels, + [LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)), + }, + Env: envs, + Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], + }; + } + + async perform(config: InferenceServerConfig): Promise { + if (!this.enabled()) throw new Error('not enabled'); + + // pull the image + const imageInfo: ImageInfo = await this.pullImage( + config.providerId, + config.image ?? LLAMA_CPP_INFERENCE_IMAGE, + config.labels, + ); + + // Get the container creation options + const containerCreateOptions: ContainerCreateOptions = await this.getContainerCreateOptions(config, imageInfo); + + // Create the container + return this.createContainer(imageInfo.engineId, containerCreateOptions, config.labels); + } +} diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts index 1bef0a3cb..89a855e0b 100644 --- a/packages/shared/Messages.ts +++ b/packages/shared/Messages.ts @@ -27,4 +27,5 @@ export enum Messages { MSG_SUPPORTED_LANGUAGES_UPDATE = 'supported-languages-supported', MSG_CONVERSATIONS_UPDATE = 'conversations-update', MSG_GPUS_UPDATE = 'gpus-update', + MSG_INFERENCE_PROVIDER_UPDATE = 'inference-provider-update', } diff --git a/packages/shared/src/models/InferenceServerConfig.ts b/packages/shared/src/models/InferenceServerConfig.ts index 48df109fd..8352fca1f 100644 --- a/packages/shared/src/models/InferenceServerConfig.ts +++ b/packages/shared/src/models/InferenceServerConfig.ts @@ -28,10 +28,14 @@ export interface InferenceServerConfig { * The identifier of the container provider to use */ providerId?: string; + /** + * The name of the inference provider to use + */ + inferenceProvider?: string; /** * Image to use */ - image: string; + image?: string; /** * Labels to use for the container */ diff --git a/yarn.lock b/yarn.lock index ceb2b85d4..280edf994 100644 --- a/yarn.lock +++ b/yarn.lock @@ -391,10 +391,10 @@ dependencies: playwright "1.42.1" -"@podman-desktop/api@0.0.202404101645-5d46ba5": - version "0.0.202404101645-5d46ba5" - resolved "https://registry.yarnpkg.com/@podman-desktop/api/-/api-0.0.202404101645-5d46ba5.tgz#562aa4470057bfc7869b576d9691da2ffe49fa18" - integrity sha512-5FwFhBrOGb2Nhd0buGZCigAewnio8xlkenmIKW0Ew2jqeoKffgW6e4bpvRTsl9JeGiQqrEa5jkZP6cft+xlGfA== +"@podman-desktop/api@1.10.3": + version "1.10.3" + resolved "https://registry.yarnpkg.com/@podman-desktop/api/-/api-1.10.3.tgz#c4c17e96aa3f70acd47162cc2d02cd0f3290fd52" + integrity sha512-jc5mYPsNz59e+o+1fQR67TPUWQoIuEssMtSwOgqdV/k0lSk05p5ErotrgeKT7WVXb7XxYOx0E4MtTTY5Kf7cyw== "@podman-desktop/tests-playwright@^1.10.3": version "1.10.3"