diff --git a/public/components/register_model/model_deployment.tsx b/public/components/register_model/model_deployment.tsx index 27855849..97eceefd 100644 --- a/public/components/register_model/model_deployment.tsx +++ b/public/components/register_model/model_deployment.tsx @@ -3,48 +3,33 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useState } from 'react'; +import React from 'react'; import { EuiCheckbox, EuiText, EuiFormRow } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { useSearchParams } from '../../hooks/use_search_params'; export const ModelDeployment = () => { const searchParams = useSearchParams(); const typeParams = searchParams.get('type'); - const [checked, setChecked] = useState(false); const { control } = useFormContext<{ deployment: boolean }>(); const modelDeploymentController = useController({ name: 'deployment', control, }); + const isRegisterExternal = typeParams === 'external'; const { ref: deploymentInputRef, ...deploymentField } = modelDeploymentController.field; - const onDeploymentChange = useCallback( - (e) => { - setChecked(e.target.checked); - deploymentField.onChange(checked); - }, - [deploymentField, checked] - ); return ( - +
{Needs a description} - {(typeParams === 'upload' || typeParams === 'import') && ( - - )} - {typeParams === 'external' && ( - - )} +
); diff --git a/public/components/register_model/model_source.tsx b/public/components/register_model/model_source.tsx index a9bdc4f5..7c7d6407 100644 --- a/public/components/register_model/model_source.tsx +++ b/public/components/register_model/model_source.tsx @@ -12,15 +12,20 @@ import { EuiFormRow, EuiComboBox, } from '@elastic/eui'; - import { useController, useFormContext } from 'react-hook-form'; -import { useMonitoring } from '../monitoring/use_monitoring'; + +import { useFetcher } from '../../hooks'; +import { APIProvider } from '../../apis/api_provider'; export const ModelSource = () => { - const { allExternalConnectors } = useMonitoring(); - const connectorOptions = allExternalConnectors?.map((item) => { - return Object.assign({}, { label: item.name, value: item.id }); - }); + const { data: allConnectorsData } = useFetcher(APIProvider.getAPI('connector').getAll); + const connectorOptions = useMemo( + () => + allConnectorsData?.data?.map((item) => { + return Object.assign({}, { label: item.name, value: item.id }); + }), + [allConnectorsData] + ); const { control } = useFormContext<{ modelConnector: string }>(); const modelConnectorController = useController({ diff --git a/public/components/register_model/pretrained_model_select.tsx b/public/components/register_model/pretrained_model_select.tsx index f4d75485..763c03e7 100644 --- a/public/components/register_model/pretrained_model_select.tsx +++ b/public/components/register_model/pretrained_model_select.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useState, Fragment, useEffect } from 'react'; +import React, { useCallback, Fragment } from 'react'; import { EuiSpacer, EuiTextColor, @@ -12,10 +12,12 @@ import { EuiSelectableOption, EuiHighlight, } from '@elastic/eui'; -import { useHistory } from 'react-router-dom'; -import { generatePath } from 'react-router-dom'; +import { useHistory, generatePath } from 'react-router-dom'; +import { useObservable } from 'react-use'; + import { modelRepositoryManager } from '../../utils/model_repository_manager'; import { routerPaths } from '../../../common/router_paths'; + interface IItem { label: string; checked?: 'on' | undefined; @@ -34,43 +36,36 @@ const renderModelOption = (option: IItem, searchValue: string) => { ); }; -export const PreTrainedModelSelect = () => { - useEffect(() => { - const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { - setModelRepoSelection( - Object.keys(models).map((name) => ({ - label: name, - description: models[name].description, - checked: undefined, - })) - ); - }); - return () => { - subscribe.unsubscribe(); - }; - }, []); - const [modelRepoSelection, setModelRepoSelection] = useState>>( - [] - ); +export const PreTrainedModelSelect = ({ + checkedPreTrainedModel, +}: { + checkedPreTrainedModel?: string; +}) => { + const preTrainedModels = useObservable(modelRepositoryManager.getPreTrainedModels$()); + const preTrainedModelOptions = preTrainedModels + ? Object.keys(preTrainedModels).map((name) => ({ + label: name, + description: preTrainedModels[name].description, + checked: checkedPreTrainedModel === name ? ('on' as const) : undefined, + })) + : []; + const history = useHistory(); const onChange = useCallback( - (modelSelection: Array>) => { - setModelRepoSelection(modelSelection); - // ShowRest(true); + (options: Array>) => { + const selectedOption = options.find((option) => option.checked === 'on'); + + if (selectedOption?.label) { + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ + selectedOption.label + }` + ); + } }, - // [ShowRest] - [] + [history] ); - useEffect(() => { - const selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); - if (selectedOption?.label) { - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ - selectedOption?.label - }` - ); - } - }, [modelRepoSelection, history]); + return (
@@ -95,7 +90,7 @@ export const PreTrainedModelSelect = () => { 'data-test-subj': 'findModel', placeholder: 'Find model', }} - options={modelRepoSelection} + options={preTrainedModelOptions} onChange={onChange} singleSelection={true} noMatchesMessage="No model found" @@ -106,7 +101,7 @@ export const PreTrainedModelSelect = () => { 'data-test-subj': 'opensearchModelList', showIcons: true, }} - isLoading={modelRepoSelection.length === 0} + isLoading={!preTrainedModels} > {(list, search) => ( diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 7c47ea2a..49880983 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -295,7 +295,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const errorCount = Object.keys(form.formState.errors).length; const formHeader = ( <> - + {registerToModelId && ( @@ -353,7 +353,6 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} ))} - {/* {formType === 'import' ? nameParams && formFooter : formFooter} */} @@ -380,7 +379,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo setIsSubmitted(true)}