Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo custom attention processor to support other attention types #6550

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cd2dccf
Redo attention processor to support other attention types
StAlKeR7779 Jun 27, 2024
9f40c2d
Remove xformers and normal attention
StAlKeR7779 Jul 27, 2024
1ab8276
Fix file name
StAlKeR7779 Jul 27, 2024
d430e4c
Merge branch 'main' into stalker7779/new_attention_processor
StAlKeR7779 Jul 27, 2024
89c37c3
Sync fixes
StAlKeR7779 Jul 27, 2024
e9cc750
Update app config
StAlKeR7779 Jul 28, 2024
4b6d613
Remove remaining references to xformers
StAlKeR7779 Jul 29, 2024
d5fa938
Run api regen
StAlKeR7779 Jul 30, 2024
5a9cc04
Small rearrangement
StAlKeR7779 Aug 1, 2024
be84746
Add assert check
StAlKeR7779 Aug 1, 2024
bf2f798
Fix bad generation on slice_size not factor of heads count
StAlKeR7779 Aug 2, 2024
91cc89a
Use invoke slice_size values, to have less confusion
StAlKeR7779 Aug 2, 2024
719daeb
Add torch-sdp scale parameter support(added in torch 2.1)
StAlKeR7779 Aug 2, 2024
a16fa31
Test implementation of sliced attention using torch-sdp
StAlKeR7779 Aug 2, 2024
7ffceaa
Fix slice_size handling
StAlKeR7779 Aug 2, 2024
c7e7103
Revert "Fix bad generation on slice_size not factor of heads count"
StAlKeR7779 Aug 3, 2024
302dc9f
Return normal attention, change slicing logic, remove old attention code
StAlKeR7779 Aug 3, 2024
18fc36d
Suggested changes
StAlKeR7779 Aug 4, 2024
6bad046
Merge branch 'main' into stalker7779/new_attention_processor
StAlKeR7779 Aug 4, 2024
f44e0cd
Update config docstring
StAlKeR7779 Aug 4, 2024
9618b6e
Suggested changes
StAlKeR7779 Aug 6, 2024
09aef43
Restore xformers
StAlKeR7779 Aug 7, 2024
37dfab7
Small fixes
StAlKeR7779 Aug 7, 2024
192fba4
Rewrite sliced attention, more optimizations(batched torch-sdp for ol…
StAlKeR7779 Aug 19, 2024
0b1ff8f
Remove redundant alignment in batched torch-sdp execution, add comments
StAlKeR7779 Aug 19, 2024
3d19cac
Suggested changes
StAlKeR7779 Aug 20, 2024
b947129
Edit comments
StAlKeR7779 Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
fi &&\

pip install $extra_index_url_arg -e ".";
# xformers + triton fails to install on arm64
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
pip install $extra_index_url_arg -e ".[xformers]"; \
else \
pip install $extra_index_url_arg -e "."; \
fi

# #### Build the Web UI ------------------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/installation/020_INSTALL_MANUAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```

- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.

!!! example "Install with `xformers`"

```bash
pip install "InvokeAI[xformers]" --use-pep517
```

1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:

=== "Linux/macOS"
Expand Down
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
in
{
devShells.${system} = rec {
develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; };
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
default = develop;
};
};
Expand Down
4 changes: 2 additions & 2 deletions installer/lib/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]:
url = "https://download.pytorch.org/whl/cpu"
elif device.value == "cuda":
# CUDA uses the default PyPi index
optional_modules = "[onnx-cuda]"
optional_modules = "[xformers,onnx-cuda]"
elif OS == "Windows":
if device.value == "cuda":
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[onnx-cuda]"
optional_modules = "[xformers,onnx-cuda]"
elif device.value == "cpu":
# CPU uses the default PyPi index, no optional modules
pass
Expand Down
8 changes: 7 additions & 1 deletion invokeai/app/api/routers/app_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
from enum import Enum
from importlib.metadata import version
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from platform import python_version
from typing import Optional
Expand Down Expand Up @@ -56,6 +56,7 @@ class AppDependencyVersions(BaseModel):
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")


class AppConfig(BaseModel):
Expand All @@ -74,6 +75,10 @@ async def get_version() -> AppVersion:

@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
Expand All @@ -87,6 +92,7 @@ async def get_app_deps() -> AppDependencyVersions:
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)


Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DEFAULT_VRAM_CACHE = 0.25
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
Expand Down Expand Up @@ -449,8 +449,8 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str
if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
del parsed_config_dict["attention_slice_size"]

# xformers attention removed, sliced moved to attention_slice_size
if attention_type in ["sliced", "xformers"]:
# sliced moved to attention_slice_size
if attention_type == "sliced":
parsed_config_dict["attention_type"] = "auto"

parsed_config_dict["schema_version"] = "4.0.3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.utils.import_utils import is_xformers_available

from invokeai.app.services.config.config_default import get_config
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
from invokeai.backend.util.devices import TorchDevice

if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


@dataclass
class IPAdapterAttentionWeights:
Expand All @@ -23,7 +30,9 @@ class IPAdapterAttentionWeights:
class CustomAttnProcessor:
"""A custom implementation of attention processor that supports additional Invoke features.
This implementation is based on
AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732)
SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113)
AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
Supported custom features:
- IP-Adapter
Expand Down Expand Up @@ -53,6 +62,9 @@ def __init__(
if self.slice_size == "auto":
self.slice_size = self._select_slice_size()

if self.attention_type == "xformers" and xformers is None:
raise ImportError("xformers attention requires xformers module to be installed.")

def _select_attention_type(self) -> str:
device = TorchDevice.choose_torch_device()
# On some mps system normal attention still faster than torch-sdp, on others - on par
Expand All @@ -61,7 +73,14 @@ def _select_attention_type(self) -> str:
# Adreitz: 260.868s vs 226.638s
if device.type == "mps":
return "normal"
else: # cuda, cpu
elif device.type == "cuda":
# Flash Attention is supported from sm80 compute capability onwards in PyTorch
# https://pytorch.org/blog/accelerated-pytorch-2/
if torch.cuda.get_device_capability("cuda")[0] < 8 and xformers is not None:
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
return "xformers"
else:
return "torch-sdp"
else: # cpu
return "torch-sdp"

def _select_slice_size(self) -> str:
Expand Down Expand Up @@ -262,6 +281,8 @@ def run_attention(
attn_call = self.run_attention_sdp
elif self.attention_type == "normal":
attn_call = self.run_attention_normal
elif self.attention_type == "xformers":
attn_call = self.run_attention_xformers
else:
raise Exception(f"Unknown attention type: {self.attention_type}")

Expand Down Expand Up @@ -291,6 +312,35 @@ def run_attention_normal(

return hidden_states

def run_attention_xformers(
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
self,
attn: Attention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()

if attention_mask is not None:
# expand our mask's singleton query_length dimension:
# [batch*heads, 1, key_length] ->
# [batch*heads, query_length, key_length]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_length, key_length]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
attention_mask = attention_mask.expand(-1, query.shape[1], -1)

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

return hidden_states

def run_attention_sdp(
self,
attn: Attention,
Expand Down Expand Up @@ -355,6 +405,10 @@ def run_attention_sliced(
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
del attn_slice
elif self.attention_type == "xformers":
hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention(
query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale
)
elif self.attention_type == "torch-sdp":
if attn_mask_slice is not None:
attn_mask_slice = attn_mask_slice.unsqueeze(0)
Expand Down
46 changes: 46 additions & 0 deletions invokeai/backend/util/hotfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,49 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0):


diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward

try:
import xformers

xformers_available = True
except Exception:
xformers_available = False


if xformers_available:
# TODO: remove when fixed in diffusers
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention

def new_memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias=None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op=None,
):
# diffusers not align shape to 8, which is required by xformers
if attn_bias is not None and type(attn_bias) is torch.Tensor:
orig_size = attn_bias.shape[-1]
new_size = ((orig_size + 7) // 8) * 8
aligned_attn_bias = torch.zeros(
(attn_bias.shape[0], attn_bias.shape[1], new_size),
device=attn_bias.device,
dtype=attn_bias.dtype,
)
aligned_attn_bias[:, :, :orig_size] = attn_bias
attn_bias = aligned_attn_bias[:, :, :orig_size]

return _xformers_memory_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
op=op,
)

xformers.ops.memory_efficient_attention = new_memory_efficient_attention
5 changes: 5 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,11 @@ export type components = {
* @description transformers version
*/
transformers: string;
/**
* Xformers
* @description xformers version
*/
xformers: string | null;
};
/**
* AppVersion
Expand Down
12 changes: 12 additions & 0 deletions invokeai/version/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,15 @@

__app_id__ = "invoke-ai/InvokeAI"
__app_name__ = "InvokeAI"


def _ignore_xformers_triton_message_on_windows():
import logging

logging.getLogger("xformers").addFilter(
lambda record: "A matching Triton is not available" not in record.getMessage()
)


# In order to be effective, this needs to happen before anything could possibly import xformers.
_ignore_xformers_triton_message_on_windows()
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ dependencies = [
]

[project.optional-dependencies]
"xformers" = [
# Core generation dependencies, pinned for reproducible builds.
"xformers==0.0.25post1; sys_platform!='darwin'",
# Auxiliary dependencies, pinned only if necessary.
"triton; sys_platform=='linux'",
]
"onnx" = ["onnxruntime"]
"onnx-cuda" = ["onnxruntime-gpu"]
"onnx-directml" = ["onnxruntime-directml"]
Expand Down
3 changes: 3 additions & 0 deletions scripts/invokeai-web.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

import logging
import os

from invokeai.app.run_app import run_app

logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())


def main():
# Change working directory to the repo root
Expand Down
Loading