From 8af1106baacb886e18a2c14d93a675691414c65a Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:30:32 +0200 Subject: [PATCH] feat: allow disabling model upload (#1662) * feat: allow disabling model upload Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Jeff MAURY Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --------- Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> Co-authored-by: Jeff MAURY --- packages/backend/package.json | 6 + .../src/managers/modelsManager.spec.ts | 112 ++++++++++++++++++ .../backend/src/managers/modelsManager.ts | 8 ++ .../src/registries/ConfigurationRegistry.ts | 2 + packages/backend/src/studio.ts | 1 + .../workers/provider/LlamaCppPython.spec.ts | 4 + .../src/models/IExtensionConfiguration.ts | 1 + 7 files changed, 134 insertions(+) diff --git a/packages/backend/package.json b/packages/backend/package.json index d8b60f9ae..f15427f55 100644 --- a/packages/backend/package.json +++ b/packages/backend/package.json @@ -21,6 +21,12 @@ "default": "", "description": "Custom path where to download models. Note: The extension must be restarted for changes to take effect. (Default is blank)" }, + "ai-lab.modelUploadDisabled": { + "type": "boolean", + "default": false, + "description": "Disable the model upload to the podman machine", + "hidden": true + }, "ai-lab.experimentalGPU": { "type": "boolean", "default": false, diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index 81d240a6d..2b0f1c27a 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -34,6 +34,8 @@ import { gguf } from '@huggingface/gguf'; import type { PodmanConnection } from './podmanConnection'; import { VMType } from '@shared/src/models/IPodman'; import { getPodmanMachineName } from '../utils/podman'; +import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry'; +import { Uploader } from '../utils/uploader'; const mocks = vi.hoisted(() => { return { @@ -49,6 +51,10 @@ const mocks = vi.hoisted(() => { }; }); +vi.mock('../utils/uploader', () => ({ + Uploader: vi.fn(), +})); + vi.mock('@huggingface/gguf', () => ({ gguf: vi.fn(), })); @@ -109,10 +115,22 @@ const telemetryLogger = { logError: mocks.logErrorMock, } as unknown as TelemetryLogger; +const configurationRegistryMock: ConfigurationRegistry = { + getExtensionConfiguration: vi.fn(), +} as unknown as ConfigurationRegistry; + beforeEach(() => { vi.resetAllMocks(); taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview); + vi.mocked(configurationRegistryMock.getExtensionConfiguration).mockReturnValue({ + modelUploadDisabled: false, + modelsPath: '~/downloads', + experimentalTuning: false, + apiPort: 0, + experimentalGPU: false, + }); + mocks.isCompletionEventMock.mockReturnValue(true); }); @@ -190,6 +208,7 @@ test('getModelsInfo should get models in local directory', async () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -240,6 +259,7 @@ test('getModelsInfo should return an empty array if the models folder does not e taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); manager.getLocalModelsFromDisk(); @@ -281,6 +301,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -340,6 +361,7 @@ test('getLocalModelsFromDisk should skip folders containing tmp files', async () taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -381,6 +403,7 @@ test('loadLocalModels should post a message with the message on disk and on cata taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -432,6 +455,7 @@ test('deleteModel deletes the model folder', async () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -497,6 +521,7 @@ describe('deleting models', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); await manager.loadLocalModels(); @@ -564,6 +589,7 @@ describe('deleting models', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); await manager.loadLocalModels(); await manager.deleteModel('model-id-1'); @@ -624,6 +650,7 @@ describe('deleting models', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); await manager.loadLocalModels(); @@ -658,6 +685,7 @@ describe('downloadModel', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -693,6 +721,7 @@ describe('downloadModel', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask'); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); @@ -725,6 +754,7 @@ describe('downloadModel', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); vi.spyOn(taskRegistry, 'updateTask'); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); @@ -756,6 +786,7 @@ describe('downloadModel', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -793,6 +824,7 @@ describe('downloadModel', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -841,6 +873,7 @@ describe('getModelMetadata', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); await expect(() => manager.getModelMetadata('unknown-model-id')).rejects.toThrowError( @@ -866,6 +899,7 @@ describe('getModelMetadata', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); @@ -907,6 +941,7 @@ describe('getModelMetadata', () => { taskRegistry, cancellationTokenRegistryMock, podmanConnectionMock, + configurationRegistryMock, ); manager.init(); @@ -927,3 +962,80 @@ describe('getModelMetadata', () => { }); }); }); + +const connectionMock: ContainerProviderConnection = { + name: 'dummy-connection', + type: 'podman', + vmType: undefined, +} as unknown as ContainerProviderConnection; + +const modelMock: ModelInfo = { + id: 'test-model-id', + url: 'dummy-url', + file: { + file: 'random', + path: 'dummy-path', + }, +} as unknown as ModelInfo; + +describe('uploadModelToPodmanMachine', () => { + test('uploader should be used', async () => { + const performMock = vi.fn().mockResolvedValue('uploader-result'); + vi.mocked(Uploader).mockReturnValue({ + onEvent: vi.fn(), + perform: performMock, + } as unknown as Uploader); + + const manager = new ModelsManager( + 'appdir', + { + postMessage: vi.fn(), + } as unknown as Webview, + { + onUpdate: vi.fn(), + getModels: () => [], + } as unknown as CatalogManager, + telemetryLogger, + taskRegistry, + cancellationTokenRegistryMock, + podmanConnectionMock, + configurationRegistryMock, + ); + + manager.init(); + const result = await manager.uploadModelToPodmanMachine(connectionMock, modelMock); + expect(result).toBe('uploader-result'); + expect(performMock).toHaveBeenCalledWith(modelMock.id); + }); + + test('upload should be skipped when configuration disable it', async () => { + vi.mocked(configurationRegistryMock.getExtensionConfiguration).mockReturnValue({ + // disable upload + modelUploadDisabled: true, + modelsPath: '~/downloads', + experimentalTuning: false, + apiPort: 0, + experimentalGPU: false, + }); + + const manager = new ModelsManager( + 'appdir', + { + postMessage: vi.fn(), + } as unknown as Webview, + { + onUpdate: vi.fn(), + getModels: () => [], + } as unknown as CatalogManager, + telemetryLogger, + taskRegistry, + cancellationTokenRegistryMock, + podmanConnectionMock, + configurationRegistryMock, + ); + + manager.init(); + await manager.uploadModelToPodmanMachine(connectionMock, modelMock); + expect(Uploader).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 256987dc4..13cb797ad 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -38,6 +38,7 @@ import type { GGUFParseOutput } from '@huggingface/gguf'; import { gguf } from '@huggingface/gguf'; import type { PodmanConnection } from './podmanConnection'; import { VMType } from '@shared/src/models/IPodman'; +import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry'; export class ModelsManager implements Disposable { #models: Map; @@ -54,6 +55,7 @@ export class ModelsManager implements Disposable { private taskRegistry: TaskRegistry, private cancellationTokenRegistry: CancellationTokenRegistry, private podmanConnection: PodmanConnection, + private configurationRegistry: ConfigurationRegistry, ) { this.#models = new Map(); this.#disposables = []; @@ -425,6 +427,12 @@ export class ModelsManager implements Disposable { model: ModelInfo, labels?: { [key: string]: string }, ): Promise { + // ensure the model upload is not disabled + if (this.configurationRegistry.getExtensionConfiguration().modelUploadDisabled) { + console.warn('The model upload is disabled, this may cause the inference server to take a few minutes to start.'); + return getLocalModelFile(model); + } + this.taskRegistry.createTask(`Copying model ${model.name} to ${connection.name}`, 'loading', { ...labels, 'model-uploading': model.id, diff --git a/packages/backend/src/registries/ConfigurationRegistry.ts b/packages/backend/src/registries/ConfigurationRegistry.ts index 8fdb01f58..5cb66237b 100644 --- a/packages/backend/src/registries/ConfigurationRegistry.ts +++ b/packages/backend/src/registries/ConfigurationRegistry.ts @@ -26,6 +26,7 @@ const CONFIGURATION_SECTIONS: string[] = [ 'ai-lab.experimentalGPU', 'ai-lab.apiPort', 'ai-lab.experimentalTuning', + 'ai-lab.modelUploadDisabled', ]; const API_PORT_DEFAULT = 10434; @@ -49,6 +50,7 @@ export class ConfigurationRegistry extends Publisher imp experimentalGPU: this.#configuration.get('experimentalGPU') ?? false, apiPort: this.#configuration.get('apiPort') ?? API_PORT_DEFAULT, experimentalTuning: this.#configuration.get('experimentalTuning') ?? false, + modelUploadDisabled: this.#configuration.get('modelUploadDisabled') ?? false, }; } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 0118d182c..7d25a1108 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -212,6 +212,7 @@ export class Studio { this.#taskRegistry, this.#cancellationTokenRegistry, this.#podmanConnection, + this.#configurationRegistry, ); this.#modelsManager.init(); this.#extensionContext.subscriptions.push(this.#modelsManager); diff --git a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts index 857a7f572..ab4b3b002 100644 --- a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts +++ b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts @@ -97,6 +97,7 @@ beforeEach(() => { modelsPath: 'model-path', apiPort: 10434, experimentalTuning: false, + modelUploadDisabled: false, }); vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue(dummyConnection); vi.mocked(podmanConnection.getContainerProviderConnection).mockReturnValue(dummyConnection); @@ -275,6 +276,7 @@ describe('perform', () => { modelsPath: '', apiPort: 10434, experimentalTuning: false, + modelUploadDisabled: false, }); vi.mocked(gpuManager.collectGPUs).mockResolvedValue([ @@ -306,6 +308,7 @@ describe('perform', () => { modelsPath: '', apiPort: 10434, experimentalTuning: false, + modelUploadDisabled: false, }); vi.mocked(gpuManager.collectGPUs).mockResolvedValue([ @@ -339,6 +342,7 @@ describe('perform', () => { modelsPath: '', apiPort: 10434, experimentalTuning: false, + modelUploadDisabled: false, }); vi.mocked(gpuManager.collectGPUs).mockResolvedValue([ diff --git a/packages/shared/src/models/IExtensionConfiguration.ts b/packages/shared/src/models/IExtensionConfiguration.ts index 64d42f433..9268ea0ae 100644 --- a/packages/shared/src/models/IExtensionConfiguration.ts +++ b/packages/shared/src/models/IExtensionConfiguration.ts @@ -21,4 +21,5 @@ export interface ExtensionConfiguration { modelsPath: string; apiPort: number; experimentalTuning: boolean; + modelUploadDisabled: boolean; }