Skip to content

Commit

Permalink
Merge branch 'main' into update-prevention-messaging
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious authored Jun 26, 2024
2 parents fe1536d + cd9dfef commit 94669b9
Show file tree
Hide file tree
Showing 5 changed files with 730 additions and 250 deletions.
169 changes: 102 additions & 67 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings

Expand All @@ -65,6 +66,9 @@ def get_scheduler(
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
Expand Down Expand Up @@ -182,8 +186,8 @@ def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
raise ValueError("cfg_scale must be greater than 1")
return v

@staticmethod
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
Expand All @@ -203,8 +207,9 @@ def _get_text_embeddings_and_masks(

return text_embeddings, text_embeddings_masks

@staticmethod
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
Expand All @@ -228,8 +233,8 @@ def _preprocess_regional_prompt_mask(
resized_mask = tf(mask)
return resized_mask

@staticmethod
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
Expand Down Expand Up @@ -279,7 +284,9 @@ def _concat_regional_text_embeddings(
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)

cur_text_embedding_len += text_embedding_info.embeds.shape[1]
Expand All @@ -301,60 +308,63 @@ def _concat_regional_text_embeddings(
)
return BasicConditioningInfo(embeds=text_embedding), regions

@staticmethod
def get_conditioning_data(
self,
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)

cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)

if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"

conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data

@staticmethod
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
Expand All @@ -377,38 +387,38 @@ def __init__(self) -> None:
requires_safety_checker=False,
)

@staticmethod
def prep_control_data(
self,
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
control_list = None
if control_list is None:
raise ValueError(f"Unexpected control_input type: {type(control_input)}")

if len(control_list) == 0:
return None
# After above handling, any control that is not None should now be of type list[ControlField].

# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR

controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)

# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
Expand All @@ -429,7 +439,7 @@ def prep_control_data(
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
Expand Down Expand Up @@ -583,15 +593,15 @@ def run_t2i_adapters(

# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
Expand All @@ -615,9 +625,7 @@ def init_scheduler(
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order

init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order

scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
Expand All @@ -639,7 +647,7 @@ def init_scheduler(
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})

return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
return timesteps, scheduler_step_kwargs

def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
Expand All @@ -656,31 +664,52 @@ def prep_inpaint_mask(

return 1 - mask, masked_latents, self.denoise_mask.gradient

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed = None
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if self.noise is not None:
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed

if self.latents is not None:
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed

if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)

if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
raise ValueError("'latents' or 'noise' must be provided!")

if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")

if seed is None:
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0

return seed, noise, latents

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)

# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
Expand All @@ -706,7 +735,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)

# get the unet's config so that we can pass the base to dispatch_progress()
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

def step_callback(state: PipelineIntermediateState) -> None:
Expand Down Expand Up @@ -754,7 +783,15 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

controlnet_data = self.prep_control_data(
Expand All @@ -776,7 +813,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
dtype=unet.dtype,
)

num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
timesteps, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
Expand All @@ -788,13 +825,11 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
Expand Down
Loading

0 comments on commit 94669b9

Please sign in to comment.