Skip to content

Commit

Permalink
Merge branch 'main' into model-manager-ui-30
Browse files Browse the repository at this point in the history
  • Loading branch information
blessedcoolant committed Jun 18, 2023
2 parents 9fda21c + a11946f commit 91016d8
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 86 deletions.
3 changes: 2 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@
"general": "General",
"generation": "Generation",
"ui": "User Interface",
"availableSchedulers": "Available Schedulers"
"favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
},
"toast": {
"serverError": "Server Error",
Expand Down
57 changes: 42 additions & 15 deletions invokeai/frontend/web/src/app/constants.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
// TODO: use Enums?
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';

export const SCHEDULERS = [
'ddim',
'lms',
'lms_k',
// zod needs the array to be `as const` to infer the type correctly
// this is the source of the `SchedulerParam` type, which is generated by zod
export const SCHEDULER_NAMES_AS_CONST = [
'euler',
'euler_k',
'euler_a',
'deis',
'ddim',
'ddpm',
'dpmpp_2s',
'dpmpp_2s_k',
'dpmpp_2m',
'dpmpp_2m_k',
'heun',
'kdpm_2',
'kdpm_2_a',
'deis',
'ddpm',
'lms',
'pndm',
'heun',
'heun_k',
'unipc',
'euler_k',
'dpmpp_2s_k',
'dpmpp_2m_k',
'heun_k',
'lms_k',
'euler_a',
'kdpm_2_a',
] as const;

export type Scheduler = (typeof SCHEDULERS)[number];
export const DEFAULT_SCHEDULER_NAME = 'euler';

export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];

export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
euler: 'Euler',
deis: 'DEIS',
ddim: 'DDIM',
ddpm: 'DDPM',
dpmpp_2s: 'DPM++ 2S',
dpmpp_2m: 'DPM++ 2M',
heun: 'Heun',
kdpm_2: 'KDPM 2',
lms: 'LMS',
pndm: 'PNDM',
unipc: 'UniPC',
euler_k: 'Euler Karras',
dpmpp_2s_k: 'DPM++ 2S Karras',
dpmpp_2m_k: 'DPM++ 2M Karras',
heun_k: 'Heun Karras',
lms_k: 'LMS Karras',
euler_a: 'Euler Ancestral',
kdpm_2_a: 'KDPM 2 Ancestral',
};

export type Scheduler = (typeof SCHEDULER_NAMES)[number];

// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { memo } from 'react';

type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
};

const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
searchable={searchable}
styles={() => ({
label: {
color: 'var(--invokeai-colors-base-300)',
fontWeight: 'normal',
},
searchInput: {
'::placeholder': {
color: 'var(--invokeai-colors-base-700)',
},
},
input: {
backgroundColor: 'var(--invokeai-colors-base-900)',
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
'&:focus-within': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
},
value: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
button: {
color: 'var(--invokeai-colors-base-100)',
},
'&:hover': {
backgroundColor: 'var(--invokeai-colors-base-700)',
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
borderColor: 'var(--invokeai-colors-base-700)',
},
item: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-200)',
padding: 6,
'&[data-hovered]': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
'&[data-active]': {
backgroundColor: 'var(--invokeai-colors-base-750)',
'&:hover': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
},
'&[data-selected]': {
color: 'var(--invokeai-colors-base-50)',
backgroundColor: 'var(--invokeai-colors-accent-650)',
fontWeight: 600,
'&:hover': {
backgroundColor: 'var(--invokeai-colors-accent-600)',
},
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: 'var(--invokeai-colors-base-100)',
},
},
})}
{...rest}
/>
</Tooltip>
);
};

export default memo(IAIMantineMultiSelect);
Original file line number Diff line number Diff line change
@@ -1,43 +1,48 @@
import { createSelector } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const allSchedulers: string[] = ui.schedulers
.slice()
.sort((a, b) => a.localeCompare(b));
const { scheduler } = generation;
const { favoriteSchedulers: enabledSchedulers } = ui;

const data = SCHEDULER_NAMES.map((schedulerName) => ({
value: schedulerName,
label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam],
group: enabledSchedulers.includes(schedulerName)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));

return {
scheduler: generation.scheduler,
allSchedulers,
scheduler,
data,
};
},
defaultSelectorOptions
);

const ParamScheduler = () => {
const { allSchedulers, scheduler } = useAppSelector(selector);

const dispatch = useAppDispatch();
const { t } = useTranslation();
const { scheduler, data } = useAppSelector(selector);

const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(setScheduler(v as Scheduler));
dispatch(setScheduler(v as SchedulerParam));
},
[dispatch]
);
Expand All @@ -46,7 +51,7 @@ const ParamScheduler = () => {
<IAIMantineSelect
label={t('parameters.scheduler')}
value={scheduler}
data={allSchedulers}
data={data}
onChange={handleChange}
/>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ModelSelect from 'features/system/components/ModelSelect';
import { memo } from 'react';
import ParamScheduler from './ParamScheduler';

const ParamSchedulerAndModel = () => {
return (
<Flex gap={3} w="full">
<Box w="16rem">
<Box w="20rem">
<ParamScheduler />
</Box>
<Box w="full">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es';
import { ImageDTO } from 'services/api';
Expand All @@ -19,6 +17,7 @@ import {
StrengthParam,
WidthParam,
} from './parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';

export interface GenerationState {
cfgScale: CfgScaleParam;
Expand Down Expand Up @@ -50,7 +49,6 @@ export interface GenerationState {
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
currentModelType: ModelLoaderTypes;
shouldUseSeamless: boolean;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
Expand All @@ -65,7 +63,7 @@ export const initialGenerationState: GenerationState = {
perlin: 0,
positivePrompt: '',
negativePrompt: '',
scheduler: 'euler',
scheduler: DEFAULT_SCHEDULER_NAME,
seamBlur: 16,
seamSize: 96,
seamSteps: 30,
Expand All @@ -85,7 +83,6 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: '',
currentModelType: 'sd1_model_loader',
shouldUseSeamless: false,
seamlessXAxis: true,
seamlessYAxis: true,
Expand Down Expand Up @@ -136,7 +133,7 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload;
},
setScheduler: (state, action: PayloadAction<Scheduler>) => {
setScheduler: (state, action: PayloadAction<SchedulerParam>) => {
state.scheduler = action.payload;
},
setSeed: (state, action: PayloadAction<number>) => {
Expand Down Expand Up @@ -220,9 +217,6 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload;
},
setCurrentModelType: (state, action: PayloadAction<ModelLoaderTypes>) => {
state.currentModelType = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
Expand Down Expand Up @@ -283,7 +277,6 @@ export const {
setVerticalSymmetrySteps,
initialImageChanged,
modelSelected,
setCurrentModelType,
setShouldUseNoiseSettings,
setSeamless,
setSeamlessXAxis,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
import { NUMPY_RAND_MAX, SCHEDULER_NAMES_AS_CONST } from 'app/constants';
import { z } from 'zod';

/**
Expand Down Expand Up @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
/**
* Zod schema for scheduler parameter
*/
export const zScheduler = z.enum(SCHEDULERS);
export const zScheduler = z.enum(SCHEDULER_NAMES_AS_CONST);
/**
* Type alias for scheduler parameter, inferred from its zod schema
*/
Expand Down
Loading

0 comments on commit 91016d8

Please sign in to comment.