Skip to content

Commit

Permalink
feat(ui): add dynamic prompts to t2i tab
Browse files Browse the repository at this point in the history
- add param accordion for dynamic prompts
- update graphs
  • Loading branch information
psychedelicious committed Jun 26, 2023
1 parent 9cfac41 commit 6390af2
Show file tree
Hide file tree
Showing 29 changed files with 479 additions and 576 deletions.
4 changes: 4 additions & 0 deletions invokeai/frontend/web/src/app/store/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';

import { listenerMiddleware } from './middleware/listenerMiddleware';

Expand All @@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer,
boards: boardsReducer,
// session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer,
};

Expand All @@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system',
'ui',
'controlNet',
'dynamicPrompts',
// 'boards',
// 'hotkeys',
// 'config',
Expand Down Expand Up @@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;
8 changes: 8 additions & 0 deletions invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number;
coarseStep: number;
};
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
};
};

Expand Down
10 changes: 9 additions & 1 deletion invokeai/frontend/web/src/common/components/IAISwitch.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps}
>
{label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';

const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;

return { isEnabled };
},
defaultSelectorOptions
);

const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);

const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);

return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
</Flex>
</IAICollapse>
);
};

export default ParamDynamicPromptsCollapse;
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';

const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;

return { combinatorial };
},
defaultSelectorOptions
);

const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();

const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);

return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};

export default ParamDynamicPromptsCombinatorial;
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';

const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;

return { maxPrompts, min, sliderMax, inputMax };
},
defaultSelectorOptions
);

const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const dispatch = useAppDispatch();

const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);

const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);

return (
<IAISlider
label="Max Prompts"
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};

export default ParamDynamicPromptsMaxPrompts;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//
50 changes: 50 additions & 0 deletions invokeai/frontend/web/src/features/dynamicPrompts/store/slice.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';

export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}

export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};

const initialState: DynamicPromptsState = initialDynamicPromptsState;

export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});

export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;

export default dynamicPromptsSlice.reducer;

export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es';
import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
Expand All @@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage)))
);

// Add ControlNet
if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) {
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
Expand All @@ -36,10 +36,9 @@ export const addControlNetToLinearGraph = (
});
}

forEach(controlNets, (controlNet) => {
validControlNets.forEach((controlNet) => {
const {
controlNetId,
isEnabled,
controlImage,
processedControlImage,
beginStepPct,
Expand All @@ -50,11 +49,6 @@ export const addControlNetToLinearGraph = (
weight,
} = controlNet;

if (!isEnabled) {
// Skip disabled ControlNets
return;
}

const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
Expand Down Expand Up @@ -82,7 +76,8 @@ export const addControlNetToLinearGraph = (

graph.nodes[controlNetNode.id] = controlNetNode;

if (size(controlNets) > 1) {
if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
Expand All @@ -91,6 +86,7 @@ export const addControlNetToLinearGraph = (
},
});
} else {
// otherwise, link directly to the base node
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
Expand Down
Loading

0 comments on commit 6390af2

Please sign in to comment.