Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): add dynamic prompts to t2i tab #3588

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions invokeai/frontend/web/src/app/store/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';

import { listenerMiddleware } from './middleware/listenerMiddleware';

Expand All @@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer,
boards: boardsReducer,
// session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer,
};

Expand All @@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system',
'ui',
'controlNet',
'dynamicPrompts',
// 'boards',
// 'hotkeys',
// 'config',
Expand Down Expand Up @@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;
8 changes: 8 additions & 0 deletions invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number;
coarseStep: number;
};
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
};
};

Expand Down
10 changes: 9 additions & 1 deletion invokeai/frontend/web/src/common/components/IAISwitch.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps}
>
{label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';

const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;

return { isEnabled };
},
defaultSelectorOptions
);

const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);

const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);

return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
</Flex>
</IAICollapse>
);
};

export default ParamDynamicPromptsCollapse;
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';

const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;

return { combinatorial };
},
defaultSelectorOptions
);

const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();

const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);

return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};

export default ParamDynamicPromptsCombinatorial;
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';

const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;

return { maxPrompts, min, sliderMax, inputMax };
},
defaultSelectorOptions
);

const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const dispatch = useAppDispatch();

const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);

const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);

return (
<IAISlider
label="Max Prompts"
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};

export default ParamDynamicPromptsMaxPrompts;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';

export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}

export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};

const initialState: DynamicPromptsState = initialDynamicPromptsState;

export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});

export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;

export default dynamicPromptsSlice.reducer;

export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es';
import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
Expand All @@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage)))
);

// Add ControlNet
if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) {
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
Expand All @@ -36,10 +36,9 @@ export const addControlNetToLinearGraph = (
});
}

forEach(controlNets, (controlNet) => {
validControlNets.forEach((controlNet) => {
const {
controlNetId,
isEnabled,
controlImage,
processedControlImage,
beginStepPct,
Expand All @@ -50,11 +49,6 @@ export const addControlNetToLinearGraph = (
weight,
} = controlNet;

if (!isEnabled) {
// Skip disabled ControlNets
return;
}

const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
Expand Down Expand Up @@ -82,7 +76,8 @@ export const addControlNetToLinearGraph = (

graph.nodes[controlNetNode.id] = controlNetNode;

if (size(controlNets) > 1) {
if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
Expand All @@ -91,6 +86,7 @@ export const addControlNetToLinearGraph = (
},
});
} else {
// otherwise, link directly to the base node
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
Expand Down
Loading