Skip to content

Commit

Permalink
Improve RAM<->VRAM memory copy performance in LoRA patching and elsew…
Browse files Browse the repository at this point in the history
…here (#6490)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <[email protected]>

* documentation fixes requested during penultimate review

* add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases

* fix ruff errors

* prevent crash on non-cuda-enabled systems

---------

Co-authored-by: Lincoln Stein <[email protected]>
Co-authored-by: Kent Keirsey <[email protected]>
Co-authored-by: Ryan Dick <[email protected]>
  • Loading branch information
4 people authored Jun 13, 2024
1 parent 568a484 commit a3cb5da
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 38 deletions.
11 changes: 7 additions & 4 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,16 @@ def __init__(
self.device, dtype=self.dtype
)

def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype

self._image_proj_model.to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)

def calc_size(self):
# workaround for circular import
Expand Down
51 changes: 29 additions & 22 deletions invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)


# TODO: find and debug lora/locon with bias
Expand Down Expand Up @@ -109,14 +110,15 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
super().to(device=device, dtype=dtype, non_blocking=non_blocking)

self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)


class LoHALayer(LoRALayerBase):
Expand Down Expand Up @@ -169,18 +171,19 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)

self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)


class LoKRLayer(LoRALayerBase):
Expand Down Expand Up @@ -265,6 +268,7 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

Expand All @@ -273,19 +277,19 @@ def to(
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)

if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)


class FullLayer(LoRALayerBase):
Expand Down Expand Up @@ -319,10 +323,11 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)


class IA3Layer(LoRALayerBase):
Expand Down Expand Up @@ -358,11 +363,12 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)


AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
Expand All @@ -388,10 +394,11 @@ def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)

def calc_size(self) -> int:
model_size = 0
Expand Down Expand Up @@ -514,7 +521,7 @@ def from_checkpoint(
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()

layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
Expand Down
16 changes: 8 additions & 8 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def apply_lora_unet(
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
Expand All @@ -83,7 +83,7 @@ def apply_lora_text_encoder(
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield

Expand All @@ -95,7 +95,7 @@ def apply_lora(
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]:
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
Expand Down Expand Up @@ -139,12 +139,12 @@ def apply_lora(
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
layer.to(device=torch.device("cpu"), non_blocking=True)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
Expand All @@ -153,15 +153,15 @@ def apply_lora(
layer_weight = layer_weight.reshape(module.weight.shape)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)

yield # wait for context manager exit

finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)

@classmethod
@contextmanager
Expand Down
10 changes: 10 additions & 0 deletions invokeai/backend/onnx/onnx_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers

Expand Down Expand Up @@ -188,6 +189,15 @@ def __call__(self, **kwargs):
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)

# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

# compatability with diffusers load code
@classmethod
def from_pretrained(
Expand Down
18 changes: 16 additions & 2 deletions invokeai/backend/raw_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
that adds additional methods and attributes.
"""

from abc import ABC, abstractmethod
from typing import Optional

class RawModel:
"""Base class for 'Raw' model wrappers."""
import torch


class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""

@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
12 changes: 12 additions & 0 deletions invokeai/backend/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ def from_checkpoint(

return result

def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)


class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
Expand Down

0 comments on commit a3cb5da

Please sign in to comment.