diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 705bfe992ff..2d261348184 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1809,13 +1809,81 @@ "process": "Process", "apply": "Apply", "cancel": "Cancel", - "spandrel": { + "spandrel_filter": { "label": "Image-to-Image Model", "description": "Run an image-to-image model on the selected layer.", - "paramModel": "Model", - "paramAutoScale": "Auto Scale", - "paramAutoScaleDesc": "The selected model will be run until the target scale is reached.", - "paramScale": "Target Scale" + "model": "Model", + "autoScale": "Auto Scale", + "autoScaleDesc": "The selected model will be run until the target scale is reached.", + "scale": "Target Scale" + }, + "canny_edge_detection": { + "label": "Canny Edge Detection", + "description": "Generates an edge map from the selected layer using the Canny edge detection algorithm.", + "low_threshold": "Low Threshold", + "high_threshold": "Hight Threshold" + }, + "color_map": { + "label": "Color Map", + "description": "Create a color map from the selected layer.", + "tile_size": "Tile Size" + }, + "content_shuffle": { + "label": "Content Shuffle", + "description": "Shuffles the content of the selected layer, similar to a 'liquify' effect.", + "scale_factor": "Scale Factor" + }, + "depth_anything_depth_estimation": { + "label": "Depth Anything", + "description": "Generates a depth map from the selected layer using a Depth Anything model.", + "model_size": "Model Size", + "model_size_small": "Small", + "model_size_small_v2": "Small v2", + "model_size_base": "Base", + "model_size_large": "Large" + }, + "dw_openpose_detection": { + "label": "DW Openpose Detection", + "description": "Detects human poses in the selected layer using the DW Openpose model.", + "draw_hands": "Draw Hands", + "draw_face": "Draw Face", + "draw_body": "Draw Body" + }, + "hed_edge_detection": { + "label": "HED Edge Detection", + "description": "Generates an edge map from the selected layer using the HED edge detection model.", + "scribble": "Scribble" + }, + "lineart_anime_edge_detection": { + "label": "Lineart Anime Edge Detection", + "description": "Generates an edge map from the selected layer using the Lineart Anime edge detection model." + }, + "lineart_edge_detection": { + "label": "Lineart Edge Detection", + "description": "Generates an edge map from the selected layer using the Lineart edge detection model.", + "coarse": "Coarse" + }, + "mediapipe_face_detection": { + "label": "MediaPipe Face Detection", + "description": "Detects faces in the selected layer using the MediaPipe face detection model.", + "max_faces": "Max Faces", + "min_confidence": "Min Confidence" + }, + "mlsd_detection": { + "label": "Line Segment Detection", + "description": "Generates a line segment map from the selected layer using the MLSD line segment detection model.", + "score_threshold": "Score Threshold", + "distance_threshold": "Distance Threshold" + }, + "normal_map": { + "label": "Normal Map", + "description": "Generates a normal map from the selected layer." + }, + "pidi_edge_detection": { + "label": "PiDiNet Edge Detection", + "description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.", + "scribble": "Scribble", + "quantize_edges": "Quantize Edges" } }, "transform": { diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 6b71fb2720c..436ed675c1e 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,4 +1,4 @@ -import type { FilterType } from 'features/controlLayers/store/types'; +import type { FilterType } from 'features/controlLayers/store/filters'; import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import type { TabName } from 'features/ui/store/uiTypes'; import type { O } from 'ts-toolbelt'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx index 877595e1c4f..6c339ce79a0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx @@ -10,7 +10,8 @@ import { selectAutoProcessFilter, settingsAutoProcessFilterToggled, } from 'features/controlLayers/store/canvasSettingsSlice'; -import { type FilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import type { FilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCanny.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCannyEdgeDetection.tsx similarity index 70% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCanny.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCannyEdgeDetection.tsx index 781053e95ca..f799caaff78 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCanny.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterCannyEdgeDetection.tsx @@ -1,15 +1,15 @@ import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { CannyProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import type { CannyEdgeDetectionFilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['canny_image_processor'].buildDefaults(); +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS.canny_edge_detection.buildDefaults(); -export const FilterCanny = ({ onChange, config }: Props) => { +export const FilterCannyEdgeDetection = ({ onChange, config }: Props) => { const { t } = useTranslation(); const handleLowThresholdChanged = useCallback( (v: number) => { @@ -27,7 +27,7 @@ export const FilterCanny = ({ onChange, config }: Props) => { return ( <> - {t('controlnet.lowThreshold')} + {t('controlLayers.filter.canny_edge_detection.low_threshold')} { /> - {t('controlnet.highThreshold')} + {t('controlLayers.filter.canny_edge_detection.high_threshold')} { ); }; -FilterCanny.displayName = 'FilterCanny'; +FilterCannyEdgeDetection.displayName = 'FilterCannyEdgeDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterColorMap.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterColorMap.tsx index 785af042234..9c56bc4e7ea 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterColorMap.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterColorMap.tsx @@ -1,19 +1,18 @@ import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import { type ColorMapFilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['color_map_image_processor'].buildDefaults(); +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS.color_map.buildDefaults(); export const FilterColorMap = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleColorMapTileSizeChanged = useCallback( (v: number) => { - onChange({ ...config, color_map_tile_size: v }); + onChange({ ...config, tile_size: v }); }, [config, onChange] ); @@ -21,10 +20,10 @@ export const FilterColorMap = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.colorMapTileSize')} + {t('controlLayers.filter.color_map.tile_size')} { marks /> ; -const DEFAULTS = IMAGE_FILTERS['content_shuffle_image_processor'].buildDefaults(); +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS.content_shuffle.buildDefaults(); export const FilterContentShuffle = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); - const handleWChanged = useCallback( + const handleScaleFactorChanged = useCallback( (v: number) => { - onChange({ ...config, w: v }); - }, - [config, onChange] - ); - - const handleHChanged = useCallback( - (v: number) => { - onChange({ ...config, h: v }); - }, - [config, onChange] - ); - - const handleFChanged = useCallback( - (v: number) => { - onChange({ ...config, f: v }); + onChange({ ...config, scale_factor: v }); }, [config, onChange] ); @@ -36,40 +22,22 @@ export const FilterContentShuffle = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.w')} + {t('controlLayers.filter.content_shuffle.scale_factor')} - - - - {t('controlnet.h')} - - - - - {t('controlnet.f')} - - ); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenpose.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenposeDetection.tsx similarity index 68% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenpose.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenposeDetection.tsx index d0f22bae203..9522823b656 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenpose.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDWOpenposeDetection.tsx @@ -1,16 +1,15 @@ import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; -import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import { type DWOpenposeDetectionFilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import type { ChangeEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['dw_openpose_image_processor'].buildDefaults(); +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS['dw_openpose_detection'].buildDefaults(); -export const FilterDWOpenpose = memo(({ onChange, config }: Props) => { +export const FilterDWOpenposeDetection = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleDrawBodyChanged = useCallback( @@ -38,15 +37,15 @@ export const FilterDWOpenpose = memo(({ onChange, config }: Props) => { <> - {t('controlnet.body')} + {t('controlLayers.filter.dw_openpose_detection.draw_body')} - {t('controlnet.face')} + {t('controlLayers.filter.dw_openpose_detection.draw_face')} - {t('controlnet.hands')} + {t('controlLayers.filter.dw_openpose_detection.draw_hands')} { ); }); -FilterDWOpenpose.displayName = 'FilterDWOpenpose'; +FilterDWOpenposeDetection.displayName = 'FilterDWOpenposeDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnything.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnythingDepthEstimation.tsx similarity index 53% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnything.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnythingDepthEstimation.tsx index 46abf1da4e6..8a06512f722 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnything.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterDepthAnythingDepthEstimation.tsx @@ -1,15 +1,15 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library'; import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { DepthAnythingModelSize, DepthAnythingProcessorConfig } from 'features/controlLayers/store/types'; -import { isDepthAnythingModelSize } from 'features/controlLayers/store/types'; +import type { DepthAnythingFilterConfig, DepthAnythingModelSize } from 'features/controlLayers/store/filters'; +import { isDepthAnythingModelSize } from 'features/controlLayers/store/filters'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; +type Props = FilterComponentProps; -export const FilterDepthAnything = memo(({ onChange, config }: Props) => { +export const FilterDepthAnythingDepthEstimation = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleModelSizeChange = useCallback( (v) => { @@ -23,10 +23,10 @@ export const FilterDepthAnything = memo(({ onChange, config }: Props) => { const options: { label: string; value: DepthAnythingModelSize }[] = useMemo( () => [ - { label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' }, - { label: t('controlnet.small'), value: 'small' }, - { label: t('controlnet.base'), value: 'base' }, - { label: t('controlnet.large'), value: 'large' }, + { label: t('controlLayers.filter.depth_anything_depth_estimation.model_size_small_v2'), value: 'small_v2' }, + { label: t('controlLayers.filter.depth_anything_depth_estimation.model_size_small'), value: 'small' }, + { label: t('controlLayers.filter.depth_anything_depth_estimation.model_size_base'), value: 'base' }, + { label: t('controlLayers.filter.depth_anything_depth_estimation.model_size_large'), value: 'large' }, ], [t] ); @@ -36,11 +36,11 @@ export const FilterDepthAnything = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.modelSize')} + {t('controlLayers.filter.depth_anything_depth_estimation.model_size')} ); }); -FilterDepthAnything.displayName = 'FilterDepthAnything'; +FilterDepthAnythingDepthEstimation.displayName = 'FilterDepthAnythingDepthEstimation'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHed.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHEDEdgeDetection.tsx similarity index 62% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHed.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHEDEdgeDetection.tsx index 50ed535da19..810da9a4c7a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHed.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterHEDEdgeDetection.tsx @@ -1,14 +1,14 @@ import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; -import type { HedProcessorConfig } from 'features/controlLayers/store/types'; +import type { HEDEdgeDetectionFilterConfig } from 'features/controlLayers/store/filters'; import type { ChangeEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; +type Props = FilterComponentProps; -export const FilterHed = memo(({ onChange, config }: Props) => { +export const FilterHEDEdgeDetection = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleScribbleChanged = useCallback( @@ -21,11 +21,11 @@ export const FilterHed = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.scribble')} + {t('controlLayers.filter.hed_edge_detection.scribble')} ); }); -FilterHed.displayName = 'FilterHed'; +FilterHEDEdgeDetection.displayName = 'FilterHEDEdgeDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineart.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineartEdgeDetection.tsx similarity index 60% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineart.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineartEdgeDetection.tsx index 9b6f57f9d8f..a840b65d273 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineart.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterLineartEdgeDetection.tsx @@ -1,14 +1,14 @@ import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; -import type { LineartProcessorConfig } from 'features/controlLayers/store/types'; +import type { LineartEdgeDetectionFilterConfig } from 'features/controlLayers/store/filters'; import type { ChangeEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; +type Props = FilterComponentProps; -export const FilterLineart = memo(({ onChange, config }: Props) => { +export const FilterLineartEdgeDetection = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleCoarseChanged = useCallback( @@ -21,11 +21,11 @@ export const FilterLineart = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.coarse')} + {t('controlLayers.filter.lineart_edge_detection.coarse')} ); }); -FilterLineart.displayName = 'FilterLineart'; +FilterLineartEdgeDetection.displayName = 'FilterLineartEdgeDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMLSDDetection.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMLSDDetection.tsx new file mode 100644 index 00000000000..3248008d7d2 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMLSDDetection.tsx @@ -0,0 +1,75 @@ +import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import type { MLSDDetectionFilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS } from 'features/controlLayers/store/filters'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +import type { FilterComponentProps } from './types'; + +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS.mlsd_detection.buildDefaults(); + +export const FilterMLSDDetection = memo(({ onChange, config }: Props) => { + const { t } = useTranslation(); + + const onDistanceThresholdChanged = useCallback( + (v: number) => { + onChange({ ...config, distance_threshold: v }); + }, + [config, onChange] + ); + + const onScoreThresholdChanged = useCallback( + (v: number) => { + onChange({ ...config, score_threshold: v }); + }, + [config, onChange] + ); + + return ( + <> + + {t('controlLayers.filter.mlsd_detection.score_threshold')} + + + + + {t('controlLayers.filter.mlsd_detection.distance_threshold')} + + + + + ); +}); + +FilterMLSDDetection.displayName = 'FilterMLSDDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediapipeFace.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediaPipeFaceDetection.tsx similarity index 70% rename from invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediapipeFace.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediaPipeFaceDetection.tsx index 6674434d0a4..6c20817644b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediapipeFace.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMediaPipeFaceDetection.tsx @@ -1,15 +1,15 @@ import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import type { MediaPipeFaceDetectionFilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import type { FilterComponentProps } from './types'; -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['mediapipe_face_processor'].buildDefaults(); +type Props = FilterComponentProps; +const DEFAULTS = IMAGE_FILTERS.mediapipe_face_detection.buildDefaults(); -export const FilterMediapipeFace = memo(({ onChange, config }: Props) => { +export const FilterMediaPipeFaceDetection = memo(({ onChange, config }: Props) => { const { t } = useTranslation(); const handleMaxFacesChanged = useCallback( @@ -29,7 +29,7 @@ export const FilterMediapipeFace = memo(({ onChange, config }: Props) => { return ( <> - {t('controlnet.maxFaces')} + {t('controlLayers.filter.mediapipe_face_detection.max_faces')} { /> - {t('controlnet.minConfidence')} + {t('controlLayers.filter.mediapipe_face_detection.min_confidence')} { ); }); -FilterMediapipeFace.displayName = 'FilterMediapipeFace'; +FilterMediaPipeFaceDetection.displayName = 'FilterMediaPipeFaceDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMidasDepth.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMidasDepth.tsx deleted file mode 100644 index 9024b45a889..00000000000 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMidasDepth.tsx +++ /dev/null @@ -1,75 +0,0 @@ -import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; - -import type { FilterComponentProps } from './types'; - -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['midas_depth_image_processor'].buildDefaults(); - -export const FilterMidasDepth = memo(({ onChange, config }: Props) => { - const { t } = useTranslation(); - - const handleAMultChanged = useCallback( - (v: number) => { - onChange({ ...config, a_mult: v }); - }, - [config, onChange] - ); - - const handleBgThChanged = useCallback( - (v: number) => { - onChange({ ...config, bg_th: v }); - }, - [config, onChange] - ); - - return ( - <> - - {t('controlnet.amult')} - - - - - {t('controlnet.bgth')} - - - - - ); -}); - -FilterMidasDepth.displayName = 'FilterMidasDepth'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMlsdImage.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMlsdImage.tsx deleted file mode 100644 index e16d77990a5..00000000000 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterMlsdImage.tsx +++ /dev/null @@ -1,75 +0,0 @@ -import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import type { MlsdProcessorConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; - -import type { FilterComponentProps } from './types'; - -type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['mlsd_image_processor'].buildDefaults(); - -export const FilterMlsdImage = memo(({ onChange, config }: Props) => { - const { t } = useTranslation(); - - const handleThrDChanged = useCallback( - (v: number) => { - onChange({ ...config, thr_d: v }); - }, - [config, onChange] - ); - - const handleThrVChanged = useCallback( - (v: number) => { - onChange({ ...config, thr_v: v }); - }, - [config, onChange] - ); - - return ( - <> - - {t('controlnet.w')} - - - - - {t('controlnet.h')} - - - - - ); -}); - -FilterMlsdImage.displayName = 'FilterMlsdImage'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection.tsx new file mode 100644 index 00000000000..05812a3f2e4 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection.tsx @@ -0,0 +1,42 @@ +import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; +import type { PiDiNetEdgeDetectionFilterConfig } from 'features/controlLayers/store/filters'; +import type { ChangeEvent } from 'react'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +import type { FilterComponentProps } from './types'; + +type Props = FilterComponentProps; + +export const FilterPiDiNetEdgeDetection = ({ onChange, config }: Props) => { + const { t } = useTranslation(); + + const onScribbleChanged = useCallback( + (e: ChangeEvent) => { + onChange({ ...config, scribble: e.target.checked }); + }, + [config, onChange] + ); + + const onQuantizeEdgesChanged = useCallback( + (e: ChangeEvent) => { + onChange({ ...config, quantize_edges: e.target.checked }); + }, + [config, onChange] + ); + + return ( + <> + + {t('controlLayers.filter.pidi_edge_detection.scribble')} + + + + {t('controlLayers.filter.pidi_edge_detection.quantize_edges')} + + + + ); +}; + +FilterPiDiNetEdgeDetection.displayName = 'FilterPiDiNetEdgeDetection'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPidi.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPidi.tsx deleted file mode 100644 index 1814edc0140..00000000000 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterPidi.tsx +++ /dev/null @@ -1,42 +0,0 @@ -import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; -import type { PidiProcessorConfig } from 'features/controlLayers/store/types'; -import type { ChangeEvent } from 'react'; -import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; - -import type { FilterComponentProps } from './types'; - -type Props = FilterComponentProps; - -export const FilterPidi = ({ onChange, config }: Props) => { - const { t } = useTranslation(); - - const handleScribbleChanged = useCallback( - (e: ChangeEvent) => { - onChange({ ...config, scribble: e.target.checked }); - }, - [config, onChange] - ); - - const handleSafeChanged = useCallback( - (e: ChangeEvent) => { - onChange({ ...config, safe: e.target.checked }); - }, - [config, onChange] - ); - - return ( - <> - - {t('controlnet.scribble')} - - - - {t('controlnet.safe')} - - - - ); -}; - -FilterPidi.displayName = 'FilterPidi'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSettings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSettings.tsx index 03769181ba5..411571fe09c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSettings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSettings.tsx @@ -1,18 +1,16 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { FilterCanny } from 'features/controlLayers/components/Filters/FilterCanny'; +import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection'; import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap'; import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle'; -import { FilterDepthAnything } from 'features/controlLayers/components/Filters/FilterDepthAnything'; -import { FilterDWOpenpose } from 'features/controlLayers/components/Filters/FilterDWOpenpose'; -import { FilterHed } from 'features/controlLayers/components/Filters/FilterHed'; -import { FilterLineart } from 'features/controlLayers/components/Filters/FilterLineart'; -import { FilterMediapipeFace } from 'features/controlLayers/components/Filters/FilterMediapipeFace'; -import { FilterMidasDepth } from 'features/controlLayers/components/Filters/FilterMidasDepth'; -import { FilterMlsdImage } from 'features/controlLayers/components/Filters/FilterMlsdImage'; -import { FilterPidi } from 'features/controlLayers/components/Filters/FilterPidi'; +import { FilterDepthAnythingDepthEstimation } from 'features/controlLayers/components/Filters/FilterDepthAnythingDepthEstimation'; +import { FilterDWOpenposeDetection } from 'features/controlLayers/components/Filters/FilterDWOpenposeDetection'; +import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filters/FilterHEDEdgeDetection'; +import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection'; +import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection'; +import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection'; +import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection'; import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel'; -import type { FilterConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import type { FilterConfig } from 'features/controlLayers/store/filters'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -21,48 +19,44 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig export const FilterSettings = memo(({ filterConfig, onChange }: Props) => { const { t } = useTranslation(); - if (filterConfig.type === 'canny_image_processor') { - return ; + if (filterConfig.type === 'canny_edge_detection') { + return ; } - if (filterConfig.type === 'color_map_image_processor') { + if (filterConfig.type === 'color_map') { return ; } - if (filterConfig.type === 'content_shuffle_image_processor') { + if (filterConfig.type === 'content_shuffle') { return ; } - if (filterConfig.type === 'depth_anything_image_processor') { - return ; + if (filterConfig.type === 'depth_anything_depth_estimation') { + return ; } - if (filterConfig.type === 'dw_openpose_image_processor') { - return ; + if (filterConfig.type === 'dw_openpose_detection') { + return ; } - if (filterConfig.type === 'hed_image_processor') { - return ; + if (filterConfig.type === 'hed_edge_detection') { + return ; } - if (filterConfig.type === 'lineart_image_processor') { - return ; + if (filterConfig.type === 'lineart_edge_detection') { + return ; } - if (filterConfig.type === 'mediapipe_face_processor') { - return ; + if (filterConfig.type === 'mediapipe_face_detection') { + return ; } - if (filterConfig.type === 'midas_depth_image_processor') { - return ; + if (filterConfig.type === 'mlsd_detection') { + return ; } - if (filterConfig.type === 'mlsd_image_processor') { - return ; - } - - if (filterConfig.type === 'pidi_image_processor') { - return ; + if (filterConfig.type === 'pidi_edge_detection') { + return ; } if (filterConfig.type === 'spandrel_filter') { @@ -72,7 +66,7 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => { return ( ); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSpandrel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSpandrel.tsx index 973f72b6110..7cd64f0c08c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSpandrel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterSpandrel.tsx @@ -11,8 +11,8 @@ import { Tooltip, } from '@invoke-ai/ui-library'; import { useModelCombobox } from 'common/hooks/useModelCombobox'; -import type { SpandrelFilterConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS } from 'features/controlLayers/store/types'; +import type { SpandrelFilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS } from 'features/controlLayers/store/filters'; import type { ChangeEvent } from 'react'; import { useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -22,7 +22,7 @@ import type { SpandrelImageToImageModelConfig } from 'services/api/types'; import type { FilterComponentProps } from './types'; type Props = FilterComponentProps; -const DEFAULTS = IMAGE_FILTERS['spandrel_filter'].buildDefaults(); +const DEFAULTS = IMAGE_FILTERS.spandrel_filter.buildDefaults(); export const FilterSpandrel = ({ onChange, config }: Props) => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterTypeSelect.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterTypeSelect.tsx index bc512dc019c..3e9da8c8421 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterTypeSelect.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/FilterTypeSelect.tsx @@ -3,8 +3,8 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import type { FilterConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/types'; +import type { FilterConfig } from 'features/controlLayers/store/filters'; +import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/filters'; import { selectConfigSlice } from 'features/system/store/configSlice'; import { includes, map } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; @@ -22,7 +22,7 @@ export const FilterTypeSelect = memo(({ filterType, onChange }: Props) => { const { t } = useTranslation(); const disabledProcessors = useAppSelector(selectDisabledProcessors); const options = useMemo(() => { - return map(IMAGE_FILTERS, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter( + return map(IMAGE_FILTERS, (data, type) => ({ value: type, label: t(`controlLayers.filter.${type}.label`) })).filter( (o) => !includes(disabledProcessors, o.value) ); }, [disabledProcessors, t]); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/types.ts b/invokeai/frontend/web/src/features/controlLayers/components/Filters/types.ts index e4132640a58..9ad420a0cf7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/types.ts @@ -1,4 +1,4 @@ -import type { FilterConfig } from 'features/controlLayers/store/types'; +import type { FilterConfig } from 'features/controlLayers/store/filters'; export type FilterComponentProps = { onChange: (config: T) => void; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts index 2b569b69da5..a89e2102ac2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts @@ -1,17 +1,21 @@ import type { SerializableObject } from 'common/types'; +import { isErr, withResultAsync } from 'common/util/result'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice'; -import type { CanvasImageState, FilterConfig } from 'features/controlLayers/store/types'; -import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types'; +import type { FilterConfig } from 'features/controlLayers/store/filters'; +import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters'; +import type { CanvasImageState } from 'features/controlLayers/store/types'; +import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; import { getImageDTO } from 'services/api/endpoints/images'; -import type { BatchConfig, ImageDTO, S } from 'services/api/types'; +import { buildSelectModelConfig } from 'services/api/hooks/modelsByType'; +import { type BatchConfig, type ImageDTO, isControlNetOrT2IAdapterModelConfig, type S } from 'services/api/types'; import { assert } from 'tsafe'; type CanvasEntityFiltererConfig = { @@ -37,7 +41,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase { $isFiltering = atom(false); $hasProcessed = atom(false); $isProcessing = atom(false); - $filterConfig = atom(IMAGE_FILTERS.canny_image_processor.buildDefaults()); + $filterConfig = atom(IMAGE_FILTERS.canny_edge_detection.buildDefaults()); constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) { super(); @@ -69,6 +73,18 @@ export class CanvasEntityFilterer extends CanvasModuleBase { this.log.trace('Initializing filter'); if (config) { this.$filterConfig.set(config); + } else if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) { + // If the parent is a control layer adapter, we should check if the model has a default filter and set it if so + const selectModelConfig = buildSelectModelConfig( + this.parent.state.controlAdapter.model.key, + isControlNetOrT2IAdapterModelConfig + ); + const modelConfig = this.manager.stateApi.runSelector(selectModelConfig); + const filter = getFilterForModel(modelConfig); + this.$filterConfig.set(filter.buildDefaults()); + } else { + // Otherwise, set the default filter + this.$filterConfig.set(IMAGE_FILTERS.canny_edge_detection.buildDefaults()); } this.$isFiltering.set(true); this.manager.stateApi.$filteringAdapter.set(this.parent); @@ -92,11 +108,12 @@ export class CanvasEntityFilterer extends CanvasModuleBase { const batch = this.buildBatchConfig(imageDTO, config, nodeId); // Listen for the filter processing completion event - const listener = async (event: S['InvocationCompleteEvent']) => { + const completedListener = async (event: S['InvocationCompleteEvent']) => { if (event.origin !== this.id || event.invocation_source_id !== nodeId) { return; } - this.manager.socket.off('invocation_complete', listener); + this.manager.socket.off('invocation_complete', completedListener); + this.manager.socket.off('invocation_error', errorListener); this.log.trace({ event } as SerializableObject, 'Handling filter processing completion'); @@ -114,13 +131,29 @@ export class CanvasEntityFilterer extends CanvasModuleBase { this.$isProcessing.set(false); this.$hasProcessed.set(true); }; + const errorListener = (event: S['InvocationErrorEvent']) => { + if (event.origin !== this.id || event.invocation_source_id !== nodeId) { + return; + } + this.manager.socket.off('invocation_complete', completedListener); + this.manager.socket.off('invocation_error', errorListener); + + this.log.error({ event } as SerializableObject, 'Error processing filter'); + this.$isProcessing.set(false); + }; - this.manager.socket.on('invocation_complete', listener); + this.manager.socket.on('invocation_complete', completedListener); + this.manager.socket.on('invocation_error', errorListener); this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch'); this.$isProcessing.set(true); - this.manager.stateApi.enqueueBatch(batch); + const req = this.manager.stateApi.enqueueBatch(batch); + const result = await withResultAsync(req.unwrap); + if (isErr(result)) { + this.$isProcessing.set(false); + } + req.reset(); }, this.config.processDebounceMs, { leading: true, trailing: true } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 382f6c88169..f3b51a26930 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -191,7 +191,7 @@ export class CanvasStateApiModule extends CanvasModuleBase { * Enqueues a batch, pushing state to redux. */ enqueueBatch = (batch: BatchConfig) => { - this.store.dispatch( + return this.store.dispatch( queueApi.endpoints.enqueueBatch.initiate(batch, { fixedCacheKey: 'enqueueBatch', }) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts new file mode 100644 index 00000000000..bd138666b27 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts @@ -0,0 +1,393 @@ +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import type { ImageWithDims } from 'features/controlLayers/store/types'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import type { AnyInvocation, ControlNetModelConfig, Invocation, T2IAdapterModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; +import { z } from 'zod'; + +const zCannyEdgeDetectionFilterConfig = z.object({ + type: z.literal('canny_edge_detection'), + low_threshold: z.number().int().gte(0).lte(255), + high_threshold: z.number().int().gte(0).lte(255), +}); +export type CannyEdgeDetectionFilterConfig = z.infer; + +const zColorMapFilterConfig = z.object({ + type: z.literal('color_map'), + tile_size: z.number().int().gte(1), +}); +export type ColorMapFilterConfig = z.infer; + +const zContentShuffleFilterConfig = z.object({ + type: z.literal('content_shuffle'), + scale_factor: z.number().int().gte(1), +}); +export type ContentShuffleFilterConfig = z.infer; + +const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']); +export type DepthAnythingModelSize = z.infer; +export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize => + zDepthAnythingModelSize.safeParse(v).success; + +const zDepthAnythingFilterConfig = z.object({ + type: z.literal('depth_anything_depth_estimation'), + model_size: zDepthAnythingModelSize, +}); +export type DepthAnythingFilterConfig = z.infer; + +const zHEDEdgeDetectionFilterConfig = z.object({ + type: z.literal('hed_edge_detection'), + scribble: z.boolean(), +}); +export type HEDEdgeDetectionFilterConfig = z.infer; + +const zLineartAnimeEdgeDetectionFilterConfig = z.object({ + type: z.literal('lineart_anime_edge_detection'), +}); +export type LineartAnimeEdgeDetectionFilterConfig = z.infer; + +const zLineartEdgeDetectionFilterConfig = z.object({ + type: z.literal('lineart_edge_detection'), + coarse: z.boolean(), +}); +export type LineartEdgeDetectionFilterConfig = z.infer; + +const zMediaPipeFaceDetectionFilterConfig = z.object({ + type: z.literal('mediapipe_face_detection'), + max_faces: z.number().int().gte(1), + min_confidence: z.number().gte(0).lte(1), +}); +export type MediaPipeFaceDetectionFilterConfig = z.infer; + +const zMLSDDetectionFilterConfig = z.object({ + type: z.literal('mlsd_detection'), + score_threshold: z.number().gte(0), + distance_threshold: z.number().gte(0), +}); +export type MLSDDetectionFilterConfig = z.infer; + +const zNormalMapFilterConfig = z.object({ + type: z.literal('normal_map'), +}); +export type NormalMapFilterConfig = z.infer; + +const zPiDiNetEdgeDetectionFilterConfig = z.object({ + type: z.literal('pidi_edge_detection'), + quantize_edges: z.boolean(), + scribble: z.boolean(), +}); +export type PiDiNetEdgeDetectionFilterConfig = z.infer; + +const zDWOpenposeDetectionFilterConfig = z.object({ + type: z.literal('dw_openpose_detection'), + draw_body: z.boolean(), + draw_face: z.boolean(), + draw_hands: z.boolean(), +}); +export type DWOpenposeDetectionFilterConfig = z.infer; + +const zSpandrelFilterConfig = z.object({ + type: z.literal('spandrel_filter'), + model: zModelIdentifierField.nullable(), + autoScale: z.boolean(), + scale: z.number().gte(1).lte(16), +}); +export type SpandrelFilterConfig = z.infer; + +const zFilterConfig = z.discriminatedUnion('type', [ + zCannyEdgeDetectionFilterConfig, + zColorMapFilterConfig, + zContentShuffleFilterConfig, + zDepthAnythingFilterConfig, + zHEDEdgeDetectionFilterConfig, + zLineartAnimeEdgeDetectionFilterConfig, + zLineartEdgeDetectionFilterConfig, + zMediaPipeFaceDetectionFilterConfig, + zMLSDDetectionFilterConfig, + zNormalMapFilterConfig, + zPiDiNetEdgeDetectionFilterConfig, + zDWOpenposeDetectionFilterConfig, + zSpandrelFilterConfig, +]); +export type FilterConfig = z.infer; + +const zFilterType = z.enum([ + 'canny_edge_detection', + 'color_map', + 'content_shuffle', + 'depth_anything_depth_estimation', + 'hed_edge_detection', + 'lineart_anime_edge_detection', + 'lineart_edge_detection', + 'mediapipe_face_detection', + 'mlsd_detection', + 'normal_map', + 'pidi_edge_detection', + 'dw_openpose_detection', + 'spandrel_filter', +]); +export type FilterType = z.infer; +export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success; + +type ImageFilterData = { + type: T; + buildDefaults(): Extract; + buildNode(imageDTO: ImageWithDims, config: Extract): AnyInvocation; + validateConfig?(config: Extract): boolean; +}; + +export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData } = { + canny_edge_detection: { + type: 'canny_edge_detection', + buildDefaults: () => ({ + type: 'canny_edge_detection', + low_threshold: 100, + high_threshold: 200, + }), + buildNode: ({ image_name }, { low_threshold, high_threshold }): Invocation<'canny_edge_detection'> => ({ + id: getPrefixedId('canny_edge_detection'), + type: 'canny_edge_detection', + image: { image_name }, + low_threshold, + high_threshold, + }), + }, + color_map: { + type: 'color_map', + buildDefaults: () => ({ + type: 'color_map', + tile_size: 64, + }), + buildNode: ({ image_name }, { tile_size }): Invocation<'color_map'> => ({ + id: getPrefixedId('color_map'), + type: 'color_map', + image: { image_name }, + tile_size, + }), + }, + content_shuffle: { + type: 'content_shuffle', + buildDefaults: () => ({ + type: 'content_shuffle', + scale_factor: 256, + }), + buildNode: ({ image_name }, { scale_factor }): Invocation<'content_shuffle'> => ({ + id: getPrefixedId('content_shuffle'), + type: 'content_shuffle', + image: { image_name }, + scale_factor, + }), + }, + depth_anything_depth_estimation: { + type: 'depth_anything_depth_estimation', + buildDefaults: () => ({ + type: 'depth_anything_depth_estimation', + model_size: 'small_v2', + }), + buildNode: ({ image_name }, { model_size }): Invocation<'depth_anything_depth_estimation'> => ({ + id: getPrefixedId('depth_anything_depth_estimation'), + type: 'depth_anything_depth_estimation', + image: { image_name }, + model_size, + }), + }, + hed_edge_detection: { + type: 'hed_edge_detection', + buildDefaults: () => ({ + type: 'hed_edge_detection', + scribble: false, + }), + buildNode: ({ image_name }, { scribble }): Invocation<'hed_edge_detection'> => ({ + id: getPrefixedId('hed_edge_detection'), + type: 'hed_edge_detection', + image: { image_name }, + scribble, + }), + }, + lineart_anime_edge_detection: { + type: 'lineart_anime_edge_detection', + buildDefaults: () => ({ + type: 'lineart_anime_edge_detection', + }), + buildNode: ({ image_name }): Invocation<'lineart_anime_edge_detection'> => ({ + id: getPrefixedId('lineart_anime_edge_detection'), + type: 'lineart_anime_edge_detection', + image: { image_name }, + }), + }, + lineart_edge_detection: { + type: 'lineart_edge_detection', + buildDefaults: () => ({ + type: 'lineart_edge_detection', + coarse: false, + }), + buildNode: ({ image_name }, { coarse }): Invocation<'lineart_edge_detection'> => ({ + id: getPrefixedId('lineart_edge_detection'), + type: 'lineart_edge_detection', + image: { image_name }, + coarse, + }), + }, + mediapipe_face_detection: { + type: 'mediapipe_face_detection', + buildDefaults: () => ({ + type: 'mediapipe_face_detection', + max_faces: 1, + min_confidence: 0.5, + }), + buildNode: ({ image_name }, { max_faces, min_confidence }): Invocation<'mediapipe_face_detection'> => ({ + id: getPrefixedId('mediapipe_face_detection'), + type: 'mediapipe_face_detection', + image: { image_name }, + max_faces, + min_confidence, + }), + }, + mlsd_detection: { + type: 'mlsd_detection', + buildDefaults: () => ({ + type: 'mlsd_detection', + score_threshold: 0.1, + distance_threshold: 20.0, + }), + buildNode: ({ image_name }, { score_threshold, distance_threshold }): Invocation<'mlsd_detection'> => ({ + id: getPrefixedId('mlsd_detection'), + type: 'mlsd_detection', + image: { image_name }, + score_threshold, + distance_threshold, + }), + }, + normal_map: { + type: 'normal_map', + buildDefaults: () => ({ + type: 'normal_map', + }), + buildNode: ({ image_name }): Invocation<'normal_map'> => ({ + id: getPrefixedId('normal_map'), + type: 'normal_map', + image: { image_name }, + }), + }, + pidi_edge_detection: { + type: 'pidi_edge_detection', + buildDefaults: () => ({ + type: 'pidi_edge_detection', + quantize_edges: false, + scribble: false, + }), + buildNode: ({ image_name }, { quantize_edges, scribble }): Invocation<'pidi_edge_detection'> => ({ + id: getPrefixedId('pidi_edge_detection'), + type: 'pidi_edge_detection', + image: { image_name }, + quantize_edges, + scribble, + }), + }, + dw_openpose_detection: { + type: 'dw_openpose_detection', + buildDefaults: () => ({ + type: 'dw_openpose_detection', + draw_body: true, + draw_face: true, + draw_hands: true, + }), + buildNode: ({ image_name }, { draw_body, draw_face, draw_hands }): Invocation<'dw_openpose_detection'> => ({ + id: getPrefixedId('dw_openpose_detection'), + type: 'dw_openpose_detection', + image: { image_name }, + draw_body, + draw_face, + draw_hands, + }), + }, + spandrel_filter: { + type: 'spandrel_filter', + buildDefaults: () => ({ + type: 'spandrel_filter', + model: null, + autoScale: true, + scale: 1, + }), + buildNode: ( + { image_name }, + { model, scale, autoScale } + ): Invocation<'spandrel_image_to_image' | 'spandrel_image_to_image_autoscale'> => { + assert(model !== null); + if (autoScale) { + const node: Invocation<'spandrel_image_to_image_autoscale'> = { + id: getPrefixedId('spandrel_image_to_image_autoscale'), + type: 'spandrel_image_to_image_autoscale', + image_to_image_model: model, + image: { image_name }, + scale, + }; + return node; + } else { + const node: Invocation<'spandrel_image_to_image'> = { + id: getPrefixedId('spandrel_image_to_image'), + type: 'spandrel_image_to_image', + image_to_image_model: model, + image: { image_name }, + }; + return node; + } + }, + validateConfig: (config): boolean => { + if (!config.model) { + return false; + } + return true; + }, + }, +} as const; + +/** + * A map of the v1 processor names to the new filter types. + */ +const PROCESSOR_TO_FILTER_MAP: Record = { + canny_image_processor: 'canny_edge_detection', + mlsd_image_processor: 'mlsd_detection', + depth_anything_image_processor: 'depth_anything_depth_estimation', + normalbae_image_processor: 'normal_map', + pidi_image_processor: 'pidi_edge_detection', + lineart_image_processor: 'lineart_edge_detection', + lineart_anime_image_processor: 'lineart_anime_edge_detection', + hed_image_processor: 'hed_edge_detection', + content_shuffle_image_processor: 'content_shuffle', + dw_openpose_image_processor: 'dw_openpose_detection', + mediapipe_face_processor: 'mediapipe_face_detection', + zoe_depth_image_processor: 'depth_anything_depth_estimation', + color_map_image_processor: 'color_map', +}; + +/** + * Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default + * filter for the model type will be used. + */ +export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => { + if (!modelConfig) { + // No model, use the default filter + return IMAGE_FILTERS.canny_edge_detection; + } + + const preprocessor = modelConfig?.default_settings?.preprocessor; + if (!preprocessor) { + // No preprocessor, use the default filter + return IMAGE_FILTERS.canny_edge_detection; + } + + if (isFilterType(preprocessor)) { + // Valid filter type, use it + return IMAGE_FILTERS[preprocessor]; + } + + const filterName = PROCESSOR_TO_FILTER_MAP[preprocessor]; + if (!filterName) { + // No filter found, use the default filter + return IMAGE_FILTERS.canny_edge_detection; + } + + // Found a filter, use it + return IMAGE_FILTERS[filterName]; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.test.ts index 0e5e8c4b69c..798c1ab0b99 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.test.ts @@ -1,30 +1,26 @@ +import type { + CannyEdgeDetectionFilterConfig, + ColorMapFilterConfig, + ContentShuffleFilterConfig, + DepthAnythingFilterConfig, + DepthAnythingModelSize, + DWOpenposeDetectionFilterConfig, + FilterConfig, + FilterType, + HEDEdgeDetectionFilterConfig, + LineartAnimeEdgeDetectionFilterConfig, + LineartEdgeDetectionFilterConfig, + MediaPipeFaceDetectionFilterConfig, + MLSDDetectionFilterConfig, + NormalMapFilterConfig, + PiDiNetEdgeDetectionFilterConfig, +} from 'features/controlLayers/store/filters'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; import { describe, test } from 'vitest'; -import type { - CannyProcessorConfig, - CLIPVisionModelV2, - ColorMapProcessorConfig, - ContentShuffleProcessorConfig, - ControlModeV2, - DepthAnythingModelSize, - DepthAnythingProcessorConfig, - DWOpenposeProcessorConfig, - FilterConfig, - FilterType, - HedProcessorConfig, - IPMethodV2, - LineartAnimeProcessorConfig, - LineartProcessorConfig, - MediapipeFaceProcessorConfig, - MidasDepthProcessorConfig, - MlsdProcessorConfig, - NormalbaeProcessorConfig, - PidiProcessorConfig, - ZoeDepthProcessorConfig, -} from './types'; +import type { CLIPVisionModelV2, ControlModeV2, IPMethodV2 } from './types'; describe('Control Adapter Types', () => { test('ProcessorType', () => { @@ -40,54 +36,46 @@ describe('Control Adapter Types', () => { assert['control_mode']>, ControlModeV2>>(); }); test('DepthAnything Model Size', () => { - assert['model_size']>, DepthAnythingModelSize>>(); + assert['model_size']>, DepthAnythingModelSize>>(); }); test('Processor Configs', () => { // The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct. // The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled. - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); - assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); }); }); // Types derived from OpenAPI -type _CannyProcessorConfig = Required< - Pick, 'id' | 'type' | 'low_threshold' | 'high_threshold'> ->; -type _ColorMapProcessorConfig = Required< - Pick, 'id' | 'type' | 'color_map_tile_size'> ->; -type _ContentShuffleProcessorConfig = Required< - Pick, 'id' | 'type' | 'w' | 'h' | 'f'> +type _CannyEdgeDetectionFilterConfig = Required< + Pick, 'type' | 'low_threshold' | 'high_threshold'> >; -type _DepthAnythingProcessorConfig = Required< - Pick, 'id' | 'type' | 'model_size'> +type _ColorMapFilterConfig = Required, 'type' | 'tile_size'>>; +type _ContentShuffleFilterConfig = Required, 'type' | 'scale_factor'>>; +type _DepthAnythingFilterConfig = Required, 'type' | 'model_size'>>; +type _HEDEdgeDetectionFilterConfig = Required, 'type' | 'scribble'>>; +type _LineartAnimeEdgeDetectionFilterConfig = Required, 'type'>>; +type _LineartEdgeDetectionFilterConfig = Required, 'type' | 'coarse'>>; +type _MediaPipeFaceDetectionFilterConfig = Required< + Pick, 'type' | 'max_faces' | 'min_confidence'> >; -type _HedProcessorConfig = Required, 'id' | 'type' | 'scribble'>>; -type _LineartAnimeProcessorConfig = Required, 'id' | 'type'>>; -type _LineartProcessorConfig = Required, 'id' | 'type' | 'coarse'>>; -type _MediapipeFaceProcessorConfig = Required< - Pick, 'id' | 'type' | 'max_faces' | 'min_confidence'> +type _MLSDDetectionFilterConfig = Required< + Pick, 'type' | 'score_threshold' | 'distance_threshold'> >; -type _MidasDepthProcessorConfig = Required< - Pick, 'id' | 'type' | 'a_mult' | 'bg_th'> +type _NormalMapFilterConfig = Required, 'type'>>; +type _DWOpenposeDetectionFilterConfig = Required< + Pick, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'> >; -type _MlsdProcessorConfig = Required, 'id' | 'type' | 'thr_v' | 'thr_d'>>; -type _NormalbaeProcessorConfig = Required, 'id' | 'type'>>; -type _DWOpenposeProcessorConfig = Required< - Pick, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'> +type _PiDiNetEdgeDetectionFilterConfig = Required< + Pick, 'type' | 'quantize_edges' | 'scribble'> >; -type _PidiProcessorConfig = Required, 'id' | 'type' | 'safe' | 'scribble'>>; -type _ZoeDepthProcessorConfig = Required, 'id' | 'type'>>; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 9ce4da0cffc..f96f4ea285a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -4,8 +4,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common'; import type { AspectRatioState } from 'features/parameters/components/Bbox/types'; import type { ParameterHeight, ParameterLoRAModel, ParameterWidth } from 'features/parameters/types/parameterSchemas'; import { zParameterNegativePrompt, zParameterPositivePrompt } from 'features/parameters/types/parameterSchemas'; -import type { AnyInvocation, BaseModelType, ImageDTO, Invocation, S } from 'services/api/types'; -import { assert } from 'tsafe'; +import type { ImageDTO } from 'services/api/types'; import { z } from 'zod'; const zId = z.string().min(1); @@ -36,463 +35,6 @@ const zIPMethodV2 = z.enum(['full', 'style', 'composition']); export type IPMethodV2 = z.infer; export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success; -const zCannyProcessorConfig = z.object({ - id: zId, - type: z.literal('canny_image_processor'), - low_threshold: z.number().int().gte(0).lte(255), - high_threshold: z.number().int().gte(0).lte(255), -}); -export type CannyProcessorConfig = z.infer; - -const zColorMapProcessorConfig = z.object({ - id: zId, - type: z.literal('color_map_image_processor'), - color_map_tile_size: z.number().int().gte(1), -}); -export type ColorMapProcessorConfig = z.infer; - -const zContentShuffleProcessorConfig = z.object({ - id: zId, - type: z.literal('content_shuffle_image_processor'), - w: z.number().int().gte(0), - h: z.number().int().gte(0), - f: z.number().int().gte(0), -}); -export type ContentShuffleProcessorConfig = z.infer; - -const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']); -export type DepthAnythingModelSize = z.infer; -export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize => - zDepthAnythingModelSize.safeParse(v).success; -const zDepthAnythingProcessorConfig = z.object({ - id: zId, - type: z.literal('depth_anything_image_processor'), - model_size: zDepthAnythingModelSize, -}); -export type DepthAnythingProcessorConfig = z.infer; - -const zHedProcessorConfig = z.object({ - id: zId, - type: z.literal('hed_image_processor'), - scribble: z.boolean(), -}); -export type HedProcessorConfig = z.infer; - -const zLineartAnimeProcessorConfig = z.object({ - id: zId, - type: z.literal('lineart_anime_image_processor'), -}); -export type LineartAnimeProcessorConfig = z.infer; - -const zLineartProcessorConfig = z.object({ - id: zId, - type: z.literal('lineart_image_processor'), - coarse: z.boolean(), -}); -export type LineartProcessorConfig = z.infer; - -const zMediapipeFaceProcessorConfig = z.object({ - id: zId, - type: z.literal('mediapipe_face_processor'), - max_faces: z.number().int().gte(1), - min_confidence: z.number().gte(0).lte(1), -}); -export type MediapipeFaceProcessorConfig = z.infer; - -const zMidasDepthProcessorConfig = z.object({ - id: zId, - type: z.literal('midas_depth_image_processor'), - a_mult: z.number().gte(0), - bg_th: z.number().gte(0), -}); -export type MidasDepthProcessorConfig = z.infer; - -const zMlsdProcessorConfig = z.object({ - id: zId, - type: z.literal('mlsd_image_processor'), - thr_v: z.number().gte(0), - thr_d: z.number().gte(0), -}); -export type MlsdProcessorConfig = z.infer; - -const zNormalbaeProcessorConfig = z.object({ - id: zId, - type: z.literal('normalbae_image_processor'), -}); -export type NormalbaeProcessorConfig = z.infer; - -const zDWOpenposeProcessorConfig = z.object({ - id: zId, - type: z.literal('dw_openpose_image_processor'), - draw_body: z.boolean(), - draw_face: z.boolean(), - draw_hands: z.boolean(), -}); -export type DWOpenposeProcessorConfig = z.infer; - -const zPidiProcessorConfig = z.object({ - id: zId, - type: z.literal('pidi_image_processor'), - safe: z.boolean(), - scribble: z.boolean(), -}); -export type PidiProcessorConfig = z.infer; - -const zZoeDepthProcessorConfig = z.object({ - id: zId, - type: z.literal('zoe_depth_image_processor'), -}); -export type ZoeDepthProcessorConfig = z.infer; - -const zSpandrelFilterConfig = z.object({ - id: zId, - type: z.literal('spandrel_filter'), - model: zModelIdentifierField.nullable(), - autoScale: z.boolean(), - scale: z.number().gte(1).lte(16), -}); -export type SpandrelFilterConfig = z.infer; - -const zFilterConfig = z.discriminatedUnion('type', [ - zCannyProcessorConfig, - zColorMapProcessorConfig, - zContentShuffleProcessorConfig, - zDepthAnythingProcessorConfig, - zHedProcessorConfig, - zLineartAnimeProcessorConfig, - zLineartProcessorConfig, - zMediapipeFaceProcessorConfig, - zMidasDepthProcessorConfig, - zMlsdProcessorConfig, - zNormalbaeProcessorConfig, - zDWOpenposeProcessorConfig, - zPidiProcessorConfig, - zZoeDepthProcessorConfig, - zSpandrelFilterConfig, -]); -export type FilterConfig = z.infer; - -const zFilterType = z.enum([ - 'canny_image_processor', - 'color_map_image_processor', - 'content_shuffle_image_processor', - 'depth_anything_image_processor', - 'hed_image_processor', - 'lineart_anime_image_processor', - 'lineart_image_processor', - 'mediapipe_face_processor', - 'midas_depth_image_processor', - 'mlsd_image_processor', - 'normalbae_image_processor', - 'dw_openpose_image_processor', - 'pidi_image_processor', - 'zoe_depth_image_processor', - 'spandrel_filter', -]); -export type FilterType = z.infer; -export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success; - -const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height); - -type ImageFilterData = { - type: T; - labelTKey: string; - descriptionTKey: string; - buildDefaults(baseModel?: BaseModelType): Extract; - buildNode(imageDTO: ImageWithDims, config: Extract): AnyInvocation; - validateConfig?(config: Extract): boolean; -}; - -/** - * A dict of ControlNet processors, including: - * - label translation key - * - description translation key - * - a builder to create default values for the config - * - a builder to create the node for the config - * - * TODO: Generate from the OpenAPI schema - */ -export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData } = { - canny_image_processor: { - type: 'canny_image_processor', - labelTKey: 'controlnet.canny', - descriptionTKey: 'controlnet.cannyDescription', - buildDefaults: (): CannyProcessorConfig => ({ - id: 'canny_image_processor', - type: 'canny_image_processor', - low_threshold: 100, - high_threshold: 200, - }), - buildNode: (imageDTO: ImageDTO, config: CannyProcessorConfig): S['CannyImageProcessorInvocation'] => ({ - ...config, - type: 'canny_image_processor', - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - color_map_image_processor: { - type: 'color_map_image_processor', - labelTKey: 'controlnet.colorMap', - descriptionTKey: 'controlnet.colorMapDescription', - buildDefaults: (): ColorMapProcessorConfig => ({ - id: 'color_map_image_processor', - type: 'color_map_image_processor', - color_map_tile_size: 64, - }), - buildNode: (imageDTO: ImageDTO, config: ColorMapProcessorConfig): S['ColorMapImageProcessorInvocation'] => ({ - ...config, - type: 'color_map_image_processor', - image: { image_name: imageDTO.image_name }, - }), - }, - content_shuffle_image_processor: { - type: 'content_shuffle_image_processor', - labelTKey: 'controlnet.contentShuffle', - descriptionTKey: 'controlnet.contentShuffleDescription', - buildDefaults: (baseModel: BaseModelType): ContentShuffleProcessorConfig => ({ - id: 'content_shuffle_image_processor', - type: 'content_shuffle_image_processor', - h: baseModel === 'sdxl' ? 1024 : 512, - w: baseModel === 'sdxl' ? 1024 : 512, - f: baseModel === 'sdxl' ? 512 : 256, - }), - buildNode: ( - imageDTO: ImageDTO, - config: ContentShuffleProcessorConfig - ): S['ContentShuffleImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - depth_anything_image_processor: { - type: 'depth_anything_image_processor', - labelTKey: 'controlnet.depthAnything', - descriptionTKey: 'controlnet.depthAnythingDescription', - buildDefaults: (): DepthAnythingProcessorConfig => ({ - id: 'depth_anything_image_processor', - type: 'depth_anything_image_processor', - model_size: 'small_v2', - }), - buildNode: ( - imageDTO: ImageDTO, - config: DepthAnythingProcessorConfig - ): S['DepthAnythingImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - resolution: minDim(imageDTO), - }), - }, - hed_image_processor: { - type: 'hed_image_processor', - labelTKey: 'controlnet.hed', - descriptionTKey: 'controlnet.hedDescription', - buildDefaults: (): HedProcessorConfig => ({ - id: 'hed_image_processor', - type: 'hed_image_processor', - scribble: false, - }), - buildNode: (imageDTO: ImageDTO, config: HedProcessorConfig): S['HedImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - lineart_anime_image_processor: { - type: 'lineart_anime_image_processor', - labelTKey: 'controlnet.lineartAnime', - descriptionTKey: 'controlnet.lineartAnimeDescription', - buildDefaults: (): LineartAnimeProcessorConfig => ({ - id: 'lineart_anime_image_processor', - type: 'lineart_anime_image_processor', - }), - buildNode: ( - imageDTO: ImageDTO, - config: LineartAnimeProcessorConfig - ): S['LineartAnimeImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - lineart_image_processor: { - type: 'lineart_image_processor', - labelTKey: 'controlnet.lineart', - descriptionTKey: 'controlnet.lineartDescription', - buildDefaults: (): LineartProcessorConfig => ({ - id: 'lineart_image_processor', - type: 'lineart_image_processor', - coarse: false, - }), - buildNode: (imageDTO: ImageDTO, config: LineartProcessorConfig): S['LineartImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - mediapipe_face_processor: { - type: 'mediapipe_face_processor', - labelTKey: 'controlnet.mediapipeFace', - descriptionTKey: 'controlnet.mediapipeFaceDescription', - buildDefaults: (): MediapipeFaceProcessorConfig => ({ - id: 'mediapipe_face_processor', - type: 'mediapipe_face_processor', - max_faces: 1, - min_confidence: 0.5, - }), - buildNode: (imageDTO: ImageDTO, config: MediapipeFaceProcessorConfig): S['MediapipeFaceProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - midas_depth_image_processor: { - type: 'midas_depth_image_processor', - labelTKey: 'controlnet.depthMidas', - descriptionTKey: 'controlnet.depthMidasDescription', - buildDefaults: (): MidasDepthProcessorConfig => ({ - id: 'midas_depth_image_processor', - type: 'midas_depth_image_processor', - a_mult: 2, - bg_th: 0.1, - }), - buildNode: (imageDTO: ImageDTO, config: MidasDepthProcessorConfig): S['MidasDepthImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - mlsd_image_processor: { - type: 'mlsd_image_processor', - labelTKey: 'controlnet.mlsd', - descriptionTKey: 'controlnet.mlsdDescription', - buildDefaults: (): MlsdProcessorConfig => ({ - id: 'mlsd_image_processor', - type: 'mlsd_image_processor', - thr_d: 0.1, - thr_v: 0.1, - }), - buildNode: (imageDTO: ImageDTO, config: MlsdProcessorConfig): S['MlsdImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - normalbae_image_processor: { - type: 'normalbae_image_processor', - labelTKey: 'controlnet.normalBae', - descriptionTKey: 'controlnet.normalBaeDescription', - buildDefaults: (): NormalbaeProcessorConfig => ({ - id: 'normalbae_image_processor', - type: 'normalbae_image_processor', - }), - buildNode: (imageDTO: ImageDTO, config: NormalbaeProcessorConfig): S['NormalbaeImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - dw_openpose_image_processor: { - type: 'dw_openpose_image_processor', - labelTKey: 'controlnet.dwOpenpose', - descriptionTKey: 'controlnet.dwOpenposeDescription', - buildDefaults: (): DWOpenposeProcessorConfig => ({ - id: 'dw_openpose_image_processor', - type: 'dw_openpose_image_processor', - draw_body: true, - draw_face: false, - draw_hands: false, - }), - buildNode: (imageDTO: ImageDTO, config: DWOpenposeProcessorConfig): S['DWOpenposeImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - image_resolution: minDim(imageDTO), - }), - }, - pidi_image_processor: { - type: 'pidi_image_processor', - labelTKey: 'controlnet.pidi', - descriptionTKey: 'controlnet.pidiDescription', - buildDefaults: (): PidiProcessorConfig => ({ - id: 'pidi_image_processor', - type: 'pidi_image_processor', - scribble: false, - safe: false, - }), - buildNode: (imageDTO: ImageDTO, config: PidiProcessorConfig): S['PidiImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - detect_resolution: minDim(imageDTO), - image_resolution: minDim(imageDTO), - }), - }, - zoe_depth_image_processor: { - type: 'zoe_depth_image_processor', - labelTKey: 'controlnet.depthZoe', - descriptionTKey: 'controlnet.depthZoeDescription', - buildDefaults: (): ZoeDepthProcessorConfig => ({ - id: 'zoe_depth_image_processor', - type: 'zoe_depth_image_processor', - }), - buildNode: (imageDTO: ImageDTO, config: ZoeDepthProcessorConfig): S['ZoeDepthImageProcessorInvocation'] => ({ - ...config, - image: { image_name: imageDTO.image_name }, - }), - }, - spandrel_filter: { - type: 'spandrel_filter', - labelTKey: 'controlLayers.filter.spandrel.label', - descriptionTKey: 'controlLayers.filter.spandrel.description', - buildDefaults: (): SpandrelFilterConfig => ({ - id: 'spandrel_filter', - type: 'spandrel_filter', - model: null, - autoScale: false, - scale: 2, - }), - buildNode: ( - imageDTO: ImageDTO, - config: SpandrelFilterConfig - ): Invocation<'spandrel_image_to_image' | 'spandrel_image_to_image_autoscale'> => { - const { model, scale, autoScale } = config; - assert(model !== null); - if (autoScale) { - const node: Invocation<'spandrel_image_to_image_autoscale'> = { - id: config.id, - type: 'spandrel_image_to_image_autoscale', - image_to_image_model: model, - image: { image_name: imageDTO.image_name }, - scale, - }; - return node; - } else { - const node: Invocation<'spandrel_image_to_image'> = { - id: config.id, - type: 'spandrel_image_to_image', - image_to_image_model: model, - image: { image_name: imageDTO.image_name }, - }; - return node; - } - }, - validateConfig: (config): boolean => { - if (!config.model) { - return false; - } - return true; - }, - }, -} as const; - const zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'view', 'bbox', 'colorPicker']); export type Tool = z.infer; diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 6bbc0e15eb6..e4c70a929a7 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -1,6 +1,12 @@ +import { createSelector, type Selector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; +import type { RootState } from 'app/store/store'; import { useMemo } from 'react'; -import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import { + modelConfigsAdapterSelectors, + selectModelConfigsQuery, + useGetModelConfigsQuery, +} from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; import { isClipEmbedModelConfig, @@ -54,3 +60,46 @@ export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig); export const useVAEModels = buildModelsHook(isVAEModelConfig); export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig); + +// const buildModelsSelector = +// (typeGuard: (config: AnyModelConfig) => config is T): Selector => +// (state) => { +// const result = selectModelConfigsQuery(state); +// if (!result.data) { +// return EMPTY_ARRAY; +// } +// return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard); +// }; +// export const selectSDMainModels = buildModelsSelector(isNonRefinerNonFluxMainModelConfig); +// export const selectMainModels = buildModelsSelector(isNonRefinerMainModelConfig); +// export const selectNonSDXLMainModels = buildModelsSelector(isNonSDXLMainModelConfig); +// export const selectRefinerModels = buildModelsSelector(isRefinerMainModelModelConfig); +// export const selectFluxModels = buildModelsSelector(isFluxMainModelModelConfig); +// export const selectSDXLModels = buildModelsSelector(isSDXLMainModelModelConfig); +// export const selectLoRAModels = buildModelsSelector(isLoRAModelConfig); +// export const selectControlNetAndT2IAdapterModels = buildModelsSelector(isControlNetOrT2IAdapterModelConfig); +// export const selectControlNetModels = buildModelsSelector(isControlNetModelConfig); +// export const selectT2IAdapterModels = buildModelsSelector(isT2IAdapterModelConfig); +// export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfig); +// export const selectClipEmbedModels = buildModelsSelector(isClipEmbedModelConfig); +// export const selectSpandrelImageToImageModels = buildModelsSelector(isSpandrelImageToImageModelConfig); +// export const selectIPAdapterModels = buildModelsSelector(isIPAdapterModelConfig); +// export const selectEmbeddingModels = buildModelsSelector(isTIModelConfig); +// export const selectVAEModels = buildModelsSelector(isVAEModelConfig); +// export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig); + +export const buildSelectModelConfig = ( + key: string, + typeGuard: (config: AnyModelConfig) => config is T +): Selector => + createSelector(selectModelConfigsQuery, (result) => { + if (!result.data) { + return null; + } + return ( + modelConfigsAdapterSelectors + .selectAll(result.data) + .filter(typeGuard) + .find((m) => m.key === key) ?? null + ); + });