diff --git a/packages/frontend/src/lib/notification/ContainerConnectionWrapper.spec.ts b/packages/frontend/src/lib/notification/ContainerConnectionWrapper.spec.ts index 4502cb1c0..8156887b9 100644 --- a/packages/frontend/src/lib/notification/ContainerConnectionWrapper.spec.ts +++ b/packages/frontend/src/lib/notification/ContainerConnectionWrapper.spec.ts @@ -29,6 +29,14 @@ import { VMType } from '@shared/src/models/IPodman'; vi.mock('../../utils/client', async () => ({ studioClient: { checkContainerConnectionStatusAndResources: vi.fn(), + getExtensionConfiguration: vi.fn(), + }, + rpcBrowser: { + subscribe: (): unknown => { + return { + unsubscribe: (): void => {}, + }; + }, }, })); @@ -55,6 +63,13 @@ beforeEach(() => { canRedirect: false, status: 'running', }); + vi.mocked(studioClient.getExtensionConfiguration).mockResolvedValue({ + experimentalGPU: false, + apiPort: 0, + experimentalTuning: false, + modelsPath: '', + modelUploadDisabled: false, + }); }); test('model without memory should not check for status', async () => { diff --git a/packages/frontend/src/lib/notification/ContainerConnectionWrapper.svelte b/packages/frontend/src/lib/notification/ContainerConnectionWrapper.svelte index cf1e4c209..8b4e406ca 100644 --- a/packages/frontend/src/lib/notification/ContainerConnectionWrapper.svelte +++ b/packages/frontend/src/lib/notification/ContainerConnectionWrapper.svelte @@ -6,12 +6,21 @@ import type { import type { ModelCheckerContext, ModelInfo } from '@shared/src/models/IModelInfo'; import ContainerConnectionStatusInfo from './ContainerConnectionStatusInfo.svelte'; import { studioClient } from '/@/utils/client'; +import { configuration } from '/@/stores/extensionConfiguration'; +import { fromStore } from 'svelte/store'; +import GPUEnabledMachine from '/@/lib/notification/GPUEnabledMachine.svelte'; +import { VMType } from '@shared/src/models/IPodman'; export let containerProviderConnection: ContainerProviderConnectionInfo | undefined = undefined; export let model: ModelInfo | undefined = undefined; export let checkContext: ModelCheckerContext = 'inference'; let connectionInfo: ContainerConnectionInfo | undefined; +let gpuWarningRequired = false; + +function shouldRecommendGPU(connection: ContainerProviderConnectionInfo): boolean { + return connection.vmType === VMType.APPLEHV || connection.vmType === VMType.APPLEHV_LABEL; +} $: if (typeof model?.memory === 'number' && containerProviderConnection) { studioClient .checkContainerConnectionStatusAndResources({ @@ -26,11 +35,17 @@ $: if (typeof model?.memory === 'number' && containerProviderConnection) { connectionInfo = undefined; console.error(err); }); + if (fromStore(configuration)?.current?.experimentalGPU && shouldRecommendGPU(containerProviderConnection)) { + gpuWarningRequired = true; + } } else { connectionInfo = undefined; } +{#if gpuWarningRequired} + +{/if} {#if connectionInfo} {/if} diff --git a/packages/frontend/src/lib/notification/GPUEnabledMachine.spec.ts b/packages/frontend/src/lib/notification/GPUEnabledMachine.spec.ts new file mode 100644 index 000000000..51325320f --- /dev/null +++ b/packages/frontend/src/lib/notification/GPUEnabledMachine.spec.ts @@ -0,0 +1,54 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import '@testing-library/jest-dom/vitest'; +import { beforeEach, expect, test, vi } from 'vitest'; +import { render, screen } from '@testing-library/svelte'; +import { studioClient } from '/@/utils/client'; +import GPUEnabledMachine from '/@/lib/notification/GPUEnabledMachine.svelte'; + +vi.mock('/@/utils/client', async () => { + return { + studioClient: { + navigateToResources: vi.fn(), + }, + }; +}); + +beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(studioClient.navigateToResources).mockResolvedValue(undefined); +}); + +test('should show navigation to resources', async () => { + render(GPUEnabledMachine); + + const banner = screen.getByLabelText('GPU machine banner'); + expect(banner).toBeInTheDocument(); + const titleDiv = screen.getByLabelText('title'); + expect(titleDiv).toBeInTheDocument(); + expect(titleDiv.textContent).equals('Non GPU enabled machine'); + const descriptionDiv = screen.getByLabelText('description'); + expect(descriptionDiv).toBeInTheDocument(); + expect(descriptionDiv.textContent).equals( + `The selected Podman machine is not GPU enabled. On MacOS, you can run GPU workloads using the krunkit\n environment. Do you want to create a GPU enabled machine ?`, + ); + + const btnUpdate = screen.queryByRole('button', { name: 'Create GPU enabled machine' }); + expect(btnUpdate).toBeInTheDocument(); +}); diff --git a/packages/frontend/src/lib/notification/GPUEnabledMachine.svelte b/packages/frontend/src/lib/notification/GPUEnabledMachine.svelte new file mode 100644 index 000000000..4dca441b6 --- /dev/null +++ b/packages/frontend/src/lib/notification/GPUEnabledMachine.svelte @@ -0,0 +1,31 @@ + + +
+
+
+ +
+
+ Non GPU enabled machine + The selected Podman machine is not GPU enabled. On MacOS, you can run GPU workloads using the krunkit + environment. Do you want to create a GPU enabled machine ? +
+
+ +
+
+
diff --git a/packages/frontend/src/pages/CreateService.spec.ts b/packages/frontend/src/pages/CreateService.spec.ts index 4cdbc8b4a..aaba2b19c 100644 --- a/packages/frontend/src/pages/CreateService.spec.ts +++ b/packages/frontend/src/pages/CreateService.spec.ts @@ -95,6 +95,14 @@ vi.mock('../utils/client', async () => ({ requestCreateInferenceServer: vi.fn(), getHostFreePort: vi.fn(), checkContainerConnectionStatusAndResources: vi.fn(), + getExtensionConfiguration: vi.fn(), + }, + rpcBrowser: { + subscribe: (): unknown => { + return { + unsubscribe: (): void => {}, + }; + }, }, })); @@ -140,6 +148,13 @@ beforeEach(() => { mocks.getInferenceServersMock.mockReturnValue([ { container: { containerId: 'dummyContainerId' } } as InferenceServer, ]); + vi.mocked(studioClient.getExtensionConfiguration).mockResolvedValue({ + experimentalGPU: false, + apiPort: 0, + experimentalTuning: false, + modelsPath: '', + modelUploadDisabled: false, + }); window.HTMLElement.prototype.scrollIntoView = vi.fn(); });