Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere #6490

Merged
merged 14 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,16 @@ def __init__(
self.device, dtype=self.dtype
)

def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype

self._image_proj_model.to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)

def calc_size(self):
# workaround for circular import
Expand Down
51 changes: 29 additions & 22 deletions invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)


# TODO: find and debug lora/locon with bias
Expand Down Expand Up @@ -109,14 +110,15 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
super().to(device=device, dtype=dtype, non_blocking=non_blocking)

self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)


class LoHALayer(LoRALayerBase):
Expand Down Expand Up @@ -169,18 +171,19 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)

self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)


class LoKRLayer(LoRALayerBase):
Expand Down Expand Up @@ -265,6 +268,7 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

Expand All @@ -273,19 +277,19 @@ def to(
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)


class FullLayer(LoRALayerBase):
Expand Down Expand Up @@ -319,10 +323,11 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)


class IA3Layer(LoRALayerBase):
Expand Down Expand Up @@ -358,11 +363,12 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)


AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
Expand All @@ -388,10 +394,11 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)

def calc_size(self) -> int:
model_size = 0
Expand Down Expand Up @@ -514,7 +521,7 @@ def from_checkpoint(
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()

layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
Expand Down
16 changes: 8 additions & 8 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def apply_lora_unet(
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
Expand All @@ -83,7 +83,7 @@ def apply_lora_text_encoder(
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield

Expand All @@ -95,7 +95,7 @@ def apply_lora(
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]:
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.

Expand Down Expand Up @@ -139,12 +139,12 @@ def apply_lora(
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
layer.to(device=torch.device("cpu"), non_blocking=True)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
Expand All @@ -153,15 +153,15 @@ def apply_lora(
layer_weight = layer_weight.reshape(module.weight.shape)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)

yield # wait for context manager exit

finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)

@classmethod
@contextmanager
Expand Down
10 changes: 10 additions & 0 deletions invokeai/backend/onnx/onnx_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers

Expand Down Expand Up @@ -188,6 +189,15 @@ def __call__(self, **kwargs):
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)

# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

# compatability with diffusers load code
@classmethod
def from_pretrained(
Expand Down
18 changes: 16 additions & 2 deletions invokeai/backend/raw_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
that adds additional methods and attributes.
"""

from abc import ABC, abstractmethod
from typing import Optional

class RawModel:
"""Base class for 'Raw' model wrappers."""
import torch


class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""

@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
12 changes: 12 additions & 0 deletions invokeai/backend/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ def from_checkpoint(

return result

def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)


class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
Expand Down
Loading