Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add control_mode parameter to ControlNet #3535

Merged
merged 18 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fd71502
First pass at ControlNet "guess mode" implementation.
GreggHelt2 Jun 11, 2023
9e0e26f
Moving from ControlNet guess_mode to separate booleans for cfg_inject…
GreggHelt2 Jun 13, 2023
8b7fac7
First pass at ControlNet "guess mode" implementation.
GreggHelt2 Jun 11, 2023
8495764
Moving from ControlNet guess_mode to separate booleans for cfg_inject…
GreggHelt2 Jun 13, 2023
de3e6cd
Switched over to ControlNet control_mode with 4 options: balanced, mo…
GreggHelt2 Jun 14, 2023
a8e0490
Merge branch 'feat/controlnet-control-modes' of https://github.com/in…
GreggHelt2 Jun 14, 2023
cfd49e3
Removing vestigial comments.
GreggHelt2 Jun 14, 2023
5cd0e90
Renamed ControlNet control_mode option "even_more_control" to "unbala…
GreggHelt2 Jun 14, 2023
43419ac
Merge branch 'main' into feat/controlnet-control-modes
blessedcoolant Jun 14, 2023
eb7047b
chore: Rebuild WebAPI
blessedcoolant Jun 14, 2023
6c53abc
feat: Add ControlMode to Linear UI
blessedcoolant Jun 14, 2023
6b8e88a
Merge branch 'main' into feat/controlnet-control-modes
blessedcoolant Jun 14, 2023
4ca325e
chore: Rebuild API
blessedcoolant Jun 14, 2023
c5faffc
Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/contro…
GreggHelt2 Jun 25, 2023
4d4b5b5
Merge branch 'main' into feat/controlnet-control-modes
blessedcoolant Jun 25, 2023
132829c
fix(ui): fix path of generated schema types
psychedelicious Jun 25, 2023
11378a9
chore(ui): regen api schema
psychedelicious Jun 25, 2023
57e7197
fix(ui): add missing ControlNetInvocation type; tidy schema-derived t…
psychedelicious Jun 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
from builtins import float, bool

import numpy as np
from typing import Literal, Optional, Union, List
Expand Down Expand Up @@ -94,6 +94,7 @@
]

CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]

class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
Expand All @@ -104,6 +105,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")

@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
Expand Down Expand Up @@ -144,11 +147,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
# fmt: on

class Config(InvocationConfig):
Expand All @@ -166,14 +169,14 @@ class Config(InvocationConfig):
}

def invoke(self, context: InvocationContext) -> ControlOutput:

return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
),
)

Expand Down
10 changes: 4 additions & 6 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,14 @@ def prep_control_data(
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
Expand Down Expand Up @@ -341,12 +336,15 @@ def prep_control_data(
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
Expand Down
76 changes: 50 additions & 26 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,12 @@ def __call__(
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")


@dataclass(frozen=True)
class ConditioningData:
Expand Down Expand Up @@ -599,48 +601,68 @@ def step(

# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)

# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None

if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")

first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)

if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)

if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight

# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])

if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
Expand All @@ -653,11 +675,11 @@ def step(

# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,
t,
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
conditioning_data.guidance_scale,
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
Expand Down Expand Up @@ -962,6 +984,7 @@ def prepare_control_image(
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):

if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -992,6 +1015,7 @@ def prepare_control_image(
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image
2 changes: 1 addition & 1 deletion invokeai/frontend/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/schema.d.ts -t",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
},
"settings": {
"models": "Models",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetAdded,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa';

import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import { ChevronUpIcon } from '@chakra-ui/icons';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import IAISwitch from 'common/components/IAISwitch';
import { ChevronUpIcon } from '@chakra-ui/icons';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';

const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };

Expand All @@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
Expand Down Expand Up @@ -137,48 +139,51 @@ const ControlNet = (props: ControlNetProps) => {
</Flex>
{isEnabled && (
<>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
h: isExpanded ? 28 : 24,
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
flexDir: 'column',
gap: 3,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
)}
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
</Flex>

{isExpanded && (
<>
<Box mt={2}>
Expand Down
Loading
Loading