Skip to content

Commit

Permalink
Merge branch 'main' into lstein/feat/load-one-file
Browse files Browse the repository at this point in the history
  • Loading branch information
lstein committed Jun 27, 2024
2 parents 9037b6f + aba1608 commit 8888641
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
3 changes: 2 additions & 1 deletion invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import Self

from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.devices import TorchDevice

from .raw_model import RawModel

Expand Down Expand Up @@ -521,7 +522,7 @@ def from_checkpoint(
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()

layer.to(device=device, dtype=dtype, non_blocking=True)
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
model.layers[layer_key] = layer

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,11 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
Expand Down
16 changes: 11 additions & 5 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.util.devices import TorchDevice

from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
Expand Down Expand Up @@ -139,12 +140,15 @@ def apply_lora(
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"), non_blocking=True)
layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
Expand All @@ -153,15 +157,17 @@ def apply_lora(
layer_weight = layer_weight.reshape(module.weight.shape)

assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))

yield # wait for context manager exit

finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)

@classmethod
@contextmanager
Expand Down
16 changes: 16 additions & 0 deletions invokeai/backend/util/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def torch_dtype(device: torch.device) -> torch.dtype:
class TorchDevice:
"""Abstraction layer for torch devices."""

CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")

@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
Expand Down Expand Up @@ -108,3 +112,15 @@ def empty_cache(cls) -> None:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]

@staticmethod
def get_non_blocking(to_device: torch.device) -> bool:
"""Return the non_blocking flag to be used when moving a tensor to a given device.
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
When moving _from_ MPS, we can use non-blocking operations.
See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
"""
return False if to_device.type == "mps" else True

0 comments on commit 8888641

Please sign in to comment.