Skip to content

Commit

Permalink
(ui): add fields for CLIP-L and CLIP-G, remove MainModelConfig type c…
Browse files Browse the repository at this point in the history
…hanges
  • Loading branch information
Mary Hipp authored and Mary Hipp committed Oct 30, 2024
1 parent d409f44 commit 9543149
Show file tree
Hide file tree
Showing 17 changed files with 486 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import {
isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
Expand Down Expand Up @@ -56,6 +60,8 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
Expand Down Expand Up @@ -135,6 +141,14 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
import type { CLIPEmbedModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

Expand All @@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPEmbedModelConfig | MainModelConfig | null) => {
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;

const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();

const _onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPGEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});

return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};

export default memo(CLIPGEmbedModelFieldInputComponent);
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;

const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();

const _onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPLEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});

return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};

export default memo(CLIPLEmbedModelFieldInputComponent);
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } f
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
import type { VAEModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

Expand All @@ -19,7 +19,7 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | MainModelConfig | null) => {
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type {
MainModelConfig,
T5EncoderBnbQuantizedLlmInt8bModelConfig,
T5EncoderModelConfig,
} from 'services/api/types';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

Expand All @@ -24,7 +20,7 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | MainModelConfig | null) => {
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
import type { VAEModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

Expand All @@ -16,7 +16,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | MainModelConfig | null) => {
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
Expand Down
14 changes: 14 additions & 0 deletions invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import type {
BoardFieldValue,
BooleanFieldValue,
CLIPEmbedModelFieldValue,
CLIPGEmbedModelFieldValue,
CLIPLEmbedModelFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
Expand All @@ -33,6 +35,8 @@ import {
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zColorFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
Expand Down Expand Up @@ -354,6 +358,12 @@ export const nodesSlice = createSlice({
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldCLIPLEmbedValueChanged: (state, action: FieldValueAction<CLIPLEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPLEmbedModelFieldValue);
},
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
Expand Down Expand Up @@ -420,6 +430,8 @@ export const {
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
Expand Down Expand Up @@ -527,6 +539,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldFluxVAEModelValueChanged,
// The `nodesChanged` has extra logic and is handled in its own extra reducer
// nodesChanged,
Expand Down
60 changes: 60 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPLEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPLEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
Expand Down Expand Up @@ -189,6 +197,8 @@ const zStatefulFieldType = z.union([
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
Expand Down Expand Up @@ -753,6 +763,52 @@ export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmb

// #endregion

// #region CLIPLEmbedModelField

export const zCLIPLEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPLEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPLEmbedModelFieldValue,
});
const zCLIPLEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPLEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPLEmbedModelFieldValue,
});

export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;

export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;

// #endregion

// #region CLIPGEmbedModelField

export const zCLIPGEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPGEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPGEmbedModelFieldValue,
});
const zCLIPGEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPGEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPGEmbedModelFieldValue,
});

export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;

export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;

// #endregion

// #region SchedulerField

export const zSchedulerFieldValue = zSchedulerField.optional();
Expand Down Expand Up @@ -845,6 +901,8 @@ export const zStatefulFieldValue = z.union([
zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
Expand Down Expand Up @@ -912,6 +970,8 @@ const zStatefulFieldInputTemplate = z.union([
zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
};

export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
Expand Down
Loading

0 comments on commit 9543149

Please sign in to comment.