Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update UI to display VAE override for model as dropdown of installed options #5815

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';

Check failure on line 1 in invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/VaeSelect.tsx

View workflow job for this annotation

GitHub Actions / lint-frontend

Run autofix to sort these imports!
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import { useGetVaeModelsQuery } from '../../../../../services/api/endpoints/models';

Check failure on line 8 in invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/VaeSelect.tsx

View workflow job for this annotation

GitHub Actions / lint-frontend

Relative import path '../../../../../services/api/endpoints/models' should be replaced with 'services/api/endpoints/models'
import { useTranslation } from 'react-i18next';
import { GroupBase } from 'chakra-react-select';

Check failure on line 10 in invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/VaeSelect.tsx

View workflow job for this annotation

GitHub Actions / lint-frontend

All imports in the declaration are only used as types. Use `import type`
import { map, reduce, groupBy } from 'lodash-es';

const VaeSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { t } = useTranslation();
const { field } = useController(props);
const { data } = useGetVaeModelsQuery();
const base = useWatch({ control: props.control, name: 'base' });

const onChange = useCallback<ComboboxOnChange>(
(value) => {
if (!value) {
field.onChange(null);
return;
}

field.onChange(value.value);
},
[field]
);

const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!data) {
return [];
}
const modelEntitiesArray = map(data.entities);
const groupedModels = groupBy(modelEntitiesArray, 'base');
const _options = reduce(
groupedModels,
(acc, val, label) => {
acc.push({
label,
options: val.map((model) => ({
label: model.name,
value: model.path,
isDisabled: base !== model.base,
})),
});
return acc;
},
[] as GroupBase<ComboboxOption>[]
);
_options.sort((a) => (a.label === base ? -1 : 1));
return _options;
}, [data, base]);

const value = useMemo(
() => options.flatMap((o) => o.options).find((m) => (field.value ? m.value === field.value : false)) ?? null,
[options, field.value]
);

return (
<Combobox
isClearable
value={value}
options={options}
onChange={onChange}
placeholder={value ? value.value : t('models.defaultVAE')}
/>
);
};

export default typedMemo(VaeSelect);
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
import VaeSelect from './Fields/VaeSelect';

export const ModelEdit = () => {
const dispatch = useAppDispatch();
Expand All @@ -39,39 +40,6 @@ export const ModelEdit = () => {

const { t } = useTranslation();

// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;

// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }

// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);

const {
register,
handleSubmit,
Expand Down Expand Up @@ -243,8 +211,8 @@ export const ModelEdit = () => {
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
<FormLabel>{t('modelManager.vae')}</FormLabel>
<VaeSelect control={control} name="vae" />
</FormControl>
</Flex>
</>
Expand Down
Loading