Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo custom attention processor to support other attention types #6550

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cd2dccf
Redo attention processor to support other attention types
StAlKeR7779 Jun 27, 2024
9f40c2d
Remove xformers and normal attention
StAlKeR7779 Jul 27, 2024
1ab8276
Fix file name
StAlKeR7779 Jul 27, 2024
d430e4c
Merge branch 'main' into stalker7779/new_attention_processor
StAlKeR7779 Jul 27, 2024
89c37c3
Sync fixes
StAlKeR7779 Jul 27, 2024
e9cc750
Update app config
StAlKeR7779 Jul 28, 2024
4b6d613
Remove remaining references to xformers
StAlKeR7779 Jul 29, 2024
d5fa938
Run api regen
StAlKeR7779 Jul 30, 2024
5a9cc04
Small rearrangement
StAlKeR7779 Aug 1, 2024
be84746
Add assert check
StAlKeR7779 Aug 1, 2024
bf2f798
Fix bad generation on slice_size not factor of heads count
StAlKeR7779 Aug 2, 2024
91cc89a
Use invoke slice_size values, to have less confusion
StAlKeR7779 Aug 2, 2024
719daeb
Add torch-sdp scale parameter support(added in torch 2.1)
StAlKeR7779 Aug 2, 2024
a16fa31
Test implementation of sliced attention using torch-sdp
StAlKeR7779 Aug 2, 2024
7ffceaa
Fix slice_size handling
StAlKeR7779 Aug 2, 2024
c7e7103
Revert "Fix bad generation on slice_size not factor of heads count"
StAlKeR7779 Aug 3, 2024
302dc9f
Return normal attention, change slicing logic, remove old attention code
StAlKeR7779 Aug 3, 2024
18fc36d
Suggested changes
StAlKeR7779 Aug 4, 2024
6bad046
Merge branch 'main' into stalker7779/new_attention_processor
StAlKeR7779 Aug 4, 2024
f44e0cd
Update config docstring
StAlKeR7779 Aug 4, 2024
9618b6e
Suggested changes
StAlKeR7779 Aug 6, 2024
09aef43
Restore xformers
StAlKeR7779 Aug 7, 2024
37dfab7
Small fixes
StAlKeR7779 Aug 7, 2024
192fba4
Rewrite sliced attention, more optimizations(batched torch-sdp for ol…
StAlKeR7779 Aug 19, 2024
0b1ff8f
Remove redundant alignment in batched torch-sdp execution, add comments
StAlKeR7779 Aug 19, 2024
3d19cac
Suggested changes
StAlKeR7779 Aug 20, 2024
b947129
Edit comments
StAlKeR7779 Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
fi &&\

# xformers + triton fails to install on arm64
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
pip install $extra_index_url_arg -e ".[xformers]"; \
else \
pip install $extra_index_url_arg -e "."; \
fi
pip install $extra_index_url_arg -e ".";

# #### Build the Web UI ------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
in
{
devShells.${system} = rec {
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; };
default = develop;
};
};
Expand Down
4 changes: 2 additions & 2 deletions installer/lib/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]:
url = "https://download.pytorch.org/whl/cpu"
elif device.value == "cuda":
# CUDA uses the default PyPi index
optional_modules = "[xformers,onnx-cuda]"
optional_modules = "[onnx-cuda]"
elif OS == "Windows":
if device.value == "cuda":
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-cuda]"
optional_modules = "[onnx-cuda]"
elif device.value == "cpu":
# CPU uses the default PyPi index, no optional modules
pass
Expand Down
8 changes: 1 addition & 7 deletions invokeai/app/api/routers/app_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import version
from pathlib import Path
from platform import python_version
from typing import Optional
Expand Down Expand Up @@ -56,7 +56,6 @@ class AppDependencyVersions(BaseModel):
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")


class AppConfig(BaseModel):
Expand All @@ -75,10 +74,6 @@ async def get_version() -> AppVersion:

@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
Expand All @@ -92,7 +87,6 @@ async def get_app_deps() -> AppDependencyVersions:
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)


Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
Expand Down Expand Up @@ -810,7 +810,7 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
attention_processor_cls=CustomAttnProcessor,
),
unet=None,
scheduler=scheduler,
Expand Down
52 changes: 47 additions & 5 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Literal, Optional

import psutil
import torch
import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
Expand All @@ -28,11 +29,11 @@
DEFAULT_VRAM_CACHE = 0.25
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
CONFIG_SCHEMA_VERSION = "4.0.3"


def get_default_ram_cache_size() -> float:
Expand Down Expand Up @@ -107,7 +108,7 @@ class InvokeAIAppConfig(BaseSettings):
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_type: Attention type.<br>Valid values: `auto`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
Expand Down Expand Up @@ -181,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
# GENERATION
sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.")
attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.")
attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".')
attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size')
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
Expand Down Expand Up @@ -433,6 +434,44 @@ def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str
return parsed_config_dict


def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.2 config dictionary to a v4.0.3 config dictionary.

Args:
config_dict: A dictionary of settings from a v4.0.2 config file.

Returns:
An config dict with the settings migrated to v4.0.3.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
attention_type = parsed_config_dict.get("attention_type", None)

# now attention_slice_size means enabling slicing attention
if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
del parsed_config_dict["attention_slice_size"]

# xformers attention removed, on mps better works normal attention
if attention_type == "xformers":
if torch.backends.mps.is_available():
parsed_config_dict["attention_type"] = "normal"
else:
parsed_config_dict["attention_type"] = "torch-sdp"

# slicing attention now enabled by `attention_slice_size`
if attention_type == "sliced":
if torch.backends.mps.is_available():
parsed_config_dict["attention_type"] = "normal"
else:
parsed_config_dict["attention_type"] = "torch-sdp"

# if no attention_slise_size in config, use balanced as default option
if "attention_slice_size" not in parsed_config_dict:
parsed_config_dict["attention_slice_size"] = "balanced"

parsed_config_dict["schema_version"] = "4.0.3"
return parsed_config_dict


def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.

Expand All @@ -458,6 +497,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.2":
migrated = True
loaded_config_dict = migrate_v4_0_2_to_4_0_3_config_dict(loaded_config_dict)

if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
Expand Down
91 changes: 7 additions & 84 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
from __future__ import annotations

import math
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import einops
import PIL.Image
import psutil
import torch
import torchvision.transforms as T
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


Expand Down Expand Up @@ -167,66 +161,6 @@ def __init__(

self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)

def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = get_config()
if config.attention_type == "xformers":
self.enable_xformers_memory_efficient_attention()
return
elif config.attention_type == "sliced":
slice_size = config.attention_slice_size
if slice_size == "auto":
slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
slice_size = "auto"
self.enable_attention_slicing(slice_size=slice_size)
return
elif config.attention_type == "normal":
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")

# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
if is_xformers_available():
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return

if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
* latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
else:
self.disable_attention_slicing()

def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")

Expand Down Expand Up @@ -321,8 +255,6 @@ def latents_from_embeddings(
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)

self._adjust_memory_efficient_attention(latents)

# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
Expand All @@ -347,23 +279,14 @@ def latents_from_embeddings(
is_gradient_mask=is_gradient_mask,
)

use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
attn_ctx = nullcontext()

if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
ip_adapters: Optional[List[UNetIPAdapterData]] = None
if ip_adapter_data is not None:
ip_adapters = [
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data
]

with attn_ctx:
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model):
callback(
PipelineIntermediateState(
step=-1,
Expand Down
Loading