Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into psyche/mm-rebase-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Jun 7, 2024
2 parents dc13493 + 6d067e5 commit 8b34976
Show file tree
Hide file tree
Showing 42 changed files with 1,659 additions and 828 deletions.
50 changes: 37 additions & 13 deletions docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1366,30 +1366,54 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |

Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
### get_model_by_key(key, [submodel]) -> LoadedModel

The `get_model_by_key()` method will retrieve the model using its
unique database key. For example:

loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))

`get_model_by_key()` may raise any of the following exceptions:

* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model

### Using the Loaded Model in Inference

`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
the model. Use it like this:

```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```

`get_model_by_key()` may raise any of the following exceptions:
The object returned by the LoadedModel context manager is an
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.

In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:

```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```

Since not all models have state dicts, the `state_dict` return value
can be None.


* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model

### Emitting model loading events

When the `context` argument is passed to `load_model_*()`, it will
Expand Down
17 changes: 13 additions & 4 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
Expand Down Expand Up @@ -172,9 +176,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
Expand Down
119 changes: 73 additions & 46 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
Expand Down Expand Up @@ -672,54 +672,52 @@ def prep_control_data(

return controlnet_data

def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]

single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))

return image_prompts

def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
ip_adapters: List[IPAdapterField],
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None

# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if not isinstance(ip_adapter, list):
ip_adapter = [ip_adapter]

if len(ip_adapter) == 0:
return None

) -> Optional[List[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
ip_adapter_data_list = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model)
)

image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]

single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]

# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))

mask = single_ip_adapter.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)

ip_adapter_data_list.append(
Expand All @@ -734,7 +732,7 @@ def prep_ip_adapter_data(
)
)

return ip_adapter_data_list
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None

def run_t2i_adapters(
self,
Expand Down Expand Up @@ -855,6 +853,16 @@ def init_scheduler(
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
#
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
# - DDIMScheduler
# - DDPMScheduler
# - DPMSolverMultistepScheduler
# - EulerAncestralDiscreteScheduler
# - EulerDiscreteScheduler
# - KDPM2AncestralDiscreteScheduler
# - LCMScheduler
# - TCDScheduler
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
Expand Down Expand Up @@ -912,6 +920,20 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
do_classifier_free_guidance=True,
)

ip_adapters: List[IPAdapterField] = []
if self.ip_adapter is not None:
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if isinstance(self.ip_adapter, list):
ip_adapters = self.ip_adapter
else:
ip_adapters = [self.ip_adapter]

# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
# a series of image conditioning embeddings. This is being done here rather than in the
# big model context below in order to use less VRAM on low-VRAM systems.
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)

# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)

Expand All @@ -930,11 +952,15 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info as unet,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
Expand Down Expand Up @@ -970,7 +996,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
ip_adapters=ip_adapters,
image_prompts=image_prompts,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
Expand Down Expand Up @@ -1285,7 +1312,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.2",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
Expand Down Expand Up @@ -1364,7 +1391,7 @@ def slerp(
TorchDevice.empty_cache()

name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)


# The Crop Latents node was copied from @skunkworxdark's implementation here:
Expand Down
52 changes: 50 additions & 2 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
"""

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Generator, Optional, Tuple

import torch

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
Expand All @@ -21,7 +24,42 @@

@dataclass
class LoadedModelWithoutConfig:
"""Context manager object that mediates transfer from RAM<->VRAM."""
"""
Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
"""

_locker: ModelLockerBase

Expand All @@ -34,6 +72,16 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
self._locker.unlock()

@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
finally:
self._locker.unlock()

@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass

@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass

@property
@abstractmethod
def model(self) -> AnyModel:
Expand All @@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""

key: str
Expand Down
Loading

0 comments on commit 8b34976

Please sign in to comment.