Skip to content

Commit

Permalink
feat: allow disabling model upload (#1662)
Browse files Browse the repository at this point in the history
* feat: allow disabling model upload

Signed-off-by: axel7083 <[email protected]>

* Apply suggestions from code review

Co-authored-by: Jeff MAURY <[email protected]>
Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
Co-authored-by: Jeff MAURY <[email protected]>
  • Loading branch information
axel7083 and jeffmaury authored Sep 10, 2024
1 parent f25fd06 commit 8af1106
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 0 deletions.
6 changes: 6 additions & 0 deletions packages/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
"default": "",
"description": "Custom path where to download models. Note: The extension must be restarted for changes to take effect. (Default is blank)"
},
"ai-lab.modelUploadDisabled": {
"type": "boolean",
"default": false,
"description": "Disable the model upload to the podman machine",
"hidden": true
},
"ai-lab.experimentalGPU": {
"type": "boolean",
"default": false,
Expand Down
112 changes: 112 additions & 0 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import { getPodmanMachineName } from '../utils/podman';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';
import { Uploader } from '../utils/uploader';

const mocks = vi.hoisted(() => {
return {
Expand All @@ -49,6 +51,10 @@ const mocks = vi.hoisted(() => {
};
});

vi.mock('../utils/uploader', () => ({
Uploader: vi.fn(),
}));

vi.mock('@huggingface/gguf', () => ({
gguf: vi.fn(),
}));
Expand Down Expand Up @@ -109,10 +115,22 @@ const telemetryLogger = {
logError: mocks.logErrorMock,
} as unknown as TelemetryLogger;

const configurationRegistryMock: ConfigurationRegistry = {
getExtensionConfiguration: vi.fn(),
} as unknown as ConfigurationRegistry;

beforeEach(() => {
vi.resetAllMocks();
taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview);

vi.mocked(configurationRegistryMock.getExtensionConfiguration).mockReturnValue({
modelUploadDisabled: false,
modelsPath: '~/downloads',
experimentalTuning: false,
apiPort: 0,
experimentalGPU: false,
});

mocks.isCompletionEventMock.mockReturnValue(true);
});

Expand Down Expand Up @@ -190,6 +208,7 @@ test('getModelsInfo should get models in local directory', async () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -240,6 +259,7 @@ test('getModelsInfo should return an empty array if the models folder does not e
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
manager.getLocalModelsFromDisk();
Expand Down Expand Up @@ -281,6 +301,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -340,6 +361,7 @@ test('getLocalModelsFromDisk should skip folders containing tmp files', async ()
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -381,6 +403,7 @@ test('loadLocalModels should post a message with the message on disk and on cata
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -432,6 +455,7 @@ test('deleteModel deletes the model folder', async () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -497,6 +521,7 @@ describe('deleting models', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -564,6 +589,7 @@ describe('deleting models', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
await manager.loadLocalModels();
await manager.deleteModel('model-id-1');
Expand Down Expand Up @@ -624,6 +650,7 @@ describe('deleting models', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

await manager.loadLocalModels();
Expand Down Expand Up @@ -658,6 +685,7 @@ describe('downloadModel', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -693,6 +721,7 @@ describe('downloadModel', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
Expand Down Expand Up @@ -725,6 +754,7 @@ describe('downloadModel', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);
vi.spyOn(taskRegistry, 'updateTask');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
Expand Down Expand Up @@ -756,6 +786,7 @@ describe('downloadModel', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -793,6 +824,7 @@ describe('downloadModel', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -841,6 +873,7 @@ describe('getModelMetadata', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

await expect(() => manager.getModelMetadata('unknown-model-id')).rejects.toThrowError(
Expand All @@ -866,6 +899,7 @@ describe('getModelMetadata', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

manager.init();
Expand Down Expand Up @@ -907,6 +941,7 @@ describe('getModelMetadata', () => {
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

manager.init();
Expand All @@ -927,3 +962,80 @@ describe('getModelMetadata', () => {
});
});
});

const connectionMock: ContainerProviderConnection = {
name: 'dummy-connection',
type: 'podman',
vmType: undefined,
} as unknown as ContainerProviderConnection;

const modelMock: ModelInfo = {
id: 'test-model-id',
url: 'dummy-url',
file: {
file: 'random',
path: 'dummy-path',
},
} as unknown as ModelInfo;

describe('uploadModelToPodmanMachine', () => {
test('uploader should be used', async () => {
const performMock = vi.fn().mockResolvedValue('uploader-result');
vi.mocked(Uploader).mockReturnValue({
onEvent: vi.fn(),
perform: performMock,
} as unknown as Uploader);

const manager = new ModelsManager(
'appdir',
{
postMessage: vi.fn(),
} as unknown as Webview,
{
onUpdate: vi.fn(),
getModels: () => [],
} as unknown as CatalogManager,
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

manager.init();
const result = await manager.uploadModelToPodmanMachine(connectionMock, modelMock);
expect(result).toBe('uploader-result');
expect(performMock).toHaveBeenCalledWith(modelMock.id);
});

test('upload should be skipped when configuration disable it', async () => {
vi.mocked(configurationRegistryMock.getExtensionConfiguration).mockReturnValue({
// disable upload
modelUploadDisabled: true,
modelsPath: '~/downloads',
experimentalTuning: false,
apiPort: 0,
experimentalGPU: false,
});

const manager = new ModelsManager(
'appdir',
{
postMessage: vi.fn(),
} as unknown as Webview,
{
onUpdate: vi.fn(),
getModels: () => [],
} as unknown as CatalogManager,
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
configurationRegistryMock,
);

manager.init();
await manager.uploadModelToPodmanMachine(connectionMock, modelMock);
expect(Uploader).not.toHaveBeenCalled();
});
});
8 changes: 8 additions & 0 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import type { GGUFParseOutput } from '@huggingface/gguf';
import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';

export class ModelsManager implements Disposable {
#models: Map<string, ModelInfo>;
Expand All @@ -54,6 +55,7 @@ export class ModelsManager implements Disposable {
private taskRegistry: TaskRegistry,
private cancellationTokenRegistry: CancellationTokenRegistry,
private podmanConnection: PodmanConnection,
private configurationRegistry: ConfigurationRegistry,
) {
this.#models = new Map();
this.#disposables = [];
Expand Down Expand Up @@ -425,6 +427,12 @@ export class ModelsManager implements Disposable {
model: ModelInfo,
labels?: { [key: string]: string },
): Promise<string> {
// ensure the model upload is not disabled
if (this.configurationRegistry.getExtensionConfiguration().modelUploadDisabled) {
console.warn('The model upload is disabled, this may cause the inference server to take a few minutes to start.');
return getLocalModelFile(model);
}

this.taskRegistry.createTask(`Copying model ${model.name} to ${connection.name}`, 'loading', {
...labels,
'model-uploading': model.id,
Expand Down
2 changes: 2 additions & 0 deletions packages/backend/src/registries/ConfigurationRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const CONFIGURATION_SECTIONS: string[] = [
'ai-lab.experimentalGPU',
'ai-lab.apiPort',
'ai-lab.experimentalTuning',
'ai-lab.modelUploadDisabled',
];

const API_PORT_DEFAULT = 10434;
Expand All @@ -49,6 +50,7 @@ export class ConfigurationRegistry extends Publisher<ExtensionConfiguration> imp
experimentalGPU: this.#configuration.get<boolean>('experimentalGPU') ?? false,
apiPort: this.#configuration.get<number>('apiPort') ?? API_PORT_DEFAULT,
experimentalTuning: this.#configuration.get<boolean>('experimentalTuning') ?? false,
modelUploadDisabled: this.#configuration.get<boolean>('modelUploadDisabled') ?? false,
};
}

Expand Down
1 change: 1 addition & 0 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ export class Studio {
this.#taskRegistry,
this.#cancellationTokenRegistry,
this.#podmanConnection,
this.#configurationRegistry,
);
this.#modelsManager.init();
this.#extensionContext.subscriptions.push(this.#modelsManager);
Expand Down
4 changes: 4 additions & 0 deletions packages/backend/src/workers/provider/LlamaCppPython.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ beforeEach(() => {
modelsPath: 'model-path',
apiPort: 10434,
experimentalTuning: false,
modelUploadDisabled: false,
});
vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue(dummyConnection);
vi.mocked(podmanConnection.getContainerProviderConnection).mockReturnValue(dummyConnection);
Expand Down Expand Up @@ -275,6 +276,7 @@ describe('perform', () => {
modelsPath: '',
apiPort: 10434,
experimentalTuning: false,
modelUploadDisabled: false,
});

vi.mocked(gpuManager.collectGPUs).mockResolvedValue([
Expand Down Expand Up @@ -306,6 +308,7 @@ describe('perform', () => {
modelsPath: '',
apiPort: 10434,
experimentalTuning: false,
modelUploadDisabled: false,
});

vi.mocked(gpuManager.collectGPUs).mockResolvedValue([
Expand Down Expand Up @@ -339,6 +342,7 @@ describe('perform', () => {
modelsPath: '',
apiPort: 10434,
experimentalTuning: false,
modelUploadDisabled: false,
});

vi.mocked(gpuManager.collectGPUs).mockResolvedValue([
Expand Down
1 change: 1 addition & 0 deletions packages/shared/src/models/IExtensionConfiguration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ export interface ExtensionConfiguration {
modelsPath: string;
apiPort: number;
experimentalTuning: boolean;
modelUploadDisabled: boolean;
}

0 comments on commit 8af1106

Please sign in to comment.