Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial Loading PR1: Tidy ModelCache #7492

Merged
merged 9 commits into from
Dec 24, 2024
1 change: 0 additions & 1 deletion docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,6 @@ the in-memory loaded model:
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |

### get_model_by_key(key, [submodel]) -> LoadedModel

Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
NodeExecutionStatsSummary,
)
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats

# Size of 1GB in bytes.
GB = 2**30
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/services/model_load/model_load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache


class ModelLoadServiceBase(ABC):
Expand All @@ -24,7 +24,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo

@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""

@abstractmethod
Expand Down
13 changes: 6 additions & 7 deletions invokeai/app/services/model_load/model_load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
Expand All @@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
Expand All @@ -45,7 +45,7 @@ def start(self, invoker: Invoker) -> None:
self._invoker = invoker

@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""
return self._ram_cache

Expand Down Expand Up @@ -78,9 +78,8 @@ def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
except IndexError:
pass

Expand Down Expand Up @@ -109,5 +108,5 @@ def diffusers_load_directory(directory: Path) -> AnyModel:
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
self._ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from invokeai.app.services.model_load.model_load_default import ModelLoadService
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase

# This registers the subclasses that implement loaders of specific model types
Expand Down
54 changes: 23 additions & 31 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
Expand All @@ -18,19 +17,17 @@
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache


@dataclass
class LoadedModelWithoutConfig:
"""
Context manager object that mediates transfer from RAM<->VRAM.
"""Context manager object that mediates transfer from RAM<->VRAM.

This is a context manager object that has two distinct APIs:

1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
Expand All @@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
```

2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
other model patchers to return the model to its unpatched state without expensive copy and restore operations.

Example:
```
Expand All @@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
image = vae.decode(latents)[0]
```

The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
do not have a state_dict, in which case this value will be None.
"""

_locker: ModelLockerBase
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
self._cache_record = cache_record
self._cache = cache

def __enter__(self) -> AnyModel:
"""Context entry."""
self._locker.lock()
self._cache.lock(self._cache_record.key)
return self.model

def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
self._locker.unlock()
self._cache.unlock(self._cache_record.key)

@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
self._cache.lock(self._cache_record.key)
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._locker.unlock()
self._cache.unlock(self._cache_record.key)

@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._locker.model
return self._cache_record.model


@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""

config: Optional[AnyModelConfig] = None
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
super().__init__(cache_record=cache_record, cache=cache)
self.config = config


# TODO(MM2):
Expand All @@ -110,7 +102,7 @@ def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
pass
Expand Down Expand Up @@ -138,6 +130,6 @@ def get_size_fs(

@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
pass
24 changes: 10 additions & 14 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
Expand All @@ -28,7 +29,7 @@ def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
self._app_config = app_config
Expand All @@ -54,22 +55,22 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")

with skip_torch_weight_init():
locker = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
cache_record = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)

@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
return self._ram_cache

def _get_model_path(self, config: AnyModelConfig) -> Path:
model_base = self._app_config.models_path
return (model_base / config.path).resolve()

def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
except IndexError:
pass

Expand All @@ -78,16 +79,11 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
loaded_model = self._load_model(config, submodel_type)

self._ram_cache.put(
config.key,
submodel_type=submodel_type,
get_model_cache_key(config.key, submodel_type),
model=loaded_model,
)

return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=stats_name,
)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)

def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
Expand Down
6 changes: 0 additions & 6 deletions invokeai/backend/model_manager/load/model_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
"""Init file for ModelCache."""

from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401

_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
50 changes: 50 additions & 0 deletions invokeai/backend/model_manager/load/model_cache/cache_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

import torch


@dataclass
class CacheRecord:
"""
Elements of the cache:

key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM

Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.

The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""

key: str
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0

def lock(self) -> None:
"""Lock this record."""
self._locks += 1

def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0

@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
15 changes: 15 additions & 0 deletions invokeai/backend/model_manager/load/model_cache/cache_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass, field
from typing import Dict


@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""

hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
Loading
Loading