From 9cc1f20ad577bc17250958080fafd12e2830e296 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Apr 2024 23:26:48 -0400 Subject: [PATCH 01/45] add simplified model manager install API to InvocationContext --- .../app/services/shared/invocation_context.py | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index e533baf3bce..20b03bb28c2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,9 +1,10 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -414,6 +415,101 @@ def search_by_attrs( model_format=format, ) + def install_model( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + inplace: Optional[bool] = False, + timeout: Optional[int] = 0, + ) -> str: + """Install and register a model in the database. + + Args: + source: String source; see below + config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record. + access_token: Optional access token for remote sources. + inplace: If true, installs a local model in place rather than copying + it into the models directory + timeout: How long to wait on install (in seconds). A value of 0 (default) + blocks indefinitely + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + Returns: + Key to the newly installed model. + + May Raise: + ValueError -- bad source + UnknownModelException -- remote model not found + InvalidModelException -- what was retrieved from remote is not a model + TimeoutError -- model could not be installed within timeout + Exception -- another error condition + """ + installer = self._services.model_manager.install + job = installer.heuristic_import( + source=source, + config=config, + access_token=access_token, + inplace=inplace, + ) + installer.wait_for_job(job, timeout) + if job.errored: + raise Exception(job.error) + key: str = job.config_out.key + return key + + def download_and_cache_model( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + + Result: + Path of the downloaded model + + May Raise: + HTTPError + TimeoutError + """ + installer = self._services.model_manager.install + path: Path = installer.download_and_cache( + source=source, + access_token=access_token, + timeout=timeout, + ) + return path + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: From af1b57a01f7e7920c44f21b60d07c33ce2355de6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Apr 2024 23:26:48 -0400 Subject: [PATCH 02/45] add simplified model manager install API to InvocationContext --- .../app/services/shared/invocation_context.py | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9994d663e5e..176303b055f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,9 +1,10 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -426,6 +427,101 @@ def search_by_attrs( model_format=format, ) + def install_model( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + inplace: Optional[bool] = False, + timeout: Optional[int] = 0, + ) -> str: + """Install and register a model in the database. + + Args: + source: String source; see below + config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record. + access_token: Optional access token for remote sources. + inplace: If true, installs a local model in place rather than copying + it into the models directory + timeout: How long to wait on install (in seconds). A value of 0 (default) + blocks indefinitely + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + Returns: + Key to the newly installed model. + + May Raise: + ValueError -- bad source + UnknownModelException -- remote model not found + InvalidModelException -- what was retrieved from remote is not a model + TimeoutError -- model could not be installed within timeout + Exception -- another error condition + """ + installer = self._services.model_manager.install + job = installer.heuristic_import( + source=source, + config=config, + access_token=access_token, + inplace=inplace, + ) + installer.wait_for_job(job, timeout) + if job.errored: + raise Exception(job.error) + key: str = job.config_out.key + return key + + def download_and_cache_model( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + + Result: + Path of the downloaded model + + May Raise: + HTTPError + TimeoutError + """ + installer = self._services.model_manager.install + path: Path = installer.download_and_cache( + source=source, + access_token=access_token, + timeout=timeout, + ) + return path + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: From df5ebdbc4f605f1d078e47a4f235a1e70111c0cd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 Apr 2024 00:55:21 -0400 Subject: [PATCH 03/45] add invocation_context.load_ckpt_from_url() method --- invokeai/app/invocations/upscale.py | 13 ++-- .../app/services/shared/invocation_context.py | 67 +++++++++++++++++-- .../image_util/realesrgan/realesrgan.py | 6 +- .../backend/model_manager/load/load_base.py | 2 +- .../model_manager/load/load_default.py | 7 +- .../load/model_cache/model_cache_base.py | 1 - .../load/model_cache/model_cache_default.py | 3 +- .../app/services/model_load/test_load_api.py | 57 ++++++++++++++++ 8 files changed, 131 insertions(+), 25 deletions(-) create mode 100644 tests/app/services/model_load/test_load_api.py diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index d687384fcbd..e09618960e8 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -from pathlib import Path from typing import Literal import cv2 @@ -11,7 +10,6 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -56,7 +54,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: rrdbnet_model = None netscale = None - esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", @@ -99,16 +96,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput: context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") - - # Downloads the ESRGAN model if it doesn't already exist - download_with_progress_bar( - name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + loadnet = context.models.load_ckpt_from_url( + source=ESRGAN_MODEL_URLS[self.model_name], ) upscaler = RealESRGAN( scale=netscale, - model_path=esrgan_model_path, + loadnet=loadnet.model, model=rrdbnet_model, half=False, tile=self.tile_size, @@ -118,6 +112,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: This strips the alpha... is that okay? cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) upscaled_image = upscaler.upscale(cv2_image) + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") torch.cuda.empty_cache() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 20b03bb28c2..5cfe9c17a19 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,11 +1,14 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from picklescan.scanner import scan_file_path from PIL.Image import Image from pydantic.networks import AnyHttpUrl +from safetensors.torch import load_file as safetensors_load_file from torch import Tensor +from torch import load as torch_load from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -476,13 +479,14 @@ def install_model( key: str = job.config_out.key return key - def download_and_cache_model( + def download_and_cache_ckpt( self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None, timeout: Optional[int] = 0, ) -> Path: - """Download the model file located at source to the models cache and return its Path. + """ + Download the model file located at source to the models cache and return its Path. This can be used to single-file install models and other resources of arbitrary types which should not get registered with the database. If the model is already @@ -510,10 +514,65 @@ def download_and_cache_model( ) return path + def load_ckpt_from_url( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, + ) -> LoadedModel: + """ + Load and cache the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_model(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + ram_cache = self._services.model_manager.load.ram_cache + try: + return LoadedModel(_locker=ram_cache.get(key=str(source))) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + return torch_load(path, map_location="cpu") + + path = self.download_and_cache_ckpt(source, access_token, timeout) + if loader is None: + loader = ( + torch_load_file + if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + + raw_model = loader(path) + ram_cache.put(key=str(source), model=raw_model) + return LoadedModel(_locker=ram_cache.get(key=str(source))) + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """Gets the app's config. + """ + Gets the app's config. Returns: The app's config. diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index c06504b6085..7c4d90f5bd2 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -1,6 +1,5 @@ import math from enum import Enum -from pathlib import Path from typing import Any, Optional import cv2 @@ -11,6 +10,7 @@ from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.util.devices import choose_torch_device """ @@ -52,7 +52,7 @@ class RealESRGAN: def __init__( self, scale: int, - model_path: Path, + loadnet: AnyModel, model: RRDBNet, tile: int = 0, tile_pad: int = 10, @@ -67,8 +67,6 @@ def __init__( self.half = half self.device = choose_torch_device() - loadnet = torch.load(model_path, map_location=torch.device("cpu")) - # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index c336926aeac..41a36d7b51a 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -23,8 +23,8 @@ class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" - config: AnyModelConfig _locker: ModelLockerBase + config: Optional[AnyModelConfig] = None def __enter__(self) -> AnyModel: """Context entry.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 60cc1f5e6cb..8e08f065e19 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -95,7 +95,6 @@ def _convert_and_load( config.key, submodel_type=submodel_type, model=loaded_model, - size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get( @@ -126,9 +125,7 @@ def _do_convert( if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put( - config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel) - ) + self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) return getattr(pipeline, submodel_type.value) if submodel_type else pipeline def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index eb82f87cb22..3ffd6714a1c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -143,7 +143,6 @@ def put( self, key: str, model: T, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index b8312f619ad..b8fe19047c3 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger @@ -157,13 +158,13 @@ def put( self, key: str, model: AnyModel, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) if key in self._cached_models: return + size = calc_model_size_by_data(model) self.make_room(size) cache_record = CacheRecord(key, model, size) self._cached_models[key] = cache_record diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py new file mode 100644 index 00000000000..b99cc7d8242 --- /dev/null +++ b/tests/app/services/model_load/test_load_api.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context +from invokeai.backend.model_manager.load.load_base import LoadedModel +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +@pytest.fixture() +def mock_context( + mock_services: InvocationServices, + mm2_model_manager: ModelManagerServiceBase, +) -> InvocationContext: + mock_services.model_manager = mm2_model_manager + return build_invocation_context( + services=mock_services, + data=None, # type: ignore + cancel_event=None, # type: ignore + ) + + +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path): + downloaded_path = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path.is_file() + assert downloaded_path.exists() + assert downloaded_path.name == "test_embedding.safetensors" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache" + + downloaded_path_2 = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path == downloaded_path_2 + + +def test_download_and_load(mock_context: InvocationContext): + loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_1, LoadedModel) + + loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_2, LoadedModel) + + with loaded_model_1 as model_1, loaded_model_2 as model_2: + assert model_1 == model_2 + assert isinstance(model_1, dict) + + +def test_install_model(mock_context: InvocationContext): + key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors") + assert key is not None + model = mock_context.models.load(key) + assert model is not None + assert model.config.key == key From 41b909cbe3761f59c180430478dcdeb8efba579e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 Apr 2024 21:05:23 -0400 Subject: [PATCH 04/45] port dw_openpose, depth_anything, and lama processors to new model download scheme --- .../controlnet_image_processors.py | 45 ++++++++++--------- invokeai/app/invocations/infill.py | 18 ++++---- .../app/services/shared/invocation_context.py | 4 +- .../image_util/depth_anything/__init__.py | 37 +++++---------- .../image_util/dw_openpose/__init__.py | 5 ++- .../image_util/dw_openpose/wholebody.py | 29 ++++-------- .../backend/image_util/infill_methods/lama.py | 39 +++++++--------- 7 files changed, 72 insertions(+), 105 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index a49c910eeb1..12a2ae9c966 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # superclass just passes through image without processing return image @@ -148,7 +148,7 @@ def load_image(self, context: InvocationContext) -> Image.Image: def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(raw_image) + processed_image = self.run_processor(raw_image, context) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery @@ -189,7 +189,7 @@ def load_image(self, context: InvocationContext) -> Image.Image: # Keep alpha channel for Canny processing to detect edges of transparent areas return context.images.get_pil(self.image.image_name, "RGBA") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processed_image = get_canny_edges( image, self.low_threshold, @@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: hed_processor = HEDProcessor() processed_image = hed_processor.run( image, @@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: lineart_processor = LineartProcessor() processed_image = lineart_processor.run( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse @@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processor = LineartAnimeProcessor() processed_image = processor.run( image, @@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, @@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") - processed_image = normalbae_processor( + processed_image: Image.Image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -412,7 +413,7 @@ def run_processor(self, image): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -503,7 +504,7 @@ def tile_resample( np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img): + def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image: np_img = np.array(img, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, @@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - depth_anything_detector = DepthAnythingDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + depth_anything_detector = DepthAnythingDetector(context) depth_anything_detector.load_model(model_size=self.model_size) processed_image = depth_anything_detector(image=image, resolution=self.resolution) @@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - dw_openpose = DWOpenposeDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + dw_openpose = DWOpenposeDetector(context) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 418bc62fdc4..edee275e722 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") @abstractmethod - def infill(self, image: Image.Image) -> Image.Image: + def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image: """Infill the image with the specified method""" pass @@ -57,7 +57,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: return ImageOutput.build(context.images.get_dto(self.image.image_name)) # Perform Infill action - infilled_image = self.infill(input_image) + infilled_image = self.infill(input_image, context) # Create ImageDTO for Infilled Image infilled_image_dto = context.images.save(image=infilled_image) @@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation): description="The color to use to infill", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) @@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation): description="The seed to use for tile generation (omit for random)", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): output = infill_tile(image, seed=self.seed, tile_size=self.tile_size) return output.infilled @@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width / self.downscale) @@ -132,8 +132,8 @@ def infill(self, image: Image.Image): class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" - def infill(self, image: Image.Image): - lama = LaMA() + def infill(self, image: Image.Image, context: InvocationContext): + lama = LaMA(context) return lama(image) @@ -141,7 +141,7 @@ def infill(self, image: Image.Image): class CV2InfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using OpenCV Inpainting""" - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return cv2_inpaint(image) @@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation): description="The max threshold for color", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index e97d29d308b..0d27b2520ba 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -534,10 +534,10 @@ def load_ckpt_from_url( loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, ) -> LoadedModel: """ - Load and cache the model file located at the indicated URL. + Download, cache, and Load the model file located at the indicated URL. This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_model(). + by the provided URL and download it if needed using download_and_cache_ckpt(). It will then load the model into the RAM cache. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call safetensors.torch.load_file() or diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index ccac2ba9493..560d977b559 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,5 +1,4 @@ -import pathlib -from typing import Literal, Union +from typing import Literal, Optional, Union import cv2 import numpy as np @@ -10,7 +9,7 @@ from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.util.devices import choose_torch_device @@ -20,18 +19,9 @@ logger = InvokeAILogger.get_logger(config=config) DEPTH_ANYTHING_MODELS = { - "large": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", - }, - "base": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", - }, - "small": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vits14.pth", - }, + "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", } @@ -53,18 +43,14 @@ class DepthAnythingDetector: - def __init__(self) -> None: - self.model = None + def __init__(self, context: InvocationContext) -> None: + self.context = context + self.model: Optional[DPT_DINOv2] = None self.model_size: Union[Literal["large", "base", "small"], None] = None self.device = choose_torch_device() - def load_model(self, model_size: Literal["large", "base", "small"] = "small"): - DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] - download_with_progress_bar( - pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, - DEPTH_ANYTHING_MODELS[model_size]["url"], - DEPTH_ANYTHING_MODEL_PATH, - ) + def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: + depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) if not self.model or model_size != self.model_size: del self.model @@ -78,7 +64,8 @@ def load_model(self, model_size: Literal["large", "base", "small"] = "small"): case "large": self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) + assert self.model is not None + self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) self.model.eval() self.model.to(choose_torch_device()) diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index c258ef2c786..17ca0233c82 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -3,6 +3,7 @@ from controlnet_aux.util import resize_image from PIL import Image +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody @@ -39,8 +40,8 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self) -> None: - self.pose_estimation = Wholebody() + def __init__(self, context: InvocationContext) -> None: + self.pose_estimation = Wholebody(context) def __call__( self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 35d340640d3..3628b0abd55 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -4,44 +4,31 @@ import numpy as np import onnxruntime as ort +import torch from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.util.devices import choose_torch_device from .onnxdet import inference_detector from .onnxpose import inference_pose DWPOSE_MODELS = { - "yolox_l.onnx": { - "local": "any/annotators/dwpose/yolox_l.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - }, - "dw-ll_ucoco_384.onnx": { - "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", - }, + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", } config = get_config() class Wholebody: - def __init__(self): + def __init__(self, context: InvocationContext): device = choose_torch_device() - providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] + providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"] - DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] - download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) - - POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"] - download_with_progress_bar( - "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH - ) - - onnx_det = DET_MODEL_PATH - onnx_pose = POSE_MODEL_PATH + onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index fa354aeed1c..8c3f33efad1 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,4 +1,3 @@ -import gc from typing import Any import numpy as np @@ -6,9 +5,7 @@ from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import choose_torch_device +from invokeai.app.services.shared.invocation_context import InvocationContext def norm_img(np_img): @@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device): class LaMA: - def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = choose_torch_device() - model_location = get_config().models_path / "core/misc/lama/lama.pt" - - if not model_location.exists(): - download_with_progress_bar( - name="LaMa Inpainting Model", - url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - dest_path=model_location, - ) + def __init__(self, context: InvocationContext): + self._context = context - model = load_jit_model(model_location, device) + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: + loaded_model = self._context.models.load_ckpt_from_url( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=lambda path: load_jit_model(path, "cpu"), + ) image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -48,20 +41,18 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: mask = np.asarray(mask) mask = np.invert(mask) mask = norm_img(mask) - mask = (mask > 0) * 1 - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) - with torch.inference_mode(): - infilled_image = model(image, mask) + with loaded_model as model: + device = next(model.buffers()).device + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + + with torch.inference_mode(): + infilled_image = model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) - del model - gc.collect() - torch.cuda.empty_cache() - return infilled_image From 3ddd7ced49979229d143edc865d32c507332a120 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 14 Apr 2024 15:54:46 -0400 Subject: [PATCH 05/45] change names of convert and download caches and add migration script --- .../app/services/config/config_default.py | 9 +- .../model_install/model_install_default.py | 17 +++- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_10.py | 87 +++++++++++++++++++ invokeai/app/util/download_with_progress.py | 51 ----------- .../app/services/model_load/test_load_api.py | 2 +- 6 files changed, 111 insertions(+), 57 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py delete mode 100644 invokeai/app/util/download_with_progress.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index f453a56584c..4b5a2004be1 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings): patchmatch: Enable patchmatch inpaint code. models_dir: Path to the models directory. convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location. + download_cache_dir: Path to the directory that contains dynamically downloaded models. legacy_conf_dir: Path to directory of legacy checkpoint config files. db_dir: Path to InvokeAI databases directory. outputs_dir: Path to directory for outputs. @@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings): # PATHS models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") - convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") @@ -303,6 +305,11 @@ def convert_cache_path(self) -> Path: """Path to the converted cache models directory, resolved to an absolute path..""" return self._resolve(self.convert_cache_dir) + @property + def download_cache_path(self) -> Path: + """Path to the downloaded models directory, resolved to an absolute path..""" + return self._resolve(self.download_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory, resolved to an absolute path..""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20cfc1c4ff9..f1fbcdb7ba7 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -393,6 +393,11 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 rmtree(model_path) self.unregister(key) + @classmethod + def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + return app_config.download_cache_path / model_hash + def download_and_cache( self, source: Union[str, AnyHttpUrl], @@ -400,8 +405,7 @@ def download_and_cache( timeout: int = 0, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] - model_path = self._app_config.convert_cache_path / model_hash + model_path = self._download_cache_path(source, self._app_config) # We expect the cache directory to contain one and only one downloaded file. # We don't know the file's name in advance, as it is set by the download @@ -532,8 +536,13 @@ def on_model_found(model_path: Path) -> bool: if resolved_path in installed_model_paths: return True # Skip core models entirely - these aren't registered with the model manager. - if str(resolved_path).startswith(str(self.app_config.models_path / "core")): - return False + for special_directory in [ + self.app_config.models_path / "core", + self.app_config.convert_cache_dir, + self.app_config.download_cache_dir, + ]: + if resolved_path.is_relative_to(special_directory): + return False try: model_id = self.register_path(model_path) self._logger.info(f"Registered {model_path.name} with id {model_id}") diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 1eed0b44092..61f35a3b4ea 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) + migrator.register_migration(build_migration_10(app_config=config, logger=logger)) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py new file mode 100644 index 00000000000..df341a6f2e5 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -0,0 +1,87 @@ +import pathlib +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_install.model_install_default import ModelInstallService +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = { + # OpenPose + "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true": "any/annotators/dwpose/yolox_l.onnx", + "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitl14.pth", + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitb14.pth", + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true": "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt": "core/misc/lama/lama.pt", + # RealESRGAN upscale + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +} + + +class Migration10Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._rename_convert_cache() + self._migrate_downloaded_models_cache() + self._remove_unused_core_models() + + def _rename_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + configured_convert_dir = self._app_config.convert_cache_dir + configured_convert_path = self._app_config.convert_cache_path + # old convert dir was in use, and current convert dir has not been changed + if legacy_convert_path.exists() and configured_convert_dir == pathlib.Path("models/.convert_cache"): + self._logger.info( + f"Migrating legacy convert cache directory from {str(legacy_convert_path)} to {str(configured_convert_path)}" + ) + shutil.rmtree(configured_convert_path, ignore_errors=True) # shouldn't be needed, but just in case... + shutil.move(legacy_convert_path, configured_convert_path) + + def _migrate_downloaded_models_cache(self) -> None: + """Move used core models to modsl/.download_cache.""" + self._logger.info(f"Migrating legacy core models to {str(self._app_config.download_cache_path)}") + for url, legacy_dest in LEGACY_CORE_MODELS.items(): + legacy_dest_path = self._app_config.models_path / legacy_dest + if not legacy_dest_path.exists(): + continue + # this returns a unique directory path + new_path = ModelInstallService._download_cache_path(url, self._app_config) + new_path.mkdir(parents=True, exist_ok=True) + shutil.move(legacy_dest_path, new_path / legacy_dest_path.name) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_10(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 9 to 10. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + """ + migration_10 = Migration( + from_version=9, + to_version=10, + callback=Migration10Callback(app_config=app_config, logger=logger), + ) + + return migration_10 diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py deleted file mode 100644 index 97a2abb2f6f..00000000000 --- a/invokeai/app/util/download_with_progress.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from urllib import request - -from tqdm import tqdm - -from invokeai.backend.util.logging import InvokeAILogger - - -class ProgressBar: - """Simple progress bar for urllib.request.urlretrieve using tqdm.""" - - def __init__(self, model_name: str = "file"): - self.pbar = None - self.name = model_name - - def __call__(self, block_num: int, block_size: int, total_size: int): - if not self.pbar: - self.pbar = tqdm( - desc=self.name, - initial=0, - unit="iB", - unit_scale=True, - unit_divisor=1000, - total=total_size, - ) - self.pbar.update(block_size) - - -def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool: - """Download a file from a URL to a destination path, with a progress bar. - If the file already exists, it will not be downloaded again. - - Exceptions are not caught. - - Args: - name (str): Name of the file being downloaded. - url (str): URL to download the file from. - dest_path (Path): Destination path to save the file to. - - Returns: - bool: True if the file was downloaded, False if it already existed. - """ - if dest_path.exists(): - return False # already downloaded - - InvokeAILogger.get_logger().info(f"Downloading {name}...") - - dest_path.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url, dest_path, ProgressBar(name)) - - return True diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index b99cc7d8242..59b207f2ee5 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -29,7 +29,7 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) assert downloaded_path.is_file() assert downloaded_path.exists() assert downloaded_path.name == "test_embedding.safetensors" - assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" downloaded_path_2 = mock_context.models.download_and_cache_ckpt( "https://www.test.foo/download/test_embedding.safetensors" From 34438ce1af2da2a0c879a83e6fe2780a9d49911c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Apr 2024 23:26:48 -0400 Subject: [PATCH 06/45] add simplified model manager install API to InvocationContext --- .../app/services/shared/invocation_context.py | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9994d663e5e..176303b055f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,9 +1,10 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -426,6 +427,101 @@ def search_by_attrs( model_format=format, ) + def install_model( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + inplace: Optional[bool] = False, + timeout: Optional[int] = 0, + ) -> str: + """Install and register a model in the database. + + Args: + source: String source; see below + config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record. + access_token: Optional access token for remote sources. + inplace: If true, installs a local model in place rather than copying + it into the models directory + timeout: How long to wait on install (in seconds). A value of 0 (default) + blocks indefinitely + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + Returns: + Key to the newly installed model. + + May Raise: + ValueError -- bad source + UnknownModelException -- remote model not found + InvalidModelException -- what was retrieved from remote is not a model + TimeoutError -- model could not be installed within timeout + Exception -- another error condition + """ + installer = self._services.model_manager.install + job = installer.heuristic_import( + source=source, + config=config, + access_token=access_token, + inplace=inplace, + ) + installer.wait_for_job(job, timeout) + if job.errored: + raise Exception(job.error) + key: str = job.config_out.key + return key + + def download_and_cache_model( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + + Result: + Path of the downloaded model + + May Raise: + HTTPError + TimeoutError + """ + installer = self._services.model_manager.install + path: Path = installer.download_and_cache( + source=source, + access_token=access_token, + timeout=timeout, + ) + return path + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: From c140d3b1df3fb256f2baf789b7968c6f941b4ef0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 Apr 2024 00:55:21 -0400 Subject: [PATCH 07/45] add invocation_context.load_ckpt_from_url() method --- invokeai/app/invocations/upscale.py | 13 ++-- .../app/services/shared/invocation_context.py | 67 +++++++++++++++++-- .../image_util/realesrgan/realesrgan.py | 6 +- .../backend/model_manager/load/load_base.py | 2 +- .../model_manager/load/load_default.py | 7 +- .../load/model_cache/model_cache_base.py | 1 - .../load/model_cache/model_cache_default.py | 3 +- .../app/services/model_load/test_load_api.py | 57 ++++++++++++++++ 8 files changed, 131 insertions(+), 25 deletions(-) create mode 100644 tests/app/services/model_load/test_load_api.py diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index d687384fcbd..e09618960e8 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -from pathlib import Path from typing import Literal import cv2 @@ -11,7 +10,6 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -56,7 +54,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: rrdbnet_model = None netscale = None - esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", @@ -99,16 +96,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput: context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") - - # Downloads the ESRGAN model if it doesn't already exist - download_with_progress_bar( - name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + loadnet = context.models.load_ckpt_from_url( + source=ESRGAN_MODEL_URLS[self.model_name], ) upscaler = RealESRGAN( scale=netscale, - model_path=esrgan_model_path, + loadnet=loadnet.model, model=rrdbnet_model, half=False, tile=self.tile_size, @@ -118,6 +112,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: This strips the alpha... is that okay? cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) upscaled_image = upscaler.upscale(cv2_image) + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") torch.cuda.empty_cache() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 176303b055f..e97d29d308b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,11 +1,14 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from picklescan.scanner import scan_file_path from PIL.Image import Image from pydantic.networks import AnyHttpUrl +from safetensors.torch import load_file as safetensors_load_file from torch import Tensor +from torch import load as torch_load from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -488,13 +491,14 @@ def install_model( key: str = job.config_out.key return key - def download_and_cache_model( + def download_and_cache_ckpt( self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None, timeout: Optional[int] = 0, ) -> Path: - """Download the model file located at source to the models cache and return its Path. + """ + Download the model file located at source to the models cache and return its Path. This can be used to single-file install models and other resources of arbitrary types which should not get registered with the database. If the model is already @@ -522,10 +526,65 @@ def download_and_cache_model( ) return path + def load_ckpt_from_url( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, + ) -> LoadedModel: + """ + Load and cache the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_model(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + ram_cache = self._services.model_manager.load.ram_cache + try: + return LoadedModel(_locker=ram_cache.get(key=str(source))) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + return torch_load(path, map_location="cpu") + + path = self.download_and_cache_ckpt(source, access_token, timeout) + if loader is None: + loader = ( + torch_load_file + if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + + raw_model = loader(path) + ram_cache.put(key=str(source), model=raw_model) + return LoadedModel(_locker=ram_cache.get(key=str(source))) + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """Gets the app's config. + """ + Gets the app's config. Returns: The app's config. diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index c06504b6085..7c4d90f5bd2 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -1,6 +1,5 @@ import math from enum import Enum -from pathlib import Path from typing import Any, Optional import cv2 @@ -11,6 +10,7 @@ from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.util.devices import choose_torch_device """ @@ -52,7 +52,7 @@ class RealESRGAN: def __init__( self, scale: int, - model_path: Path, + loadnet: AnyModel, model: RRDBNet, tile: int = 0, tile_pad: int = 10, @@ -67,8 +67,6 @@ def __init__( self.half = half self.device = choose_torch_device() - loadnet = torch.load(model_path, map_location=torch.device("cpu")) - # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index c336926aeac..41a36d7b51a 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -23,8 +23,8 @@ class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" - config: AnyModelConfig _locker: ModelLockerBase + config: Optional[AnyModelConfig] = None def __enter__(self) -> AnyModel: """Context entry.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 6774fc29894..451770c0cbe 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -95,7 +95,6 @@ def _convert_and_load( config.key, submodel_type=submodel_type, model=loaded_model, - size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get( @@ -126,9 +125,7 @@ def _do_convert( if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put( - config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel) - ) + self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) return getattr(pipeline, submodel_type.value) if submodel_type else pipeline def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index a8c2dd3e92e..ec77bbe477c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -143,7 +143,6 @@ def put( self, key: str, model: T, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 2ba52d466ce..919a7c43968 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger @@ -157,13 +158,13 @@ def put( self, key: str, model: AnyModel, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) if key in self._cached_models: return + size = calc_model_size_by_data(model) self.make_room(size) cache_record = CacheRecord(key, model, size) self._cached_models[key] = cache_record diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py new file mode 100644 index 00000000000..b99cc7d8242 --- /dev/null +++ b/tests/app/services/model_load/test_load_api.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context +from invokeai.backend.model_manager.load.load_base import LoadedModel +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +@pytest.fixture() +def mock_context( + mock_services: InvocationServices, + mm2_model_manager: ModelManagerServiceBase, +) -> InvocationContext: + mock_services.model_manager = mm2_model_manager + return build_invocation_context( + services=mock_services, + data=None, # type: ignore + cancel_event=None, # type: ignore + ) + + +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path): + downloaded_path = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path.is_file() + assert downloaded_path.exists() + assert downloaded_path.name == "test_embedding.safetensors" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache" + + downloaded_path_2 = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path == downloaded_path_2 + + +def test_download_and_load(mock_context: InvocationContext): + loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_1, LoadedModel) + + loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_2, LoadedModel) + + with loaded_model_1 as model_1, loaded_model_2 as model_2: + assert model_1 == model_2 + assert isinstance(model_1, dict) + + +def test_install_model(mock_context: InvocationContext): + key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors") + assert key is not None + model = mock_context.models.load(key) + assert model is not None + assert model.config.key == key From 3ead827d61fb3c935e3396d43455c12b6b59018f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 Apr 2024 21:05:23 -0400 Subject: [PATCH 08/45] port dw_openpose, depth_anything, and lama processors to new model download scheme --- .../controlnet_image_processors.py | 45 ++++++++++--------- invokeai/app/invocations/infill.py | 18 ++++---- .../app/services/shared/invocation_context.py | 4 +- .../image_util/depth_anything/__init__.py | 37 +++++---------- .../image_util/dw_openpose/__init__.py | 5 ++- .../image_util/dw_openpose/wholebody.py | 29 ++++-------- .../backend/image_util/infill_methods/lama.py | 39 +++++++--------- 7 files changed, 72 insertions(+), 105 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index a49c910eeb1..12a2ae9c966 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # superclass just passes through image without processing return image @@ -148,7 +148,7 @@ def load_image(self, context: InvocationContext) -> Image.Image: def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(raw_image) + processed_image = self.run_processor(raw_image, context) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery @@ -189,7 +189,7 @@ def load_image(self, context: InvocationContext) -> Image.Image: # Keep alpha channel for Canny processing to detect edges of transparent areas return context.images.get_pil(self.image.image_name, "RGBA") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processed_image = get_canny_edges( image, self.low_threshold, @@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: hed_processor = HEDProcessor() processed_image = hed_processor.run( image, @@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: lineart_processor = LineartProcessor() processed_image = lineart_processor.run( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse @@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processor = LineartAnimeProcessor() processed_image = processor.run( image, @@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, @@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") - processed_image = normalbae_processor( + processed_image: Image.Image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -412,7 +413,7 @@ def run_processor(self, image): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -503,7 +504,7 @@ def tile_resample( np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img): + def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image: np_img = np.array(img, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, @@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - depth_anything_detector = DepthAnythingDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + depth_anything_detector = DepthAnythingDetector(context) depth_anything_detector.load_model(model_size=self.model_size) processed_image = depth_anything_detector(image=image, resolution=self.resolution) @@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - dw_openpose = DWOpenposeDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + dw_openpose = DWOpenposeDetector(context) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 418bc62fdc4..edee275e722 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") @abstractmethod - def infill(self, image: Image.Image) -> Image.Image: + def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image: """Infill the image with the specified method""" pass @@ -57,7 +57,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: return ImageOutput.build(context.images.get_dto(self.image.image_name)) # Perform Infill action - infilled_image = self.infill(input_image) + infilled_image = self.infill(input_image, context) # Create ImageDTO for Infilled Image infilled_image_dto = context.images.save(image=infilled_image) @@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation): description="The color to use to infill", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) @@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation): description="The seed to use for tile generation (omit for random)", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): output = infill_tile(image, seed=self.seed, tile_size=self.tile_size) return output.infilled @@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width / self.downscale) @@ -132,8 +132,8 @@ def infill(self, image: Image.Image): class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" - def infill(self, image: Image.Image): - lama = LaMA() + def infill(self, image: Image.Image, context: InvocationContext): + lama = LaMA(context) return lama(image) @@ -141,7 +141,7 @@ def infill(self, image: Image.Image): class CV2InfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using OpenCV Inpainting""" - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return cv2_inpaint(image) @@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation): description="The max threshold for color", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index e97d29d308b..0d27b2520ba 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -534,10 +534,10 @@ def load_ckpt_from_url( loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, ) -> LoadedModel: """ - Load and cache the model file located at the indicated URL. + Download, cache, and Load the model file located at the indicated URL. This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_model(). + by the provided URL and download it if needed using download_and_cache_ckpt(). It will then load the model into the RAM cache. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call safetensors.torch.load_file() or diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index ccac2ba9493..560d977b559 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,5 +1,4 @@ -import pathlib -from typing import Literal, Union +from typing import Literal, Optional, Union import cv2 import numpy as np @@ -10,7 +9,7 @@ from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.util.devices import choose_torch_device @@ -20,18 +19,9 @@ logger = InvokeAILogger.get_logger(config=config) DEPTH_ANYTHING_MODELS = { - "large": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", - }, - "base": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", - }, - "small": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vits14.pth", - }, + "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", } @@ -53,18 +43,14 @@ class DepthAnythingDetector: - def __init__(self) -> None: - self.model = None + def __init__(self, context: InvocationContext) -> None: + self.context = context + self.model: Optional[DPT_DINOv2] = None self.model_size: Union[Literal["large", "base", "small"], None] = None self.device = choose_torch_device() - def load_model(self, model_size: Literal["large", "base", "small"] = "small"): - DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] - download_with_progress_bar( - pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, - DEPTH_ANYTHING_MODELS[model_size]["url"], - DEPTH_ANYTHING_MODEL_PATH, - ) + def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: + depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) if not self.model or model_size != self.model_size: del self.model @@ -78,7 +64,8 @@ def load_model(self, model_size: Literal["large", "base", "small"] = "small"): case "large": self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) + assert self.model is not None + self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) self.model.eval() self.model.to(choose_torch_device()) diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index c258ef2c786..17ca0233c82 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -3,6 +3,7 @@ from controlnet_aux.util import resize_image from PIL import Image +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody @@ -39,8 +40,8 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self) -> None: - self.pose_estimation = Wholebody() + def __init__(self, context: InvocationContext) -> None: + self.pose_estimation = Wholebody(context) def __call__( self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 35d340640d3..3628b0abd55 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -4,44 +4,31 @@ import numpy as np import onnxruntime as ort +import torch from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.util.devices import choose_torch_device from .onnxdet import inference_detector from .onnxpose import inference_pose DWPOSE_MODELS = { - "yolox_l.onnx": { - "local": "any/annotators/dwpose/yolox_l.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - }, - "dw-ll_ucoco_384.onnx": { - "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", - }, + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", } config = get_config() class Wholebody: - def __init__(self): + def __init__(self, context: InvocationContext): device = choose_torch_device() - providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] + providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"] - DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] - download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) - - POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"] - download_with_progress_bar( - "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH - ) - - onnx_det = DET_MODEL_PATH - onnx_pose = POSE_MODEL_PATH + onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index fa354aeed1c..8c3f33efad1 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,4 +1,3 @@ -import gc from typing import Any import numpy as np @@ -6,9 +5,7 @@ from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import choose_torch_device +from invokeai.app.services.shared.invocation_context import InvocationContext def norm_img(np_img): @@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device): class LaMA: - def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = choose_torch_device() - model_location = get_config().models_path / "core/misc/lama/lama.pt" - - if not model_location.exists(): - download_with_progress_bar( - name="LaMa Inpainting Model", - url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - dest_path=model_location, - ) + def __init__(self, context: InvocationContext): + self._context = context - model = load_jit_model(model_location, device) + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: + loaded_model = self._context.models.load_ckpt_from_url( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=lambda path: load_jit_model(path, "cpu"), + ) image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -48,20 +41,18 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: mask = np.asarray(mask) mask = np.invert(mask) mask = norm_img(mask) - mask = (mask > 0) * 1 - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) - with torch.inference_mode(): - infilled_image = model(image, mask) + with loaded_model as model: + device = next(model.buffers()).device + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + + with torch.inference_mode(): + infilled_image = model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) - del model - gc.collect() - torch.cuda.empty_cache() - return infilled_image From fa6efac436a40d1ef890467dd338f78589185828 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 14 Apr 2024 15:54:46 -0400 Subject: [PATCH 09/45] change names of convert and download caches and add migration script --- .../app/services/config/config_default.py | 9 +- .../model_install/model_install_default.py | 17 +++- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_10.py | 87 +++++++++++++++++++ invokeai/app/util/download_with_progress.py | 51 ----------- .../app/services/model_load/test_load_api.py | 2 +- 6 files changed, 111 insertions(+), 57 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py delete mode 100644 invokeai/app/util/download_with_progress.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index f453a56584c..4b5a2004be1 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings): patchmatch: Enable patchmatch inpaint code. models_dir: Path to the models directory. convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location. + download_cache_dir: Path to the directory that contains dynamically downloaded models. legacy_conf_dir: Path to directory of legacy checkpoint config files. db_dir: Path to InvokeAI databases directory. outputs_dir: Path to directory for outputs. @@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings): # PATHS models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") - convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") @@ -303,6 +305,11 @@ def convert_cache_path(self) -> Path: """Path to the converted cache models directory, resolved to an absolute path..""" return self._resolve(self.convert_cache_dir) + @property + def download_cache_path(self) -> Path: + """Path to the downloaded models directory, resolved to an absolute path..""" + return self._resolve(self.download_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory, resolved to an absolute path..""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20cfc1c4ff9..f1fbcdb7ba7 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -393,6 +393,11 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 rmtree(model_path) self.unregister(key) + @classmethod + def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + return app_config.download_cache_path / model_hash + def download_and_cache( self, source: Union[str, AnyHttpUrl], @@ -400,8 +405,7 @@ def download_and_cache( timeout: int = 0, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] - model_path = self._app_config.convert_cache_path / model_hash + model_path = self._download_cache_path(source, self._app_config) # We expect the cache directory to contain one and only one downloaded file. # We don't know the file's name in advance, as it is set by the download @@ -532,8 +536,13 @@ def on_model_found(model_path: Path) -> bool: if resolved_path in installed_model_paths: return True # Skip core models entirely - these aren't registered with the model manager. - if str(resolved_path).startswith(str(self.app_config.models_path / "core")): - return False + for special_directory in [ + self.app_config.models_path / "core", + self.app_config.convert_cache_dir, + self.app_config.download_cache_dir, + ]: + if resolved_path.is_relative_to(special_directory): + return False try: model_id = self.register_path(model_path) self._logger.info(f"Registered {model_path.name} with id {model_id}") diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 1eed0b44092..61f35a3b4ea 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) + migrator.register_migration(build_migration_10(app_config=config, logger=logger)) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py new file mode 100644 index 00000000000..df341a6f2e5 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -0,0 +1,87 @@ +import pathlib +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_install.model_install_default import ModelInstallService +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = { + # OpenPose + "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true": "any/annotators/dwpose/yolox_l.onnx", + "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitl14.pth", + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitb14.pth", + "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true": "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt": "core/misc/lama/lama.pt", + # RealESRGAN upscale + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +} + + +class Migration10Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._rename_convert_cache() + self._migrate_downloaded_models_cache() + self._remove_unused_core_models() + + def _rename_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + configured_convert_dir = self._app_config.convert_cache_dir + configured_convert_path = self._app_config.convert_cache_path + # old convert dir was in use, and current convert dir has not been changed + if legacy_convert_path.exists() and configured_convert_dir == pathlib.Path("models/.convert_cache"): + self._logger.info( + f"Migrating legacy convert cache directory from {str(legacy_convert_path)} to {str(configured_convert_path)}" + ) + shutil.rmtree(configured_convert_path, ignore_errors=True) # shouldn't be needed, but just in case... + shutil.move(legacy_convert_path, configured_convert_path) + + def _migrate_downloaded_models_cache(self) -> None: + """Move used core models to modsl/.download_cache.""" + self._logger.info(f"Migrating legacy core models to {str(self._app_config.download_cache_path)}") + for url, legacy_dest in LEGACY_CORE_MODELS.items(): + legacy_dest_path = self._app_config.models_path / legacy_dest + if not legacy_dest_path.exists(): + continue + # this returns a unique directory path + new_path = ModelInstallService._download_cache_path(url, self._app_config) + new_path.mkdir(parents=True, exist_ok=True) + shutil.move(legacy_dest_path, new_path / legacy_dest_path.name) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_10(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 9 to 10. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + """ + migration_10 = Migration( + from_version=9, + to_version=10, + callback=Migration10Callback(app_config=app_config, logger=logger), + ) + + return migration_10 diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py deleted file mode 100644 index 97a2abb2f6f..00000000000 --- a/invokeai/app/util/download_with_progress.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from urllib import request - -from tqdm import tqdm - -from invokeai.backend.util.logging import InvokeAILogger - - -class ProgressBar: - """Simple progress bar for urllib.request.urlretrieve using tqdm.""" - - def __init__(self, model_name: str = "file"): - self.pbar = None - self.name = model_name - - def __call__(self, block_num: int, block_size: int, total_size: int): - if not self.pbar: - self.pbar = tqdm( - desc=self.name, - initial=0, - unit="iB", - unit_scale=True, - unit_divisor=1000, - total=total_size, - ) - self.pbar.update(block_size) - - -def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool: - """Download a file from a URL to a destination path, with a progress bar. - If the file already exists, it will not be downloaded again. - - Exceptions are not caught. - - Args: - name (str): Name of the file being downloaded. - url (str): URL to download the file from. - dest_path (Path): Destination path to save the file to. - - Returns: - bool: True if the file was downloaded, False if it already existed. - """ - if dest_path.exists(): - return False # already downloaded - - InvokeAILogger.get_logger().info(f"Downloading {name}...") - - dest_path.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url, dest_path, ProgressBar(name)) - - return True diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index b99cc7d8242..59b207f2ee5 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -29,7 +29,7 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) assert downloaded_path.is_file() assert downloaded_path.exists() assert downloaded_path.name == "test_embedding.safetensors" - assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" downloaded_path_2 = mock_context.models.download_and_cache_ckpt( "https://www.test.foo/download/test_embedding.safetensors" From d72f272f1682763170af3627f7089cbe5a4aee64 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Apr 2024 23:53:30 -0400 Subject: [PATCH 10/45] Address change requests in first round of PR reviews. Pending: - Move model install calls into model manager and create passthrus in invocation_context. - Consider splitting load_model_from_url() into a call to get the path and a call to load the path. --- invokeai/app/invocations/upscale.py | 28 ++++----- .../model_install/model_install_default.py | 6 +- .../app/services/shared/invocation_context.py | 61 ------------------ .../migrations/migration_10.py | 62 ++++++++----------- .../model_manager/load/load_default.py | 3 +- invokeai/backend/util/util.py | 21 +++++++ .../app/services/model_load/test_load_api.py | 6 -- 7 files changed, 64 insertions(+), 123 deletions(-) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index b8acfcb7bfd..29cf7819de3 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -11,7 +11,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN -from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithBoard, WithMetadata @@ -96,22 +95,21 @@ def invoke(self, context: InvocationContext) -> ImageOutput: source=ESRGAN_MODEL_URLS[self.model_name], ) - upscaler = RealESRGAN( - scale=netscale, - loadnet=loadnet.model, - model=rrdbnet_model, - half=False, - tile=self.tile_size, - ) - - # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL - # TODO: This strips the alpha... is that okay? - cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) - upscaled_image = upscaler.upscale(cv2_image) + with loadnet as loadnet_model: + upscaler = RealESRGAN( + scale=netscale, + loadnet=loadnet_model, + model=rrdbnet_model, + half=False, + tile=self.tile_size, + ) - pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") + # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL + # TODO: This strips the alpha... is that okay? + cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) + upscaled_image = upscaler.upscale(cv2_image) - TorchDevice.empty_cache() + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") image_dto = context.images.save(image=pil_image) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index e259cebace1..c4127acf7a6 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -6,7 +6,6 @@ import signal import threading import time -from hashlib import sha256 from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree @@ -44,6 +43,7 @@ from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import slugify from .model_install_base import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -396,8 +396,8 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 @classmethod def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: - model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] - return app_config.download_cache_path / model_hash + escaped_source = slugify(str(source)) + return app_config.download_cache_path / escaped_source def download_and_cache( self, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 0d27b2520ba..50551efa31a 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -430,67 +430,6 @@ def search_by_attrs( model_format=format, ) - def install_model( - self, - source: str, - config: Optional[Dict[str, Any]] = None, - access_token: Optional[str] = None, - inplace: Optional[bool] = False, - timeout: Optional[int] = 0, - ) -> str: - """Install and register a model in the database. - - Args: - source: String source; see below - config: Optional dict. Any fields in this dict - will override corresponding autoassigned probe fields in the - model's config record. - access_token: Optional access token for remote sources. - inplace: If true, installs a local model in place rather than copying - it into the models directory - timeout: How long to wait on install (in seconds). A value of 0 (default) - blocks indefinitely - - The source can be: - 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) - 2. An http or https URL (`https://foo.bar/foo`) - 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) - - We extend the HuggingFace repo_id syntax to include the variant and the - subfolder or path. The following are acceptable alternatives: - stabilityai/stable-diffusion-v4 - stabilityai/stable-diffusion-v4:fp16 - stabilityai/stable-diffusion-v4:fp16:vae - stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors - stabilityai/stable-diffusion-v4:onnx:vae - - Because a local file path can look like a huggingface repo_id, the logic - first checks whether the path exists on disk, and if not, it is treated as - a parseable huggingface repo. - - Returns: - Key to the newly installed model. - - May Raise: - ValueError -- bad source - UnknownModelException -- remote model not found - InvalidModelException -- what was retrieved from remote is not a model - TimeoutError -- model could not be installed within timeout - Exception -- another error condition - """ - installer = self._services.model_manager.install - job = installer.heuristic_import( - source=source, - config=config, - access_token=access_token, - inplace=inplace, - ) - installer.wait_for_job(job, timeout) - if job.errored: - raise Exception(job.error) - key: str = job.config_out.key - return key - def download_and_cache_ckpt( self, source: Union[str, AnyHttpUrl], diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py index df341a6f2e5..4c4f742d4c0 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -1,28 +1,26 @@ -import pathlib import shutil import sqlite3 from logging import Logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install.model_install_default import ModelInstallService from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -LEGACY_CORE_MODELS = { +LEGACY_CORE_MODELS = [ # OpenPose - "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true": "any/annotators/dwpose/yolox_l.onnx", - "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + "any/annotators/dwpose/yolox_l.onnx", + "any/annotators/dwpose/dw-ll_ucoco_384.onnx", # DepthAnything - "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitl14.pth", - "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitb14.pth", - "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true": "any/annotators/depth_anything/depth_anything_vits14.pth", + "any/annotators/depth_anything/depth_anything_vitl14.pth", + "any/annotators/depth_anything/depth_anything_vitb14.pth", + "any/annotators/depth_anything/depth_anything_vits14.pth", # Lama inpaint - "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt": "core/misc/lama/lama.pt", + "core/misc/lama/lama.pt", # RealESRGAN upscale - "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", - "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", - "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", -} + "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +] class Migration10Callback: @@ -31,34 +29,24 @@ def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: self._logger = logger def __call__(self, cursor: sqlite3.Cursor) -> None: - self._rename_convert_cache() - self._migrate_downloaded_models_cache() + self._remove_convert_cache() + self._remove_downloaded_models() self._remove_unused_core_models() - def _rename_convert_cache(self) -> None: + def _remove_convert_cache(self) -> None: """Rename models/.cache to models/.convert_cache.""" + self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") legacy_convert_path = self._app_config.root_path / "models" / ".cache" - configured_convert_dir = self._app_config.convert_cache_dir - configured_convert_path = self._app_config.convert_cache_path - # old convert dir was in use, and current convert dir has not been changed - if legacy_convert_path.exists() and configured_convert_dir == pathlib.Path("models/.convert_cache"): - self._logger.info( - f"Migrating legacy convert cache directory from {str(legacy_convert_path)} to {str(configured_convert_path)}" - ) - shutil.rmtree(configured_convert_path, ignore_errors=True) # shouldn't be needed, but just in case... - shutil.move(legacy_convert_path, configured_convert_path) + shutil.rmtree(legacy_convert_path, ignore_errors=True) - def _migrate_downloaded_models_cache(self) -> None: - """Move used core models to modsl/.download_cache.""" - self._logger.info(f"Migrating legacy core models to {str(self._app_config.download_cache_path)}") - for url, legacy_dest in LEGACY_CORE_MODELS.items(): - legacy_dest_path = self._app_config.models_path / legacy_dest - if not legacy_dest_path.exists(): - continue - # this returns a unique directory path - new_path = ModelInstallService._download_cache_path(url, self._app_config) - new_path.mkdir(parents=True, exist_ok=True) - shutil.move(legacy_dest_path, new_path / legacy_dest_path.name) + def _remove_downloaded_models(self) -> None: + """Remove models from their old locations; they will re-download when needed.""" + self._logger.info( + "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." + ) + for model_path in LEGACY_CORE_MODELS: + legacy_dest_path = self._app_config.models_path / model_path + legacy_dest_path.unlink(missing_ok=True) def _remove_unused_core_models(self) -> None: """Remove unused core models and their directories.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 459808b4554..16b9e3646ee 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import slugify # TO DO: The loader is not thread safe! @@ -84,7 +85,7 @@ def _convert_and_load( except IndexError: pass - cache_path: Path = self._convert_cache.cache_path(config.key) + cache_path: Path = self._convert_cache.cache_path(slugify(model_path)) if self._needs_conversion(config, model_path, cache_path): loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) else: diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 7d0d9d03f76..2b2bee34583 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -1,6 +1,8 @@ import base64 import io import os +import re +import unicodedata import warnings from pathlib import Path @@ -12,6 +14,25 @@ GIG = 1073741824 +def slugify(value: str, allow_unicode: bool = False) -> str: + """ + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + + Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + value = re.sub(r"[/]", "_", value.lower()) + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 59b207f2ee5..463be86c681 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -49,9 +49,3 @@ def test_download_and_load(mock_context: InvocationContext): assert isinstance(model_1, dict) -def test_install_model(mock_context: InvocationContext): - key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors") - assert key is not None - model = mock_context.models.load(key) - assert model is not None - assert model.config.key == key From 70903ef05760e82a5814c432e6a0b48475eba467 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Apr 2024 11:33:23 -0400 Subject: [PATCH 11/45] refactor load_ckpt_from_url() --- invokeai/app/invocations/infill.py | 4 +- .../model_install/model_install_base.py | 7 +- .../model_install/model_install_default.py | 4 +- .../services/model_load/model_load_base.py | 27 +++- .../services/model_load/model_load_default.py | 53 +++++++- .../model_manager/model_manager_base.py | 36 ++++++ .../model_manager/model_manager_default.py | 39 +++++- .../app/services/shared/invocation_context.py | 117 ++++++++++-------- .../backend/image_util/infill_methods/lama.py | 16 +-- invokeai/backend/model_manager/config.py | 2 +- .../app/services/model_load/test_load_api.py | 2 - 11 files changed, 235 insertions(+), 72 deletions(-) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index edee275e722..50fec229940 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,7 +133,9 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - lama = LaMA(context) + # Note that this accesses a protected attribute to get to the model manager service. + # Is there a better way? + lama = LaMA(context._services.model_manager) return lama(image) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 0ea901fb462..388f4a5ba27 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -468,7 +468,12 @@ def sync_model_path(self, key: str) -> AnyModelConfig: """ @abstractmethod - def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + def download_and_cache_ckpt( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: """ Download the model file located at source to the models cache and return its Path. diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c4127acf7a6..32c86ad3a30 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -399,9 +399,9 @@ def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: Invoke escaped_source = slugify(str(source)) return app_config.download_cache_path / escaped_source - def download_and_cache( + def download_and_cache_ckpt( self, - source: Union[str, AnyHttpUrl], + source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: int = 0, ) -> Path: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e932..d59f7a370d1 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -2,7 +2,10 @@ """Base class for model loader.""" from abc import ABC, abstractmethod -from typing import Optional +from pathlib import Path +from typing import Callable, Dict, Optional + +from torch import Tensor from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType @@ -38,3 +41,25 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]: @abstractmethod def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" + + @abstractmethod + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 21d3c56f36b..a87b6123ce9 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,7 +1,13 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional, Type +from pathlib import Path +from typing import Callable, Dict, Optional, Type + +from picklescan.scanner import scan_file_path +from safetensors.torch import load_file as safetensors_load_file +from torch import Tensor +from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker @@ -88,6 +94,51 @@ def load_model( ) return loaded_model + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + cache_key = str(model_path) + ram_cache = self.ram_cache + try: + return LoadedModel(_locker=ram_cache.get(key=cache_key)) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") + return result + + if loader is None: + loader = ( + torch_load_file + if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModel(_locker=ram_cache.get(key=cache_key)) + def _emit_load_event( self, context_data: InvocationContextData, diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index af1b68e1ec3..7a5f433aca0 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,11 +1,15 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, Optional import torch +from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker +from invokeai.backend.model_manager.load import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -66,3 +70,35 @@ def start(self, invoker: Invoker) -> None: @abstractmethod def stop(self, invoker: Invoker) -> None: pass + + @abstractmethod + def load_ckpt_from_url( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 1a2b9a34022..57c409c066d 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,13 +1,15 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from typing import Optional +from pathlib import Path +from typing import Callable, Dict, Optional import torch +from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -100,3 +102,36 @@ def build_model_manager( event_bus=events, ) return cls(store=model_record_service, install=installer, load=loader) + + def load_ckpt_from_url( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout) + return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 50551efa31a..485be2ba914 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,14 +1,11 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union -from picklescan.scanner import scan_file_path +import torch from PIL.Image import Image from pydantic.networks import AnyHttpUrl -from safetensors.torch import load_file as safetensors_load_file -from torch import Tensor -from torch import load as torch_load from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -263,7 +260,7 @@ def get_path(self, image_name: str, thumbnail: bool = False) -> Path: class TensorsInterface(InvocationContextInterface): - def save(self, tensor: Tensor) -> str: + def save(self, tensor: torch.Tensor) -> str: """Saves a tensor, returning its name. Args: @@ -276,7 +273,7 @@ def save(self, tensor: Tensor) -> str: name = self._services.tensors.save(obj=tensor) return name - def load(self, name: str) -> Tensor: + def load(self, name: str) -> torch.Tensor: """Loads a tensor by name. Args: @@ -316,8 +313,10 @@ def load(self, name: str) -> ConditioningFieldData: class ModelsInterface(InvocationContextInterface): + """Common API for loading, downloading and managing models.""" + def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: - """Checks if a model exists. + """Check if a model exists. Args: identifier: The key or ModelField representing the model. @@ -326,14 +325,18 @@ def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: True if the model exists, False if not. """ if isinstance(identifier, str): - return self._services.model_manager.store.exists(identifier) - - return self._services.model_manager.store.exists(identifier.key) + # For some reason, Mypy is not getting the type annotations for many of + # the model manager service calls and raises a "returning Any in typed + # context" error. Hence the extra typing hints here and below. + result: bool = self._services.model_manager.store.exists(identifier) + else: + result = self._services.model_manager.store.exists(identifier.key) + return result def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model. + """Load a model. Args: identifier: The key or ModelField representing the model. @@ -342,22 +345,22 @@ def load( Returns: An object representing the loaded model. """ - # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - return self._services.model_manager.load.load_model(model, submodel_type, self._data) + result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - return self._services.model_manager.load.load_model(model, _submodel_type, self._data) + result = self._services.model_manager.load.load_model(model, _submodel_type, self._data) + return result def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model by its attributes. + """Load a model by its attributes. Args: name: Name of the model. @@ -369,7 +372,6 @@ def load_by_attrs( Returns: An object representing the loaded model. """ - configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) if len(configs) == 0: raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") @@ -377,10 +379,11 @@ def load_by_attrs( if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + return result def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: - """Gets a model's config. + """Get a model's config. Args: identifier: The key or ModelField representing the model. @@ -389,12 +392,13 @@ def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModel The model's config. """ if isinstance(identifier, str): - return self._services.model_manager.store.get_model(identifier) - - return self._services.model_manager.store.get_model(identifier.key) + result: AnyModelConfig = self._services.model_manager.store.get_model(identifier) + else: + result = self._services.model_manager.store.get_model(identifier.key) + return result def search_by_path(self, path: Path) -> list[AnyModelConfig]: - """Searches for models by path. + """Search for models by path. Args: path: The path to search for. @@ -402,7 +406,8 @@ def search_by_path(self, path: Path) -> list[AnyModelConfig]: Returns: A list of models that match the path. """ - return self._services.model_manager.store.search_by_path(path) + result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path) + return result def search_by_attrs( self, @@ -411,7 +416,7 @@ def search_by_attrs( type: Optional[ModelType] = None, format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: - """Searches for models by attributes. + """Search for models by attributes. Args: name: The name to search for (exact match). @@ -422,13 +427,13 @@ def search_by_attrs( Returns: A list of models that match the attributes. """ - - return self._services.model_manager.store.search_by_attr( + result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr( model_name=name, base_model=base, model_type=type, model_format=format, ) + return result def download_and_cache_ckpt( self, @@ -451,26 +456,49 @@ def download_and_cache_ckpt( out long downloads. Result: - Path of the downloaded model + Path to the downloaded model May Raise: HTTPError TimeoutError """ installer = self._services.model_manager.install - path: Path = installer.download_and_cache( + path: Path = installer.download_and_cache_ckpt( source=source, access_token=access_token, timeout=timeout, ) return path + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader) + return result + def load_ckpt_from_url( self, - source: Union[str, AnyHttpUrl], + source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0, - loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ Download, cache, and Load the model file located at the indicated URL. @@ -495,29 +523,10 @@ def load_ckpt_from_url( Returns: A LoadedModel object. """ - ram_cache = self._services.model_manager.load.ram_cache - try: - return LoadedModel(_locker=ram_cache.get(key=str(source))) - except IndexError: - pass - - def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]: - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") - return torch_load(path, map_location="cpu") - - path = self.download_and_cache_ckpt(source, access_token, timeout) - if loader is None: - loader = ( - torch_load_file - if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) - else lambda path: safetensors_load_file(path, device="cpu") - ) - - raw_model = loader(path) - ram_cache.put(key=str(source), model=raw_model) - return LoadedModel(_locker=ram_cache.get(key=str(source))) + result: LoadedModel = self._services.model_manager.load_ckpt_from_url( + source=source, access_token=access_token, timeout=timeout, loader=loader + ) + return result class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 8c3f33efad1..c7fea497ca8 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,11 +1,13 @@ -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.shared.invocation_context import InvocationContext + +if TYPE_CHECKING: + from invokeai.app.services.model_manager import ModelManagerServiceBase def norm_img(np_img): @@ -16,20 +18,20 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device): +def load_jit_model(url_or_path, device) -> torch.nn.Module: model_path = url_or_path logger.info(f"Loading model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu").to(device) + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore model.eval() return model class LaMA: - def __init__(self, context: InvocationContext): - self._context = context + def __init__(self, model_manager: "ModelManagerServiceBase"): + self._model_manager = model_manager def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - loaded_model = self._context.models.load_ckpt_from_url( + loaded_model = self._model_manager.load_ckpt_from_url( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=lambda path: load_jit_model(path, "cpu"), ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 82f88c0e817..1a5d95b7d82 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -36,7 +36,7 @@ # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] class InvalidModelConfigException(Exception): diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 463be86c681..167c2a09dfa 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -47,5 +47,3 @@ def test_download_and_load(mock_context: InvocationContext): with loaded_model_1 as model_1, loaded_model_2 as model_2: assert model_1 == model_2 assert isinstance(model_1, dict) - - From a26667d3ca9110d07507a8466fdc6339f72a36e3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Apr 2024 12:24:36 -0400 Subject: [PATCH 12/45] make download and convert cache keys safe for filename length --- .../convert_cache/convert_cache_default.py | 2 ++ .../model_manager/load/load_default.py | 3 +- invokeai/backend/util/util.py | 12 ++++++-- .../app/services/model_load/test_load_api.py | 29 +++++++++++++++---- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 8dc2aff74b7..cf6448c0568 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -7,6 +7,7 @@ from invokeai.backend.util import GIG, directory_size from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util.util import safe_filename from .convert_cache_base import ModelConvertCacheBase @@ -35,6 +36,7 @@ def max_size(self, value: float) -> None: def cache_path(self, key: str) -> Path: """Return the path for a model with the indicated key.""" + key = safe_filename(self._cache_path, key) return self._cache_path / key def make_room(self, size: float) -> None: diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 16b9e3646ee..a63cc66a86c 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,6 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.util import slugify # TO DO: The loader is not thread safe! @@ -85,7 +84,7 @@ def _convert_and_load( except IndexError: pass - cache_path: Path = self._convert_cache.cache_path(slugify(model_path)) + cache_path: Path = self._convert_cache.cache_path(str(model_path)) if self._needs_conversion(config, model_path, cache_path): loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) else: diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 2b2bee34583..8ffa1ee7df9 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -18,7 +18,8 @@ def slugify(value: str, allow_unicode: bool = False) -> str: """ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated dashes to single dashes. Remove characters that aren't alphanumerics, - underscores, or hyphens. Convert to lowercase. Also strip leading and + underscores, or hyphens. Replace slashes with underscores. + Convert to lowercase. Also strip leading and trailing whitespace, dashes, and underscores. Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py @@ -29,10 +30,17 @@ def slugify(value: str, allow_unicode: bool = False) -> str: else: value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") value = re.sub(r"[/]", "_", value.lower()) - value = re.sub(r"[^\w\s-]", "", value.lower()) + value = re.sub(r"[^.\w\s-]", "", value.lower()) return re.sub(r"[-\s]+", "-", value).strip("-_") +def safe_filename(directory: Path, value: str) -> str: + """Make a string safe to use as a filename.""" + escaped_string = slugify(value) + max_name_length = os.pathconf(directory, "PC_NAME_MAX") + return escaped_string[len(escaped_string) - max_name_length :] + + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 167c2a09dfa..7eb09fb3754 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +import torch from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_manager import ModelManagerServiceBase @@ -22,7 +23,7 @@ def mock_context( ) -def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path): +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: downloaded_path = mock_context.models.download_and_cache_ckpt( "https://www.test.foo/download/test_embedding.safetensors" ) @@ -37,13 +38,29 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) assert downloaded_path == downloaded_path_2 -def test_download_and_load(mock_context: InvocationContext): +def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: + downloaded_path = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path) + assert isinstance(loaded_model_1, LoadedModel) + + loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path) + assert isinstance(loaded_model_2, LoadedModel) + assert loaded_model_1.model is loaded_model_2.model + + loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file) + assert isinstance(loaded_model_3, LoadedModel) + assert loaded_model_1.model is not loaded_model_3.model + assert isinstance(loaded_model_1.model, dict) + assert isinstance(loaded_model_3.model, dict) + assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) + + +def test_download_and_load(mock_context: InvocationContext) -> None: loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") assert isinstance(loaded_model_1, LoadedModel) loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") assert isinstance(loaded_model_2, LoadedModel) - - with loaded_model_1 as model_1, loaded_model_2 as model_2: - assert model_1 == model_2 - assert isinstance(model_1, dict) + assert loaded_model_1.model is loaded_model_2.model # should be cached copy From 7c39929758a8df48b8bc9973766a5d47a3ee2b1c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Apr 2024 13:41:06 -0400 Subject: [PATCH 13/45] support VRAM caching of dict models that lack `to()` --- .../app/services/model_load/model_load_base.py | 2 +- .../services/model_load/model_load_default.py | 2 +- .../app/services/shared/invocation_context.py | 4 ++-- .../load/model_cache/model_cache_default.py | 17 ++++++++--------- .../load/model_cache/model_locker.py | 7 ------- 5 files changed, 12 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index d59f7a370d1..32fc62fa5bc 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -58,7 +58,7 @@ def load_ckpt_from_path( Args: model_path: A pathlib.Path to a checkpoint-style models file - loader: A Callable that expects a Path and returns a Dict[str|int, Any] + loader: A Callable that expects a Path and returns a Dict[str, Tensor] Returns: A LoadedModel object. diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index a87b6123ce9..af211c260e5 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -109,7 +109,7 @@ def load_ckpt_from_path( Args: model_path: A pathlib.Path to a checkpoint-style models file - loader: A Callable that expects a Path and returns a Dict[str|int, Any] + loader: A Callable that expects a Path and returns a Dict[str, Tensor] Returns: A LoadedModel object. diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 485be2ba914..bfdbf1e0259 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -437,7 +437,7 @@ def search_by_attrs( def download_and_cache_ckpt( self, - source: Union[str, AnyHttpUrl], + source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0, ) -> Path: @@ -501,7 +501,7 @@ def load_ckpt_from_url( loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ - Download, cache, and Load the model file located at the indicated URL. + Download, cache, and load the model file located at the indicated URL. This will check the model download cache for the model designated by the provided URL and download it if needed using download_and_cache_ckpt(). diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 62bb766cd63..bd7b2ffc7a8 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -252,23 +252,22 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device May raise a torch.cuda.OutOfMemoryError """ - # These attributes are not in the base ModelMixin class but in various derived classes. - # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): - return - - source_device = cache_entry.model.device + model = cache_entry.model - # Note: We compare device types only so that 'cuda' == 'cuda:0'. - # This would need to be revised to support multi-GPU. + source_device = model.device if hasattr(model, "device") else self.storage_device if torch.device(source_device).type == torch.device(target_device).type: return start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() try: - cache_entry.model.to(target_device) + if hasattr(model, "to"): + model.to(target_device) + elif isinstance(model, dict): + for _, v in model.items(): + if hasattr(v, "to"): + v.to(target_device) except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index a2759877739..e3cb7c8fff2 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -29,10 +29,6 @@ def model(self) -> AnyModel: def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! self._cache_entry.lock() try: if self._cache.lazy_offloading: @@ -55,9 +51,6 @@ def lock(self) -> AnyModel: def unlock(self) -> None: """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return - self._cache_entry.unlock() if not self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) From 57c831442e4b41521b6d096d5805c3718a3fd764 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Apr 2024 14:42:40 -0400 Subject: [PATCH 14/45] fix safe_filename() on windows --- invokeai/backend/util/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 8ffa1ee7df9..1ee89dcc661 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -37,7 +37,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str: def safe_filename(directory: Path, value: str) -> str: """Make a string safe to use as a filename.""" escaped_string = slugify(value) - max_name_length = os.pathconf(directory, "PC_NAME_MAX") + max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256 return escaped_string[len(escaped_string) - max_name_length :] From fcb071f30caa5de52ab6e951e39dfbf6bf80e458 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 Apr 2024 08:12:51 +1000 Subject: [PATCH 15/45] feat(backend): lift managed model loading out of lama class --- invokeai/app/invocations/infill.py | 10 +++-- .../backend/image_util/infill_methods/lama.py | 43 ++++++++----------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 50fec229940..f8358d1df5c 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,10 +133,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - # Note that this accesses a protected attribute to get to the model manager service. - # Is there a better way? - lama = LaMA(context._services.model_manager) - return lama(image) + with context.models.load_ckpt_from_url( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=LaMA.load_jit_model, + ) as model: + lama = LaMA(model) + return lama(image) @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index c7fea497ca8..cd5838d1f2b 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,13 +1,12 @@ -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import Any import numpy as np import torch from PIL import Image import invokeai.backend.util.logging as logger - -if TYPE_CHECKING: - from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.backend.model_manager.config import AnyModel def norm_img(np_img): @@ -18,24 +17,11 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device) -> torch.nn.Module: - model_path = url_or_path - logger.info(f"Loading model from: {model_path}") - model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore - model.eval() - return model - - class LaMA: - def __init__(self, model_manager: "ModelManagerServiceBase"): - self._model_manager = model_manager + def __init__(self, model: AnyModel): + self._model = model def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - loaded_model = self._model_manager.load_ckpt_from_url( - source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - loader=lambda path: load_jit_model(path, "cpu"), - ) - image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -45,16 +31,23 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: mask = norm_img(mask) mask = (mask > 0) * 1 - with loaded_model as model: - device = next(model.buffers()).device - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) + device = next(self._model.buffers()).device + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) - with torch.inference_mode(): - infilled_image = model(image, mask) + with torch.inference_mode(): + infilled_image = self._model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) return infilled_image + + @staticmethod + def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module: + model_path = url_or_path + logger.info(f"Loading model from: {model_path}") + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore + model.eval() + return model From 1fe90c357cb3f216bd4f2789180911c3520c6505 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 Apr 2024 08:56:00 +1000 Subject: [PATCH 16/45] feat(backend): lift managed model loading out of depthanything class --- .../controlnet_image_processors.py | 18 ++++--- .../image_util/depth_anything/__init__.py | 53 ++++++++----------- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index fd34f18e76e..6c647c7ed14 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -2,6 +2,7 @@ # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float +from pathlib import Path from typing import Dict, List, Literal, Union import cv2 @@ -37,11 +38,12 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES from invokeai.backend.image_util.canny import get_canny_edges -from invokeai.backend.image_util.depth_anything import DepthAnythingDetector +from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output @@ -603,11 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: - depth_anything_detector = DepthAnythingDetector(context) - depth_anything_detector.load_model(model_size=self.model_size) - - processed_image = depth_anything_detector(image=image, resolution=self.resolution) - return processed_image + def loader(model_path: Path): + return DepthAnythingDetector.load_model( + model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() + ) + + with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: + depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) + processed_image = depth_anything_detector(image=image, resolution=self.resolution) + return processed_image @invocation( diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index 2d88c45485e..1adcc6b2029 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,4 +1,5 @@ -from typing import Literal, Optional, Union +from pathlib import Path +from typing import Literal import cv2 import numpy as np @@ -9,10 +10,8 @@ from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger config = get_config() @@ -43,33 +42,27 @@ class DepthAnythingDetector: - def __init__(self, context: InvocationContext) -> None: - self.context = context - self.model: Optional[DPT_DINOv2] = None - self.model_size: Union[Literal["large", "base", "small"], None] = None - self.device = TorchDevice.choose_torch_device() - - def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: - depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) - - if not self.model or model_size != self.model_size: - del self.model - self.model_size = model_size - - match self.model_size: - case "small": - self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) - case "base": - self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) - case "large": - self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - - assert self.model is not None - self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) - self.model.eval() - - self.model.to(self.device) - return self.model + def __init__(self, model: DPT_DINOv2, device: torch.device) -> None: + self.model = model + self.device = device + + @staticmethod + def load_model( + model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small" + ) -> DPT_DINOv2: + match model_size: + case "small": + model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) + case "base": + model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) + case "large": + model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) + + model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu")) + model.eval() + + model.to(device) + return model def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: if not self.model: From 38df6f37028a8d6cc81055aa35e4ee3e5a716cf7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 2 May 2024 21:22:33 -0400 Subject: [PATCH 17/45] fix ruff error --- invokeai/app/invocations/controlnet_image_processors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e9793b94247..267f72ec2b3 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -43,8 +43,8 @@ from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.image_util.util import np_to_pil, pil_to_np +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output From e9a20051bd52860298a1b10fac2b40672f855a79 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 3 May 2024 18:08:53 -0400 Subject: [PATCH 18/45] refactor DWOpenPose and add type hints --- .../controlnet_image_processors.py | 8 +++- .../image_util/dw_openpose/__init__.py | 48 +++++++++++++++---- .../backend/image_util/dw_openpose/utils.py | 8 ++-- .../image_util/dw_openpose/wholebody.py | 13 ++--- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 267f72ec2b3..971179ac93a 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -39,7 +39,7 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.backend.image_util.canny import get_canny_edges from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector -from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor @@ -633,7 +633,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: - dw_openpose = DWOpenposeDetector(context) + mm = context.models + onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + + dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index 17ca0233c82..cfd3ea4b0da 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -1,31 +1,53 @@ +from pathlib import Path +from typing import Dict + import numpy as np import torch from controlnet_aux.util import resize_image from PIL import Image -from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose +from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody - -def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512): +DWPOSE_MODELS = { + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", +} + + +def draw_pose( + pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]], + H: int, + W: int, + draw_face: bool = True, + draw_body: bool = True, + draw_hands: bool = True, + resolution: int = 512, +) -> Image.Image: bodies = pose["bodies"] faces = pose["faces"] hands = pose["hands"] + + assert isinstance(bodies, dict) candidate = bodies["candidate"] + + assert isinstance(bodies, dict) subset = bodies["subset"] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) if draw_body: canvas = draw_bodypose(canvas, candidate, subset) if draw_hands: + assert isinstance(hands, np.ndarray) canvas = draw_handpose(canvas, hands) if draw_face: - canvas = draw_facepose(canvas, faces) + assert isinstance(hands, np.ndarray) + canvas = draw_facepose(canvas, faces) # type: ignore - dwpose_image = resize_image( + dwpose_image: Image.Image = resize_image( canvas, resolution, ) @@ -40,11 +62,16 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self, context: InvocationContext) -> None: - self.pose_estimation = Wholebody(context) + def __init__(self, onnx_det: Path, onnx_pose: Path) -> None: + self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose) def __call__( - self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 + self, + image: Image.Image, + draw_face: bool = False, + draw_body: bool = True, + draw_hands: bool = False, + resolution: int = 512, ) -> Image.Image: np_image = np.array(image) H, W, C = np_image.shape @@ -80,3 +107,6 @@ def __call__( return draw_pose( pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution ) + + +__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"] diff --git a/invokeai/backend/image_util/dw_openpose/utils.py b/invokeai/backend/image_util/dw_openpose/utils.py index 428672ab312..dc142dfa71c 100644 --- a/invokeai/backend/image_util/dw_openpose/utils.py +++ b/invokeai/backend/image_util/dw_openpose/utils.py @@ -5,11 +5,13 @@ import cv2 import matplotlib import numpy as np +import numpy.typing as npt eps = 0.01 +NDArrayInt = npt.NDArray[np.uint8] -def draw_bodypose(canvas, candidate, subset): +def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape candidate = np.array(candidate) subset = np.array(subset) @@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset): return canvas -def draw_handpose(canvas, all_hand_peaks): +def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape edges = [ @@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks): return canvas -def draw_facepose(canvas, all_lmks): +def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape for lmks in all_lmks: lmks = np.array(lmks) diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 0f66af2c779..3f77f20b9c2 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -2,33 +2,26 @@ # Modified pathing to suit Invoke +from pathlib import Path + import numpy as np import onnxruntime as ort from invokeai.app.services.config.config_default import get_config -from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.util.devices import TorchDevice from .onnxdet import inference_detector from .onnxpose import inference_pose -DWPOSE_MODELS = { - "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", -} - config = get_config() class Wholebody: - def __init__(self, context: InvocationContext): + def __init__(self, onnx_det: Path, onnx_pose: Path): device = TorchDevice.choose_torch_device() providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] - onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) - onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) - self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) From f211c95dbcd08770e388356cdcd5a0988dadf4fc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 5 May 2024 21:00:31 -0400 Subject: [PATCH 19/45] move access token regex matching into download queue --- invokeai/app/api/dependencies.py | 2 +- .../app/services/download/download_default.py | 16 +++++- .../model_install/model_install_default.py | 9 +--- .../model_manager/model_manager_base.py | 5 -- .../model_manager/model_manager_default.py | 7 +-- .../app/services/shared/invocation_context.py | 9 +--- .../services/download/test_download_queue.py | 50 +++++++++++++++++++ 7 files changed, 69 insertions(+), 29 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f69..0cfcf2f3b71 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -93,7 +93,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - download_queue_service = DownloadQueueService(event_bus=events) + download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_manager = ModelManagerService.build_model_manager( app_config=configuration, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7d8229fba1d..d9ab2c7f351 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -15,6 +15,7 @@ from requests import HTTPError from tqdm import tqdm +from invokeai.app.services.config import InvokeAIAppConfig, get_config from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp from invokeai.backend.util.logging import InvokeAILogger @@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, + app_config: Optional[InvokeAIAppConfig] = None, event_bus: Optional[EventServiceBase] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ Initialize DownloadQueue. + :param app_config: InvokeAIAppConfig object :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ + self._app_config = app_config or get_config() self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 self._queue: PriorityQueue[DownloadJob] = PriorityQueue() @@ -139,7 +143,7 @@ def download( source=source, dest=dest, priority=priority, - access_token=access_token, + access_token=access_token or self._lookup_access_token(source), ) self.submit_download_job( job, @@ -333,6 +337,16 @@ def _validate_filename(self, directory: str, filename: str) -> bool: def _in_progress_path(self, path: Path) -> Path: return path.with_name(path.name + ".downloading") + def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: + # Pull the token from config if it exists and matches the URL + print(self._app_config) + token = None + for pair in self._app_config.remote_api_tokens or []: + if re.search(pair.url_regex, str(source)): + token = pair.token + break + return token + def _signal_job_started(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.RUNNING if job.on_start: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 92512baec91..1a08624f8e3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -222,16 +222,9 @@ def heuristic_import( access_token=access_token, ) elif re.match(r"^https?://[^/]+", source): - # Pull the token from config if it exists and matches the URL - _token = access_token - if _token is None: - for pair in self.app_config.remote_api_tokens or []: - if re.search(pair.url_regex, source): - _token = pair.token - break source_obj = URLModelSource( url=AnyHttpUrl(source), - access_token=_token, + access_token=access_token, ) else: raise ValueError(f"Unsupported model source: '{source}'") diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 7a5f433aca0..d16c00302ee 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -75,8 +75,6 @@ def stop(self, invoker: Invoker) -> None: def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -94,9 +92,6 @@ def load_ckpt_from_url( Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 57c409c066d..ed274266f39 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -106,8 +106,6 @@ def build_model_manager( def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -125,13 +123,10 @@ def load_ckpt_from_url( Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout) + model_path = self.install.download_and_cache_ckpt(source=source) return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index bfdbf1e0259..c7602760f71 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -496,8 +496,6 @@ def load_ckpt_from_path( def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -515,17 +513,12 @@ def load_ckpt_from_url( Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - result: LoadedModel = self._services.model_manager.load_ckpt_from_url( - source=source, access_token=access_token, timeout=timeout, loader=loader - ) + result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader) return result diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 307238fd611..07c473b1832 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -2,14 +2,19 @@ import re import time +from contextlib import contextmanager from pathlib import Path +from typing import Generator import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession +from invokeai.app.services.config import get_config +from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings @@ -34,6 +39,17 @@ def session() -> Session: ), ) + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + # here are some malformed URLs to test # missing the content length sess.mount( @@ -205,3 +221,37 @@ def handler(signum, frame): assert events[-1].event_name == "download_cancelled" assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() + + +@contextmanager +def clear_config() -> Generator[None, None, None]: + try: + yield None + finally: + get_config.cache_clear() + + +def test_tokens(tmp_path: Path, session: Session): + with clear_config(): + config = get_config() + config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] + queue = DownloadQueueService(requests_session=session) + queue.start() + # this one has an access token assigned + job1 = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + ) + # this one doesn't + job2 = queue.download( + source=AnyHttpUrl( + "http://www.huggingface.co/foo.txt", + ), + dest=tmp_path, + ) + queue.join() + # this token is defined in the temporary root invokeai.yaml + # see tests/backend/model_manager/data/invokeai_root/invokeai.yaml + assert job1.access_token == "cv_12345" + assert job2.access_token is None + queue.stop() From b48d4a049deedf0103f47f6d50a3ef0ebe94d165 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 8 May 2024 21:21:01 -0700 Subject: [PATCH 20/45] bad implementation of diffusers folder download --- .../model_install/model_install_base.py | 16 +++++++++ .../model_install/model_install_default.py | 33 ++++++++++++++----- invokeai/backend/model_manager/config.py | 2 +- .../metadata/fetch/huggingface.py | 2 +- .../model_install/test_model_install.py | 24 ++++++++++++-- 5 files changed, 64 insertions(+), 13 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 388f4a5ba27..ccb8e3772e2 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _do_install: Optional[bool] = PrivateAttr(default=True) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: @@ -407,6 +408,21 @@ def import_model( """ + @abstractmethod + def download_diffusers_model( + self, + source: HFModelSource, + download_to: Path, + ) -> ModelInstallJob: + """ + Download, but do not install, a diffusers model. + + :param source: An HFModelSource object containing a repo_id + :param download_to: Path to directory that will contain the downloaded model. + + Returns: a ModelInstallJob + """ + @abstractmethod def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: """Return the ModelInstallJob(s) corresponding to the provided source.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1a08624f8e3..fe932649c4c 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -249,6 +249,9 @@ def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = N self._install_jobs.append(install_job) return install_job + def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob: + return self._import_from_hf(source, download_path=download_to) + def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs @@ -641,7 +644,12 @@ def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[st inplace=source.inplace or False, ) - def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_hf( + self, + source: HFModelSource, + config: Optional[Dict[str, Any]] = None, + download_path: Optional[Path] = None, + ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests source.access_token = source.access_token or HfFolder.get_token() if not source.access_token: @@ -660,9 +668,14 @@ def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any] config=config, remote_files=remote_files, metadata=metadata, + download_path=download_path, ) - def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_url( + self, + source: URLModelSource, + config: Optional[Dict[str, Any]], + ) -> ModelInstallJob: # URLs from HuggingFace will be handled specially metadata = None fetcher = None @@ -676,6 +689,7 @@ def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, An self._logger.debug(f"metadata={metadata}") if metadata and isinstance(metadata, ModelMetadataWithFiles): remote_files = metadata.download_urls(session=self._session) + print(remote_files) else: remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] return self._import_remote_model( @@ -691,13 +705,14 @@ def _import_remote_model( remote_files: List[RemoteModelFile], metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], + download_path: Optional[Path] = None, # if defined, download only - don't install! ) -> ModelInstallJob: # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - tmpdir = Path( + destdir = download_path or Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -708,7 +723,7 @@ def _import_remote_model( source=source, config_in=config or {}, source_metadata=metadata, - local_path=tmpdir, # local path may change once the download has started due to content-disposition handling + local_path=destdir, # local path may change once the download has started due to content-disposition handling bytes=0, total_bytes=0, ) @@ -722,9 +737,10 @@ def _import_remote_model( root = Path(".") subfolder = Path(".") - # we remember the path up to the top of the tmpdir so that it may be + # we remember the path up to the top of the destdir so that it may be # removed safely at the end of the install process. - install_job._install_tmpdir = tmpdir + install_job._install_tmpdir = destdir + install_job._do_install = download_path is None assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below files_string = "file" if len(remote_files) == 1 else "file" @@ -736,7 +752,7 @@ def _import_remote_model( self._logger.debug(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") - dest = tmpdir / path.parent + dest = destdir / path.parent dest.mkdir(parents=True, exist_ok=True) download_job = DownloadJob( source=url, @@ -805,7 +821,8 @@ def _download_complete_callback(self, download_job: DownloadJob) -> None: # are there any more active jobs left in this task? if install_job.downloading and all(x.complete for x in install_job.download_parts): self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) + if install_job._do_install: + self._put_in_queue(install_job) # Let other threads know that the number of downloads has changed self._download_cache.pop(download_job.source, None) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9e30d960165..e3c99c5644a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - Default = "" # model files without "fp16" or other qualifier - empty str + Default = "" # model files without "fp16" or other qualifier FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 4e3625fdbe6..ab78b3e0640 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -83,7 +83,7 @@ def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyMod assert s.size is not None files.append( RemoteModelFile( - url=hf_hub_url(id, s.rfilename, revision=variant), + url=hf_hub_url(id, s.rfilename, revision=variant or "main"), path=Path(name, s.rfilename), size=s.size, sha256=s.lfs.get("sha256") if s.lfs else None, diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c755d3c491d..ba844552408 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( + HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, @@ -21,7 +22,13 @@ URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException -from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType +from invokeai.backend.model_manager.config import ( + BaseModelType, + InvalidModelConfigException, + ModelFormat, + ModelRepoVariant, + ModelType, +) from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 OS = platform.uname().system @@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: @pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: +def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) bus = mm2_installer.event_bus @@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co } +@pytest.mark.timeout(timeout=20, method="thread") +def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None: + source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) + job = mm2_installer.download_diffusers_model(source, tmp_path) + mm2_installer.wait_for_installs(timeout=5) + print(job.local_path) + assert job.status == InstallStatus.DOWNLOADS_DONE + assert (tmp_path / "sdxl-turbo").exists() + assert (tmp_path / "sdxl-turbo" / "model_index.json").exists() + + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -327,7 +345,7 @@ def raise_runtime_error(*args, **kwargs): }, ], ) -@pytest.mark.timeout(timeout=40, method="thread") +@pytest.mark.timeout(timeout=20, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params From 0bf14c2830c915f635dd624bc6816d1f963dcd88 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 12 May 2024 20:14:00 -0600 Subject: [PATCH 21/45] add multifile_download() method to download service --- invokeai/app/services/download/__init__.py | 9 +- .../app/services/download/download_base.py | 106 ++++++--- .../app/services/download/download_default.py | 201 ++++++++++++++---- .../model_install/model_install_default.py | 1 - .../model_manager/metadata/metadata_base.py | 2 +- .../services/download/test_download_queue.py | 112 +++++++++- 6 files changed, 349 insertions(+), 82 deletions(-) diff --git a/invokeai/app/services/download/__init__.py b/invokeai/app/services/download/__init__.py index 371c531387d..33b0025809c 100644 --- a/invokeai/app/services/download/__init__.py +++ b/invokeai/app/services/download/__init__.py @@ -1,10 +1,17 @@ """Init file for download queue.""" -from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException +from .download_base import ( + DownloadJob, + DownloadJobStatus, + DownloadQueueServiceBase, + MultiFileDownloadJob, + UnknownJobIDException, +) from .download_default import DownloadQueueService, TqdmProgress __all__ = [ "DownloadJob", + "MultiFileDownloadJob", "DownloadQueueServiceBase", "DownloadQueueService", "TqdmProgress", diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 2ac13b825fe..9bb2300eb64 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -5,11 +5,13 @@ from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Set from pydantic import BaseModel, Field, PrivateAttr from pydantic.networks import AnyHttpUrl +from invokeai.backend.model_manager.metadata import RemoteModelFile + class DownloadJobStatus(str, Enum): """State of a download job.""" @@ -33,30 +35,19 @@ class ServiceInactiveException(Exception): """This exception is raised when user attempts to initiate a download before the service is started.""" -DownloadEventHandler = Callable[["DownloadJob"], None] -DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] +DownloadEventHandler = Callable[["DownloadJobBase"], None] +DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None] +MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None] +MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None] -@total_ordering -class DownloadJob(BaseModel): - """Class to monitor and control a model download request.""" - # required variables to be passed in on creation - source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") - dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path") - access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") - # automatically assigned on creation - id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel - priority: int = Field(default=10, description="Queue priority; lower values are higher priority") +class DownloadJobBase(BaseModel): + """Base of classes to monitor and control downloads.""" - # set internally during download process + dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") + download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") - download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file") - job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") - job_ended: Optional[str] = Field( - default=None, description="Timestamp for when the download job ende1d (completed or errored)" - ) - content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") bytes: int = Field(default=0, description="Bytes downloaded so far") total_bytes: int = Field(default=0, description="Total file size (bytes)") @@ -74,14 +65,6 @@ class DownloadJob(BaseModel): _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) _on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None) - def __hash__(self) -> int: - """Return hash of the string representation of this object, for indexing.""" - return hash(str(self)) - - def __le__(self, other: "DownloadJob") -> bool: - """Return True if this job's priority is less than another's.""" - return self.priority <= other.priority - def cancel(self) -> None: """Call to cancel the job.""" self._cancelled = True @@ -98,6 +81,11 @@ def complete(self) -> bool: """Return true if job completed without errors.""" return self.status == DownloadJobStatus.COMPLETED + @property + def waiting(self) -> bool: + """Return true if the job is waiting to run.""" + return self.status == DownloadJobStatus.WAITING + @property def running(self) -> bool: """Return true if the job is running.""" @@ -154,6 +142,39 @@ def set_callbacks( self._on_cancelled = on_cancelled +@total_ordering +class DownloadJob(DownloadJobBase): + """Class to monitor and control a model download request.""" + + # required variables to be passed in on creation + source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") + access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") + # automatically assigned on creation + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + priority: int = Field(default=10, description="Queue priority; lower values are higher priority") + + # set internally during download process + job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") + job_ended: Optional[str] = Field( + default=None, description="Timestamp for when the download job ende1d (completed or errored)" + ) + content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") + + def __hash__(self) -> int: + """Return hash of the string representation of this object, for indexing.""" + return hash(str(self)) + + def __le__(self, other: "DownloadJob") -> bool: + """Return True if this job's priority is less than another's.""" + return self.priority <= other.priority + + +class MultiFileDownloadJob(DownloadJobBase): + """Class to monitor and control multifile downloads.""" + + download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.") + + class DownloadQueueServiceBase(ABC): """Multithreaded queue for downloading models via URL.""" @@ -201,6 +222,33 @@ def download( """ pass + @abstractmethod + def multifile_download( + self, + parts: Set[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + """ + Create and enqueue a multifile download job. + + :param parts: Set of URL / filename pairs + :param dest: Path to download to. See below. + :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated + events. + :returns: A MultiFileDownloadJob object for monitoring the state of the download. + + The `dest` argument is a Path object pointing to a directory. All downloads + with be placed inside this directory. The callbacks will receive the + MultiFileDownloadJob. + """ + pass + @abstractmethod def submit_download_job( self, @@ -262,7 +310,7 @@ def join(self) -> None: pass @abstractmethod - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: """Wait until the indicated download job has reached a terminal state. This will block until the indicated install job has completed, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index d9ab2c7f351..9ab452c1ef3 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -18,6 +18,7 @@ from invokeai.app.services.config import InvokeAIAppConfig, get_config from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp +from invokeai.backend.model_manager.metadata import RemoteModelFile from invokeai.backend.util.logging import InvokeAILogger from .download_base import ( @@ -27,6 +28,9 @@ DownloadJobCancelledException, DownloadJobStatus, DownloadQueueServiceBase, + MultiFileDownloadEventHandler, + MultiFileDownloadExceptionHandler, + MultiFileDownloadJob, ServiceInactiveException, UnknownJobIDException, ) @@ -54,10 +58,11 @@ def __init__( """ self._app_config = app_config or get_config() self._jobs: Dict[int, DownloadJob] = {} + self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {} self._next_job_id = 0 self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() - self._job_completed_event = threading.Event() + self._job_terminated_event = threading.Event() self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -155,6 +160,49 @@ def download( ) return job + def multifile_download( + self, + parts: Set[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + on_start: Optional[MultiFileDownloadEventHandler] = None, + on_progress: Optional[MultiFileDownloadEventHandler] = None, + on_complete: Optional[MultiFileDownloadEventHandler] = None, + on_cancelled: Optional[MultiFileDownloadEventHandler] = None, + on_error: Optional[MultiFileDownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + mfdj = MultiFileDownloadJob(dest=dest) + mfdj.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + + for part in parts: + url = part.url + path = dest / part.path + assert path.is_relative_to(dest), "only relative download paths accepted" + job = DownloadJob( + source=url, + dest=path, + access_token=access_token, + ) + mfdj.download_parts.add(job) + self._download_part2parent[job.source] = mfdj + + for download_job in mfdj.download_parts: + self.submit_download_job( + download_job, + on_start=self._mfd_started, + on_progress=self._mfd_progress, + on_complete=self._mfd_complete, + on_cancelled=self._mfd_cancelled, + on_error=self._mfd_error, + ) + return mfdj + def join(self) -> None: """Wait for all jobs to complete.""" self._queue.join() @@ -187,7 +235,7 @@ def cancel_job(self, job: DownloadJob) -> None: If it is running it will be stopped. job.status will be set to DownloadJobStatus.CANCELLED """ - with self._lock: + if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]: job.cancel() def cancel_all_jobs(self) -> None: @@ -196,12 +244,12 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=0.25): # in case we miss an event - self._job_completed_event.clear() + if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event + self._job_terminated_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") return job @@ -230,22 +278,25 @@ def _download_next_item(self) -> None: job.job_started = get_iso_timestamp() self._do_download(job) self._signal_job_complete(job) - except (OSError, HTTPError) as excp: - job.error_type = excp.__class__.__name__ + f"({str(excp)})" - job.error = traceback.format_exc() - self._signal_job_error(job, excp) except DownloadJobCancelledException: self._signal_job_cancelled(job) self._cleanup_cancelled_job(job) - + except Exception as excp: + job.error_type = excp.__class__.__name__ + f"({str(excp)})" + job.error = traceback.format_exc() + self._signal_job_error(job, excp) finally: job.job_ended = get_iso_timestamp() - self._job_completed_event.set() # signal a change to terminal state + self._job_terminated_event.set() # signal a change to terminal state + self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it + self._job_terminated_event.set() self._queue.task_done() + self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") def _do_download(self, job: DownloadJob) -> None: """Do the actual download.""" + url = job.source header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" @@ -339,7 +390,6 @@ def _in_progress_path(self, path: Path) -> Path: def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: # Pull the token from config if it exists and matches the URL - print(self._app_config) token = None for pair in self._app_config.remote_api_tokens or []: if re.search(pair.url_regex, str(source)): @@ -349,25 +399,13 @@ def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: def _signal_job_started(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.RUNNING - if job.on_start: - try: - job.on_start(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_start") if self._event_bus: assert job.download_path self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix()) def _signal_job_progress(self, job: DownloadJob) -> None: - if job.on_progress: - try: - job.on_progress(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_progress") if self._event_bus: assert job.download_path self._event_bus.emit_download_progress( @@ -379,13 +417,7 @@ def _signal_job_progress(self, job: DownloadJob) -> None: def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED - if job.on_complete: - try: - job.on_complete(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_complete") if self._event_bus: assert job.download_path self._event_bus.emit_download_complete( @@ -396,26 +428,21 @@ def _signal_job_cancelled(self, job: DownloadJob) -> None: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: return job.status = DownloadJobStatus.CANCELLED - if job.on_cancelled: - try: - job.on_cancelled(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_cancelled") if self._event_bus: self._event_bus.emit_download_cancelled(str(job.source)) + # if multifile download, then signal the parent + if parent_job := self._download_part2parent.get(job.source, None): + if not parent_job.in_terminal_state: + parent_job.status = DownloadJobStatus.CANCELLED + self._execute_cb(parent_job, "on_cancelled") + def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: job.status = DownloadJobStatus.ERROR self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}") - if job.on_error: - try: - job.on_error(job, excp) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_error", excp) + if self._event_bus: assert job.error_type assert job.error @@ -430,6 +457,86 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: except OSError as excp: self._logger.warning(excp) + ######################################## + # callbacks used for multifile downloads + ######################################## + def _mfd_started(self, download_job: DownloadJob) -> None: + self._logger.info(f"File download started: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.waiting: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.status = DownloadJobStatus.RUNNING + self._execute_cb(mf_job, "on_start") + + def _mfd_progress(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.cancelled: + for part in mf_job.download_parts: + self.cancel_job(part) + elif mf_job.running: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts) + self._execute_cb(mf_job, "on_progress") + + def _mfd_complete(self, download_job: DownloadJob) -> None: + self._logger.info(f"Download complete: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + + # are there any more active jobs left in this task? + if mf_job.running and all(x.complete for x in mf_job.download_parts): + mf_job.status = DownloadJobStatus.COMPLETED + self._execute_cb(mf_job, "on_complete") + + # we're done with this sub-job + self._job_terminated_event.set() + + def _mfd_cancelled(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + + if not mf_job.in_terminal_state: + self._logger.warning(f"Download cancelled: {download_job.source}") + mf_job.cancel() + + for s in mf_job.download_parts: + self.cancel_job(s) + + def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + if not mf_job.in_terminal_state: + mf_job.status = download_job.status + mf_job.error = download_job.error + mf_job.error_type = download_job.error_type + self._execute_cb(mf_job, "on_error", excp) + self._logger.error( + f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}" + ) + for s in [x for x in mf_job.download_parts if x.running]: + self.cancel_job(s) + self._download_part2parent.pop(download_job.source) + self._job_terminated_event.set() + + def _execute_cb( + self, + job: DownloadJob | MultiFileDownloadJob, + callback_name: str, + excp: Optional[Exception] = None, + ) -> None: + if callback := getattr(job, callback_name, None): + args = [job, excp] if excp else [job] + try: + callback(*args) + except Exception as e: + self._logger.error( + f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}" + ) + def get_pc_name_max(directory: str) -> int: if hasattr(os, "pathconf"): diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index fe932649c4c..f59c7b9f850 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -689,7 +689,6 @@ def _import_from_url( self._logger.debug(f"metadata={metadata}") if metadata and isinstance(metadata, ModelMetadataWithFiles): remote_files = metadata.download_urls(session=self._session) - print(remote_files) else: remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] return self._import_remote_model( diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 585c0fa31cb..4abf020538b 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -37,7 +37,7 @@ class RemoteModelFile(BaseModel): url: AnyHttpUrl = Field(description="The url to download this model file") path: Path = Field(description="The path to the file, relative to the model root") - size: int = Field(description="The size of this file, in bytes") + size: Optional[int] = Field(description="The size of this file, in bytes", default=0) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 07c473b1832..578dbd29074 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager from pathlib import Path -from typing import Generator +from typing import Generator, Optional import pytest from pydantic.networks import AnyHttpUrl @@ -13,7 +13,8 @@ from invokeai.app.services.config import get_config from invokeai.app.services.config.config_default import URLRegexTokenPair -from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob +from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService @@ -67,11 +68,116 @@ def session() -> Session: return sess +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + print(f"bytes = {job.bytes}") + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert Path( + tmp_path, "sdxl-turbo/model_index.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" + assert Path( + tmp_path, "sdxl-turbo/text_encoder/config.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" + + assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + files = metadata.download_urls(session=mm2_session) + # this will give a 404 error + files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken"))) + job = queue.multifile_download( + parts=files, + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + queue.join() + + assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert "HTTPError(NOT FOUND)" in job.error_type + assert DownloadJobStatus.ERROR in events + queue.stop() + + +@pytest.mark.timeout(timeout=15, method="thread") +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: + event_bus = TestEventService() + + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) + queue.start() + + cancelled = False + + def cancelled_callback(job: DownloadJob) -> None: + nonlocal cancelled + cancelled = True + + def handler(signum, frame): + raise TimeoutError("Join took too long to return") + + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_cancelled=cancelled_callback, + ) + queue.cancel_job(job) + queue.join() + + assert job.status == DownloadJobStatus.CANCELLED + assert cancelled + events = event_bus.events + assert "download_cancelled" in [x.event_name for x in events] + queue.stop() + + @pytest.mark.timeout(timeout=20, method="thread") def test_basic_queue_download(tmp_path: Path, session: Session) -> None: events = set() - def event_handler(job: DownloadJob) -> None: + def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: events.add(job.status) queue = DownloadQueueService( From 287c679f7b14b8faa00e93110b9b03a0ba46c35c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 13 May 2024 18:31:40 -0400 Subject: [PATCH 22/45] clean up type checking for single file and multifile download job callbacks --- .../app/services/download/download_base.py | 25 +- .../app/services/download/download_default.py | 19 +- .../services/download/test_download_queue.py | 216 +++++++++--------- 3 files changed, 137 insertions(+), 123 deletions(-) diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 9bb2300eb64..3e415091c7c 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -5,7 +5,7 @@ from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Any, Callable, List, Optional, Set +from typing import Any, Callable, List, Optional, Set, Union from pydantic import BaseModel, Field, PrivateAttr from pydantic.networks import AnyHttpUrl @@ -35,12 +35,12 @@ class ServiceInactiveException(Exception): """This exception is raised when user attempts to initiate a download before the service is started.""" -DownloadEventHandler = Callable[["DownloadJobBase"], None] -DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None] - +SingleFileDownloadEventHandler = Callable[["DownloadJob"], None] +SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None] MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None] - +DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] +DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] class DownloadJobBase(BaseModel): """Base of classes to monitor and control downloads.""" @@ -228,6 +228,7 @@ def multifile_download( parts: Set[RemoteModelFile], dest: Path, access_token: Optional[str] = None, + submit_job: bool = True, on_start: Optional[DownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None, @@ -239,6 +240,11 @@ def multifile_download( :param parts: Set of URL / filename pairs :param dest: Path to download to. See below. + :param access_token: Access token to download the indicated files. If not provided, + each file's URL may be matched to an access token using the config file matching + system. + :param submit_job: If true [default] then submit the job for execution. Otherwise, + you will need to pass the job to submit_multifile_download(). :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated events. :returns: A MultiFileDownloadJob object for monitoring the state of the download. @@ -249,6 +255,15 @@ def multifile_download( """ pass + @abstractmethod + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + """ + Enqueue a previously-created multi-file download job. + + :param job: A MultiFileDownloadJob created with multifile_download() + """ + pass + @abstractmethod def submit_download_job( self, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 9ab452c1ef3..3f55e9a2540 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -25,11 +25,10 @@ DownloadEventHandler, DownloadExceptionHandler, DownloadJob, + DownloadJobBase, DownloadJobCancelledException, DownloadJobStatus, DownloadQueueServiceBase, - MultiFileDownloadEventHandler, - MultiFileDownloadExceptionHandler, MultiFileDownloadJob, ServiceInactiveException, UnknownJobIDException, @@ -165,11 +164,11 @@ def multifile_download( parts: Set[RemoteModelFile], dest: Path, access_token: Optional[str] = None, - on_start: Optional[MultiFileDownloadEventHandler] = None, - on_progress: Optional[MultiFileDownloadEventHandler] = None, - on_complete: Optional[MultiFileDownloadEventHandler] = None, - on_cancelled: Optional[MultiFileDownloadEventHandler] = None, - on_error: Optional[MultiFileDownloadExceptionHandler] = None, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, ) -> MultiFileDownloadJob: mfdj = MultiFileDownloadJob(dest=dest) mfdj.set_callbacks( @@ -191,8 +190,11 @@ def multifile_download( ) mfdj.download_parts.add(job) self._download_part2parent[job.source] = mfdj + self.submit_multifile_download(mfdj) + return mfdj - for download_job in mfdj.download_parts: + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + for download_job in job.download_parts: self.submit_download_job( download_job, on_start=self._mfd_started, @@ -201,7 +203,6 @@ def multifile_download( on_cancelled=self._mfd_cancelled, on_error=self._mfd_error, ) - return mfdj def join(self) -> None: """Wait for all jobs to complete.""" diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 578dbd29074..393cd54a033 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -69,111 +69,6 @@ def session() -> Session: @pytest.mark.timeout(timeout=10, method="thread") -def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: - fetcher = HuggingFaceMetadataFetch(mm2_session) - metadata = fetcher.from_id("stabilityai/sdxl-turbo") - events = set() - - def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: - print(f"bytes = {job.bytes}") - events.add(job.status) - - queue = DownloadQueueService( - requests_session=mm2_session, - ) - queue.start() - job = queue.multifile_download( - parts=metadata.download_urls(session=mm2_session), - dest=tmp_path, - on_start=event_handler, - on_progress=event_handler, - on_complete=event_handler, - on_error=event_handler, - ) - assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" - queue.join() - - assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" - assert job.bytes > 0, "expected download bytes to be positive" - assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" - assert Path( - tmp_path, "sdxl-turbo/model_index.json" - ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" - assert Path( - tmp_path, "sdxl-turbo/text_encoder/config.json" - ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" - - assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} - queue.stop() - - -@pytest.mark.timeout(timeout=10, method="thread") -def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: - fetcher = HuggingFaceMetadataFetch(mm2_session) - metadata = fetcher.from_id("stabilityai/sdxl-turbo") - events = set() - - def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: - events.add(job.status) - - queue = DownloadQueueService( - requests_session=mm2_session, - ) - queue.start() - files = metadata.download_urls(session=mm2_session) - # this will give a 404 error - files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken"))) - job = queue.multifile_download( - parts=files, - dest=tmp_path, - on_start=event_handler, - on_progress=event_handler, - on_complete=event_handler, - on_error=event_handler, - ) - queue.join() - - assert job.status == DownloadJobStatus("error"), "expected job status to be errored" - assert "HTTPError(NOT FOUND)" in job.error_type - assert DownloadJobStatus.ERROR in events - queue.stop() - - -@pytest.mark.timeout(timeout=15, method="thread") -def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: - event_bus = TestEventService() - - queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) - queue.start() - - cancelled = False - - def cancelled_callback(job: DownloadJob) -> None: - nonlocal cancelled - cancelled = True - - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - - fetcher = HuggingFaceMetadataFetch(mm2_session) - metadata = fetcher.from_id("stabilityai/sdxl-turbo") - - job = queue.multifile_download( - parts=metadata.download_urls(session=mm2_session), - dest=tmp_path, - on_cancelled=cancelled_callback, - ) - queue.cancel_job(job) - queue.join() - - assert job.status == DownloadJobStatus.CANCELLED - assert cancelled - events = event_bus.events - assert "download_cancelled" in [x.event_name for x in events] - queue.stop() - - -@pytest.mark.timeout(timeout=20, method="thread") def test_basic_queue_download(tmp_path: Path, session: Session) -> None: events = set() @@ -203,7 +98,7 @@ def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_errors(tmp_path: Path, session: Session) -> None: queue = DownloadQueueService( requests_session=session, @@ -225,7 +120,7 @@ def test_errors(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_event_bus(tmp_path: Path, session: Session) -> None: event_bus = TestEventService() @@ -261,7 +156,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: queue = DownloadQueueService( requests_session=session, @@ -293,7 +188,7 @@ def broken_callback(job: DownloadJob) -> None: queue.stop() -@pytest.mark.timeout(timeout=15, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_cancel(tmp_path: Path, session: Session) -> None: event_bus = TestEventService() @@ -328,6 +223,109 @@ def handler(signum, frame): assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + print(f"bytes = {job.bytes}") + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert Path( + tmp_path, "sdxl-turbo/model_index.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" + assert Path( + tmp_path, "sdxl-turbo/text_encoder/config.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" + + assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + files = metadata.download_urls(session=mm2_session) + # this will give a 404 error + files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken"))) + job = queue.multifile_download( + parts=files, + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + queue.join() + + assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert "HTTPError(NOT FOUND)" in job.error_type + assert DownloadJobStatus.ERROR in events + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: + event_bus = TestEventService() + + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) + queue.start() + + cancelled = False + + def cancelled_callback(job: DownloadJob) -> None: + nonlocal cancelled + cancelled = True + + def handler(signum, frame): + raise TimeoutError("Join took too long to return") + + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_cancelled=cancelled_callback, + ) + queue.cancel_job(job) + queue.join() + + assert job.status == DownloadJobStatus.CANCELLED + assert cancelled + events = event_bus.events + assert "download_cancelled" in [x.event_name for x in events] + queue.stop() @contextmanager def clear_config() -> Generator[None, None, None]: From f29c406fed4f398a4134643f356f402e9a4426f5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 13 May 2024 22:49:15 -0400 Subject: [PATCH 23/45] refactor model_install to work with refactored download queue --- .../app/services/download/download_base.py | 12 +- .../app/services/download/download_default.py | 45 +++-- .../model_install/model_install_base.py | 24 +-- .../model_install/model_install_default.py | 167 +++++++----------- .../model_manager/metadata/metadata_base.py | 3 + .../services/download/test_download_queue.py | 117 +++++------- .../model_install/test_model_install.py | 37 +++- .../model_manager/model_manager_fixtures.py | 41 +++++ 8 files changed, 226 insertions(+), 220 deletions(-) diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 3e415091c7c..4880ab98b89 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -42,9 +42,13 @@ class ServiceInactiveException(Exception): DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] + class DownloadJobBase(BaseModel): """Base of classes to monitor and control downloads.""" + # automatically assigned on creation + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") @@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase): # required variables to be passed in on creation source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") - # automatically assigned on creation - id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel priority: int = Field(default=10, description="Queue priority; lower values are higher priority") # set internally during download process @@ -225,7 +227,7 @@ def download( @abstractmethod def multifile_download( self, - parts: Set[RemoteModelFile], + parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True, @@ -315,7 +317,7 @@ def prune_jobs(self) -> None: pass @abstractmethod - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """Cancel the job, clearing partial downloads and putting it into ERROR state.""" pass @@ -325,7 +327,7 @@ def join(self) -> None: pass @abstractmethod - def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Wait until the indicated download job has reached a terminal state. This will block until the indicated install job has completed, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 3f55e9a2540..4555477004f 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -113,18 +113,16 @@ def submit_download_job( raise ServiceInactiveException( "The download service is not currently accepting requests. Please call start() to initialize the service." ) - with self._lock: - job.id = self._next_job_id - self._next_job_id += 1 - job.set_callbacks( - on_start=on_start, - on_progress=on_progress, - on_complete=on_complete, - on_cancelled=on_cancelled, - on_error=on_error, - ) - self._jobs[job.id] = job - self._queue.put(job) + job.id = self._next_id() + job.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + self._jobs[job.id] = job + self._queue.put(job) def download( self, @@ -161,16 +159,17 @@ def download( def multifile_download( self, - parts: Set[RemoteModelFile], + parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, + submit_job: bool = True, on_start: Optional[DownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None, on_cancelled: Optional[DownloadEventHandler] = None, on_error: Optional[DownloadExceptionHandler] = None, ) -> MultiFileDownloadJob: - mfdj = MultiFileDownloadJob(dest=dest) + mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id()) mfdj.set_callbacks( on_start=on_start, on_progress=on_progress, @@ -190,7 +189,8 @@ def multifile_download( ) mfdj.download_parts.add(job) self._download_part2parent[job.source] = mfdj - self.submit_multifile_download(mfdj) + if submit_job: + self.submit_multifile_download(mfdj) return mfdj def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: @@ -208,6 +208,12 @@ def join(self) -> None: """Wait for all jobs to complete.""" self._queue.join() + def _next_id(self) -> int: + with self._lock: + id = self._next_job_id + self._next_job_id += 1 + return id + def list_jobs(self) -> List[DownloadJob]: """List all the jobs.""" return list(self._jobs.values()) @@ -229,7 +235,7 @@ def id_to_job(self, id: int) -> DownloadJob: except KeyError as excp: raise UnknownJobIDException("Unrecognized job") from excp - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """ Cancel the indicated job. @@ -245,7 +251,7 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) - def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: @@ -468,6 +474,11 @@ def _mfd_started(self, download_job: DownloadJob) -> None: if mf_job.waiting: mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) mf_job.status = DownloadJobStatus.RUNNING + assert download_job.download_path is not None + path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest) + mf_job.download_path = ( + mf_job.dest / path_relative_to_destdir.parts[0] + ) # keep just the first component of the path self._execute_cb(mf_job, "on_start") def _mfd_progress(self, download_job: DownloadJob) -> None: diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index ccb8e3772e2..68cf9591e0f 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -6,14 +6,14 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl from typing_extensions import Annotated from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase @@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel): source_metadata: Optional[AnyModelRepoMetadata] = Field( default=None, description="Metadata provided by the model source" ) - download_parts: Set[DownloadJob] = Field( - default_factory=set, description="Download jobs contributing to this install" - ) error: Optional[str] = Field( default=None, description="On an error condition, this field will contain the text of the exception" ) @@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _do_install: Optional[bool] = PrivateAttr(default=True) + _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: @@ -408,21 +405,6 @@ def import_model( """ - @abstractmethod - def download_diffusers_model( - self, - source: HFModelSource, - download_to: Path, - ) -> ModelInstallJob: - """ - Download, but do not install, a diffusers model. - - :param source: An HFModelSource object containing a repo_id - :param download_to: Path to directory that will contain the downloaded model. - - Returns: a ModelInstallJob - """ - @abstractmethod def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: """Return the ModelInstallJob(s) corresponding to the provided source.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index f59c7b9f850..2ad321c260b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import torch import yaml @@ -18,7 +18,7 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -89,7 +89,7 @@ def __init__( self._downloads_changed_event = threading.Event() self._install_completed_event = threading.Event() self._download_queue = download_queue - self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} + self._download_cache: Dict[int, ModelInstallJob] = {} self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -249,9 +249,6 @@ def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = N self._install_jobs.append(install_job) return install_job - def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob: - return self._import_from_hf(source, download_path=download_to) - def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs @@ -291,8 +288,9 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" job.cancel() - with self._lock: - self._cancel_download_parts(job) + self._logger.warning(f"Cancelling {job.source}") + if dj := job._download_job: + self._download_queue.cancel_job(dj) def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" @@ -340,7 +338,7 @@ def _migrate_yaml(self) -> None: legacy_config_path = stanza.get("config") if legacy_config_path: # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. - legacy_config_path: Path = self._app_config.root_path / legacy_config_path + legacy_config_path = self._app_config.root_path / legacy_config_path if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) config["config_path"] = str(legacy_config_path) @@ -476,16 +474,19 @@ def _register_or_install(self, job: ModelInstallJob) -> None: job.config_out = self.record_store.get_model(key) self._signal_job_completed(job) - def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: - if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): - job.set_error( + def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: + download_job = install_job._download_job + if download_job and any( + x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts + ): + install_job.set_error( InvalidModelConfigException( - f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." + f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) else: - job.set_error(excp) - self._signal_job_errored(job) + install_job.set_error(excp) + self._signal_job_errored(install_job) # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory @@ -511,7 +512,6 @@ def _register_orphaned_models(self) -> None: This is typically only used during testing with a new DB or when using the memory DB, because those are the only situations in which we may have orphaned models in the models directory. """ - installed_model_paths = { (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() } @@ -648,7 +648,6 @@ def _import_from_hf( self, source: HFModelSource, config: Optional[Dict[str, Any]] = None, - download_path: Optional[Path] = None, ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests source.access_token = source.access_token or HfFolder.get_token() @@ -668,7 +667,6 @@ def _import_from_hf( config=config, remote_files=remote_files, metadata=metadata, - download_path=download_path, ) def _import_from_url( @@ -704,14 +702,10 @@ def _import_remote_model( remote_files: List[RemoteModelFile], metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], - download_path: Optional[Path] = None, # if defined, download only - don't install! ) -> ModelInstallJob: - # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. - # Currently the tmpdir isn't automatically removed at exit because it is - # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - destdir = download_path or Path( + destdir = Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -726,6 +720,9 @@ def _import_remote_model( bytes=0, total_bytes=0, ) + # remember the temporary directory for later removal + install_job._install_tmpdir = destdir + # In the event that there is a subfolder specified in the source, # we need to remove it from the destination path in order to avoid # creating unwanted subfolders @@ -739,39 +736,31 @@ def _import_remote_model( # we remember the path up to the top of the destdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = destdir - install_job._do_install = download_path is None - assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") - self._logger.debug(f"remote_files={remote_files}") + parts: List[RemoteModelFile] = [] for model_file in remote_files: - url = model_file.url - path = root / model_file.path.relative_to(subfolder) - self._logger.debug(f"Downloading {url} => {path}") + assert install_job.total_bytes is not None + assert model_file.size is not None install_job.total_bytes += model_file.size - assert hasattr(source, "access_token") - dest = destdir / path.parent - dest.mkdir(parents=True, exist_ok=True) - download_job = DownloadJob( - source=url, - dest=dest, - access_token=source.access_token, - ) - self._download_cache[download_job.source] = install_job # matches a download job to an install job - install_job.download_parts.add(download_job) - - # only start the jobs once install_job.download_parts is fully populated - for download_job in install_job.download_parts: - self._download_queue.submit_download_job( - download_job, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, - ) + parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) + multifile_job = self._download_queue.multifile_download( + parts=parts, + dest=destdir, + access_token=source.access_token, + submit_job=False, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + self._download_cache[multifile_job.id] = install_job + install_job._download_job = multifile_job + files_string = "file" if len(remote_files) == 1 else "file" + self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + self._logger.debug(f"remote_files={remote_files}") + self._download_queue.submit_multifile_download(multifile_job) return install_job def _stat_size(self, path: Path) -> int: @@ -786,86 +775,59 @@ def _stat_size(self, path: Path) -> int: # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ - def _download_started_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download started: {download_job.source}") + def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] + install_job = self._download_cache[download_job.id] install_job.status = InstallStatus.DOWNLOADING assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: - partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) - dest_name = partial_path.parts[0] - install_job.local_path = install_job._install_tmpdir / dest_name + if install_job.local_path == install_job._install_tmpdir: # first time + install_job.local_path = download_job.download_path + install_job.total_bytes = download_job.total_bytes - # Update the total bytes count for remote sources. - if not install_job.total_bytes: - install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts) - - def _download_progress_callback(self, download_job: DownloadJob) -> None: + def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] + install_job = self._download_cache[download_job.id] if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._cancel_download_parts(install_job) + self._download_queue.cancel_job(download_job) else: # update sizes - install_job.bytes = sum(x.bytes for x in install_job.download_parts) + install_job.bytes = sum(x.bytes for x in download_job.download_parts) self._signal_job_downloading(install_job) - def _download_complete_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download complete: {download_job.source}") + def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - - # are there any more active jobs left in this task? - if install_job.downloading and all(x.complete for x in install_job.download_parts): - self._signal_job_downloads_done(install_job) - if install_job._do_install: - self._put_in_queue(install_job) + install_job = self._download_cache.pop(download_job.id) + self._signal_job_downloads_done(install_job) + self._put_in_queue(install_job) # this starts the installation and registration # Let other threads know that the number of downloads has changed - self._download_cache.pop(download_job.source, None) self._downloads_changed_event.set() - def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) + install_job = self._download_cache.pop(download_job.id) assert install_job is not None assert excp is not None install_job.set_error(excp) - self._logger.error( - f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}" - ) - self._cancel_download_parts(install_job) + self._download_queue.cancel_job(download_job) # Let other threads know that the number of downloads has changed self._downloads_changed_event.set() - def _download_cancelled_callback(self, download_job: DownloadJob) -> None: + def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) + install_job = self._download_cache.pop(download_job.id, None) if not install_job: return self._downloads_changed_event.set() - self._logger.warning(f"Model download canceled: {download_job.source}") # if install job has already registered an error, then do not replace its status with cancelled if not install_job.errored: install_job.cancel() - self._cancel_download_parts(install_job) # Let other threads know that the number of downloads has changed self._downloads_changed_event.set() - def _cancel_download_parts(self, install_job: ModelInstallJob) -> None: - # on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks - # do not lock here because it gets called within a locked context - for s in install_job.download_parts: - self._download_queue.cancel_job(s) - - if all(x.in_terminal_state for x in install_job.download_parts): - # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources - self._put_in_queue(install_job) - # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus # ------------------------------------------------------------------------------------------------ @@ -877,6 +839,7 @@ def _signal_job_running(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: + assert job._download_job is not None parts: List[Dict[str, str | int]] = [ { "url": str(x.source), @@ -884,7 +847,7 @@ def _signal_job_downloading(self, job: ModelInstallJob) -> None: "bytes": x.bytes, "total_bytes": x.total_bytes, } - for x in job.download_parts + for x in job._download_job.download_parts ] assert job.bytes is not None assert job.total_bytes is not None @@ -929,7 +892,13 @@ def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) @staticmethod - def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: + def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: + """ + Return a metadata fetcher appropriate for provided url. + + This used to be more useful, but the number of supported model + sources has been reduced to HuggingFace alone. + """ if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): return HuggingFaceMetadataFetch raise ValueError(f"Unsupported model source: '{url}'") diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 4abf020538b..f9f5335d175 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel): size: Optional[int] = Field(description="The size of this file, in bytes", default=0) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) + def __hash__(self) -> int: + return hash(str(self)) + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 393cd54a033..564d9c30a07 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -4,79 +4,33 @@ import time from contextlib import contextmanager from pathlib import Path -from typing import Generator, Optional +from typing import Any, Generator, Optional import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from requests_testadapter import TestAdapter, TestSession +from requests_testadapter import TestAdapter from invokeai.app.services.config import get_config from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob -from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile +from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings -TestAdapter.__test__ = False # type: ignore - - -@pytest.fixture -def session() -> Session: - sess = TestSession() - for i in ["12345", "9999", "54321"]: - content = ( - b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) - ) # for pause tests, must make content large - sess.mount( - f"http://www.civitai.com/models/{i}", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": f'filename="mock{i}.safetensors"', - }, - ), - ) - - sess.mount( - "http://www.huggingface.co/foo.txt", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": 'filename="foo.safetensors"', - }, - ), - ) - - # here are some malformed URLs to test - # missing the content length - sess.mount( - "http://www.civitai.com/models/missing", - TestAdapter( - b"Missing content length", - headers={ - "Content-Disposition": 'filename="missing.txt"', - }, - ), - ) - # not found test - sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) - - return sess +TestAdapter.__test__ = False @pytest.mark.timeout(timeout=10, method="thread") -def test_basic_queue_download(tmp_path: Path, session: Session) -> None: +def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None: events = set() def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: events.add(job.status) queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() job = queue.download( @@ -92,6 +46,7 @@ def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: queue.join() assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.download_path == tmp_path / "mock12345.safetensors" assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} @@ -99,9 +54,9 @@ def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_errors(tmp_path: Path, session: Session) -> None: +def test_errors(tmp_path: Path, mm2_session: Session) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_event_bus(tmp_path: Path, session: Session) -> None: +def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), @@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: +def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -189,10 +144,10 @@ def broken_callback(job: DownloadJob) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_cancel(tmp_path: Path, session: Session) -> None: +def test_cancel(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() cancelled = False @@ -204,9 +159,6 @@ def cancelled_callback(job: DownloadJob) -> None: nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, @@ -223,14 +175,15 @@ def handler(signum, frame): assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() + @pytest.mark.timeout(timeout=10, method="thread") def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) events = set() def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: - print(f"bytes = {job.bytes}") events.add(job.status) queue = DownloadQueueService( @@ -251,6 +204,7 @@ def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Except assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" assert job.bytes > 0, "expected download bytes to be positive" assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "sdxl-turbo" assert Path( tmp_path, "sdxl-turbo/model_index.json" ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" @@ -266,6 +220,7 @@ def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Except def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) events = set() def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: @@ -289,13 +244,14 @@ def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Except queue.join() assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert job.error_type is not None assert "HTTPError(NOT FOUND)" in job.error_type assert DownloadJobStatus.ERROR in events queue.stop() @pytest.mark.timeout(timeout=10, method="thread") -def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None: event_bus = TestEventService() queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) @@ -307,11 +263,9 @@ def cancelled_callback(job: DownloadJob) -> None: nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) job = queue.multifile_download( parts=metadata.download_urls(session=mm2_session), @@ -327,6 +281,29 @@ def handler(signum, frame): assert "download_cancelled" in [x.event_name for x in events] queue.stop() + +def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=[ + RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors")) + ], + dest=tmp_path, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "mock12345.safetensors" + assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" + queue.stop() + + @contextmanager def clear_config() -> Generator[None, None, None]: try: @@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]: get_config.cache_clear() -def test_tokens(tmp_path: Path, session: Session): +def test_tokens(tmp_path: Path, mm2_session: Session): with clear_config(): config = get_config() config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] - queue = DownloadQueueService(requests_session=session) + queue = DownloadQueueService(requests_session=mm2_session) queue.start() # this one has an access token assigned job1 = queue.download( diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index ba844552408..31d09d1029e 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con @pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None: +def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: + # TODO: Test subfolder download source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) - job = mm2_installer.download_diffusers_model(source, tmp_path) - mm2_installer.wait_for_installs(timeout=5) - print(job.local_path) - assert job.status == InstallStatus.DOWNLOADS_DONE - assert (tmp_path / "sdxl-turbo").exists() - assert (tmp_path / "sdxl-turbo" / "model_index.json").exists() + + bus = mm2_installer.event_bus + store = mm2_installer.record_store + assert isinstance(bus, EventServiceBase) + assert store is not None + + job = mm2_installer.import_model(source) + job_list = mm2_installer.wait_for_installs(timeout=10) + assert len(job_list) == 1 + assert job.complete + assert job.config_out + + key = job.config_out.key + model_record = store.get_model(key) + assert (mm2_app_config.models_path / model_record.path).exists() + assert model_record.type == ModelType.Main + assert model_record.format == ModelFormat.Diffusers + + assert hasattr(bus, "events") # the dummyeventservice has this + assert len(bus.events) >= 3 + event_names = {x.event_name for x in bus.events} + assert event_names == { + "model_install_downloading", + "model_install_downloads_done", + "model_install_running", + "model_install_completed", + } def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: @@ -327,7 +349,6 @@ def raise_runtime_error(*args, **kwargs): assert job.error == "Test error" -# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test @pytest.mark.parametrize( "model_params", [ diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 980f6ea17b2..0301101a195 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: }, ), ) + + for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + sess.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + + # here are some malformed URLs to test + # missing the content length + sess.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), + ) + # not found test + sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + return sess From 911a24479b15e4858beda43817d91b22e6a32b51 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 16 May 2024 07:18:33 -0400 Subject: [PATCH 24/45] add tests for model install file size reporting --- .../app/services/model_install/model_install_default.py | 9 ++++++--- tests/app/services/model_install/test_model_install.py | 5 +++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2ad321c260b..1d77b2c6e1e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -855,8 +855,8 @@ def _signal_job_downloading(self, job: ModelInstallJob) -> None: str(job.source), local_path=job.local_path.as_posix(), parts=parts, - bytes=job.bytes, - total_bytes=job.total_bytes, + bytes=sum(x["bytes"] for x in parts), + total_bytes=sum(x["total_bytes"] for x in parts), id=job.id, ) @@ -875,7 +875,10 @@ def _signal_job_completed(self, job: ModelInstallJob) -> None: assert job.local_path is not None assert job.config_out is not None key = job.config_out.key - self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) + self._event_bus.emit_model_install_completed(source=str(job.source), + key=key, + id=job.id, + total_bytes=job.bytes) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 31d09d1029e..f73b8275343 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -317,6 +317,11 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con "model_install_completed", } + completed_events = [x for x in bus.events if x.event_name == "model_install_completed"] + downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"] + assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"] + assert job.total_bytes == completed_events[0].payload["total_bytes"] + assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) From 2dae5eb7ad6da3fb7e90a948e903430eacc507eb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 16 May 2024 22:26:18 -0400 Subject: [PATCH 25/45] more refactoring; HF subfolders not working --- docs/contributing/MODEL_MANAGER.md | 15 +- .../model_install/model_install_base.py | 9 +- .../model_install/model_install_default.py | 169 ++++++++++-------- .../model_records/model_records_base.py | 8 +- .../model_install/test_model_install.py | 14 +- 5 files changed, 113 insertions(+), 102 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index c12046293c0..d53198b98e4 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the following initialization pattern: ``` -from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config import get_config from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.download import DownloadQueueService -from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.util.logging import InvokeAILogger -config = InvokeAIAppConfig.get_config() -config.parse_args() +config = get_config() logger = InvokeAILogger.get_logger(config=config) -db = SqliteDatabase(config, logger) +db = SqliteDatabase(config.db_path, logger) record_store = ModelRecordServiceSQL(db) queue = DownloadQueueService() queue.start() -installer = ModelInstallService(app_config=config, +installer = ModelInstallService(app_config=config, record_store=record_store, - download_queue=queue - ) + download_queue=queue + ) installer.start() ``` diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 68cf9591e0f..b622c8dade0 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -466,17 +466,14 @@ def sync_model_path(self, key: str) -> AnyModelConfig: """ @abstractmethod - def download_and_cache_ckpt( + def download_and_cache_model( self, - source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: int = 0, + source: str, ) -> Path: """ Download the model file located at source to the models cache and return its Path. - :param source: A Url or a string that can be converted into one. - :param access_token: Optional access token to access restricted resources. + :param source: A string representing a URL or repo_id. The model file will be downloaded into the system-wide model cache (`models/.cache`) if it isn't already there. Note that the model cache diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1d77b2c6e1e..a6bb7ad10da 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml @@ -18,7 +18,7 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -208,26 +208,12 @@ def heuristic_import( access_token: Optional[str] = None, inplace: Optional[bool] = False, ) -> ModelInstallJob: - variants = "|".join(ModelRepoVariant.__members__.values()) - hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" - source_obj: Optional[StringLikeSource] = None - - if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source), inplace=inplace) - elif match := re.match(hf_repoid_re, source): - source_obj = HFModelSource( - repo_id=match.group(1), - variant=match.group(2) if match.group(2) else None, # pass None rather than '' - subfolder=Path(match.group(3)) if match.group(3) else None, - access_token=access_token, - ) - elif re.match(r"^https?://[^/]+", source): - source_obj = URLModelSource( - url=AnyHttpUrl(source), - access_token=access_token, - ) - else: - raise ValueError(f"Unsupported model source: '{source}'") + """Install a model using pattern matching to infer the type of source.""" + source_obj = self._guess_source(source) + if isinstance(source_obj, LocalModelSource): + source_obj.inplace = inplace + elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource): + source_obj.access_token = access_token return self.import_model(source_obj, config) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 @@ -383,37 +369,86 @@ def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: Invoke escaped_source = slugify(str(source)) return app_config.download_cache_path / escaped_source - def download_and_cache_ckpt( + def download_and_cache_model( self, - source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: int = 0, + source: str, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_path = self._download_cache_path(source, self._app_config) + model_path = self._download_cache_path(str(source), self._app_config) - # We expect the cache directory to contain one and only one downloaded file. + # We expect the cache directory to contain one and only one downloaded file or directory. # We don't know the file's name in advance, as it is set by the download # content-disposition header. if model_path.exists(): - contents = [x for x in model_path.iterdir() if x.is_file()] + contents: List[Path] = list(model_path.iterdir()) if len(contents) > 0: return contents[0] model_path.mkdir(parents=True, exist_ok=True) - job = self._download_queue.download( - source=AnyHttpUrl(str(source)), + model_source = self._guess_source(source) + remote_files, _ = self._remote_files_from_source(model_source) + job = self._download_queue.multifile_download( + parts=remote_files, dest=model_path, - access_token=access_token, - on_progress=TqdmProgress().update, ) - self._download_queue.wait_for_job(job, timeout) + files_string = "file" if len(remote_files) == 1 else "file" + self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None return job.download_path else: raise Exception(job.error) + def _remote_files_from_source( + self, source: ModelSource + ) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]: + metadata = None + if isinstance(source, HFModelSource): + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls( + variant=source.variant or self._guess_variant(), + subfolder=source.subfolder, + session=self._session, + ), metadata + + if isinstance(source, URLModelSource): + try: + fetcher = self.get_fetcher_from_url(str(source.url)) + kwargs: dict[str, Any] = {"session": self._session} + metadata = fetcher(**kwargs).from_url(source.url) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls(session=self._session), metadata + except ValueError: + pass + + return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None + + raise Exception(f"No files associated with {source}") + + def _guess_source(self, source: str) -> ModelSource: + """Turn a source string into a ModelSource object.""" + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return source_obj + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -650,18 +685,9 @@ def _import_from_hf( config: Optional[Dict[str, Any]] = None, ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests - source.access_token = source.access_token or HfFolder.get_token() - if not source.access_token: - self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) - assert isinstance(metadata, ModelMetadataWithFiles) - remote_files = metadata.download_urls( - variant=source.variant or self._guess_variant(), - subfolder=source.subfolder, - session=self._session, - ) - + if source.access_token is None: + source.access_token = HfFolder.get_token() + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -674,21 +700,7 @@ def _import_from_url( source: URLModelSource, config: Optional[Dict[str, Any]], ) -> ModelInstallJob: - # URLs from HuggingFace will be handled specially - metadata = None - fetcher = None - try: - fetcher = self.get_fetcher_from_url(str(source.url)) - except ValueError: - pass - kwargs: dict[str, Any] = {"session": self._session} - if fetcher is not None: - metadata = fetcher(**kwargs).from_url(source.url) - self._logger.debug(f"metadata={metadata}") - if metadata and isinstance(metadata, ModelMetadataWithFiles): - remote_files = metadata.download_urls(session=self._session) - else: - remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -733,26 +745,17 @@ def _import_remote_model( root = Path(".") subfolder = Path(".") - # we remember the path up to the top of the destdir so that it may be - # removed safely at the end of the install process. - install_job._install_tmpdir = destdir - parts: List[RemoteModelFile] = [] for model_file in remote_files: assert install_job.total_bytes is not None assert model_file.size is not None install_job.total_bytes += model_file.size parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) - multifile_job = self._download_queue.multifile_download( + multifile_job = self._multifile_download( parts=parts, dest=destdir, access_token=source.access_token, - submit_job=False, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, + submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict ) self._download_cache[multifile_job.id] = install_job install_job._download_job = multifile_job @@ -772,6 +775,21 @@ def _stat_size(self, path: Path) -> int: size += sum(self._stat_size(Path(root, x)) for x in files) return size + def _multifile_download( + self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True + ) -> MultiFileDownloadJob: + return self._download_queue.multifile_download( + parts=parts, + dest=dest, + access_token=access_token, + submit_job=submit_job, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ @@ -875,10 +893,9 @@ def _signal_job_completed(self, job: ModelInstallJob) -> None: assert job.local_path is not None assert job.config_out is not None key = job.config_out.key - self._event_bus.emit_model_install_completed(source=str(job.source), - key=key, - id=job.id, - total_bytes=job.bytes) + self._event_bus.emit_model_install_completed( + source=str(job.source), key=key, id=job.id, total_bytes=job.bytes + ) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 094ade63838..57531cf3c19 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,15 +12,13 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager import ( +from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.config import ( ControlAdapterDefaultSettings, MainModelDefaultSettings, + ModelFormat, + ModelType, ModelVariantType, SchedulerPredictionType, ) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index f73b8275343..ca8616238f2 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -222,7 +222,7 @@ def test_delete_register( store.get_model(key) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: ] -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) @@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con } -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: - # TODO: Test subfolder download source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) bus = mm2_installer.event_bus @@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con assert job.total_bytes == completed_events[0].payload["total_bytes"] assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -371,7 +371,7 @@ def raise_runtime_error(*args, **kwargs): }, ], ) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params @@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode } assert "repo_id" in model_params install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) - mm2_installer.wait_for_job(install_job1, timeout=20) + mm2_installer.wait_for_job(install_job1, timeout=10) if model_params["type"] != "embedding": assert install_job1.errored assert install_job1.error_type == "InvalidModelConfigException" @@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) - mm2_installer.wait_for_job(install_job2, timeout=20) + mm2_installer.wait_for_job(install_job2, timeout=10) assert install_job2.complete assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out From d968c6f379dec510eed914b185c3872d3196e7d2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 17 May 2024 22:29:19 -0400 Subject: [PATCH 26/45] refactor multifile download code --- docs/contributing/DOWNLOAD_QUEUE.md | 63 ++++++++- docs/contributing/MODEL_MANAGER.md | 38 ++++++ .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/infill.py | 2 +- invokeai/app/invocations/upscale.py | 2 +- .../model_install/model_install_default.py | 125 ++++++++++-------- .../services/model_load/model_load_base.py | 4 +- .../services/model_load/model_load_default.py | 14 +- .../model_manager/model_manager_base.py | 2 +- .../model_manager/model_manager_default.py | 56 ++++---- .../app/services/shared/invocation_context.py | 51 ++----- .../services/download/test_download_queue.py | 13 ++ .../app/services/model_load/test_load_api.py | 32 +++-- 13 files changed, 263 insertions(+), 145 deletions(-) diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md index d43c670d2c8..960180961e9 100644 --- a/docs/contributing/DOWNLOAD_QUEUE.md +++ b/docs/contributing/DOWNLOAD_QUEUE.md @@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects specify the source and destination of the download, and keep track of the progress of the download. -The only job type currently implemented is `DownloadJob`, a pydantic object with the +Two job types are defined. `DownloadJob` and +`MultiFileDownloadJob`. The former is a pydantic object with the following fields: | **Field** | **Type** | **Default** | **Description** | @@ -138,7 +139,7 @@ following fields: | `dest` | Path | | Where to download to | | `access_token` | str | | [optional] string containing authentication token for access | | `on_start` | Callable | | [optional] callback when the download starts | -| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_error` | Callable | | [optional] callback called after an error occurs | | `id` | int | auto assigned | Job ID, an integer >= 0 | @@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an `error_type` field of "DownloadJobCancelledException". In addition, the job's `cancelled` property will be set to True. +The `MultiFileDownloadJob` is used for diffusers model downloads, +which contain multiple files and directories under a common root: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| _Fields passed in at job creation time_ | +| `download_parts` | Set[DownloadJob]| | Component download jobs | +| `dest` | Path | | Where to download to | +| `on_start` | Callable | | [optional] callback when the download starts | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_complete` | Callable | | [optional] callback called after successful download completion | +| `on_error` | Callable | | [optional] callback called after an error occurs | +| `id` | int | auto assigned | Job ID, an integer >= 0 | +| _Fields updated over the course of the download task_ +| `status` | DownloadJobStatus| | Status code | +| `download_path` | Path | | Path to the root of the downloaded files | +| `bytes` | int | 0 | Bytes downloaded so far | +| `total_bytes` | int | 0 | Total size of the file at the remote site | +| `error_type` | str | | String version of the exception that caused an error during download | +| `error` | str | | String version of the traceback associated with an error | +| `cancelled` | bool | False | Set to true if the job was cancelled by the caller| + +Note that the MultiFileDownloadJob does not support the `priority`, +`job_started`, `job_ended` or `content_type` attributes. You can get +these from the individual download jobs in `download_parts`. + + ### Callbacks Download jobs can be associated with a series of callbacks, each with @@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with running jobs with `cancel_all_jobs()`, and wait for all jobs to finish with `join()`. -#### job = queue.download(source, dest, priority, access_token) +#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) Create a new download job and put it on the queue, returning the DownloadJob object. +#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) + +This is similar to download(), but instead of taking a single source, +it accepts a `parts` argument consisting of a list of +`RemoteModelFile` objects. Each part corresponds to a URL/Path pair, +where the URL is the location of the remote file, and the Path is the +destination. + +`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and +consists of a url/path pair. Note that the path *must* be relative. + +The method returns a `MultiFileDownloadJob`. + + +``` +from invokeai.backend.model_manager.metadata import RemoteModelFile +remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'', + path='my_model/textencoder/pytorch_model.safetensors' + ) +remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt', + path='my_model/vae/diffusers_model.safetensors' + ) +job = queue.multifile_download(parts=[remote_file_1, remote_file_2], + dest='/tmp/downloads', + on_progress=TqdmProgress().update) +queue.wait_for_job(job) +print(f"The files were downloaded to {job.download_path}") +``` + #### jobs = queue.list_jobs() Return a list of all active and inactive `DownloadJob`s. diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index d53198b98e4..fbc9079d49e 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1577,3 +1577,41 @@ This method takes a model key, looks it up using the `ModelRecordServiceBase` object in `mm.store`, and passes the returned model configuration to `load_model_by_config()`. It may raise a `NotImplementedException`. + +## Invocation Context Model Manager API + +Within invocations, the following methods are available from the +`InvocationContext` object: + +### context.download_and_cache_model(source) -> Path + +This method accepts a `source` of a model, downloads and caches it +locally, and returns a Path to the local model. The source can be a +local file or directory, a URL, or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + +### context.load_and_cache_model(source, [loader]) -> LoadedModel + +This method takes a model source, downloads it, caches it, and then +loads it into the RAM cache for use in inference. The optional loader +is a Callable that accepts a Path to the object, and returns a +`Dict[str, torch.Tensor]`. If no loader is provided, then the method +will use `torch.load()` for a .ckpt or .bin checkpoint file, +`safetensors.torch.load_file()` for a safetensors checkpoint file, or +`*.from_pretrained()` for a directory that looks like a +diffusers directory. + + + diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 971179ac93a..e69f4b54ad6 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -611,7 +611,7 @@ def loader(model_path: Path): model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() ) - with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: + with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) processed_image = depth_anything_detector(image=image, resolution=self.resolution) return processed_image @@ -634,8 +634,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mm = context.models - onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) - onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index f8358d1df5c..ddd11cf93f8 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,7 +133,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - with context.models.load_ckpt_from_url( + with context.models.load_and_cache_model( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=LaMA.load_jit_model, ) as model: diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 29cf7819de3..670082f1200 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -91,7 +91,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: context.logger.error(msg) raise ValueError(msg) - loadnet = context.models.load_ckpt_from_url( + loadnet = context.models.load_and_cache_model( source=ESRGAN_MODEL_URLS[self.model_name], ) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index a6bb7ad10da..cde9a6502ea 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -387,12 +387,13 @@ def download_and_cache_model( model_path.mkdir(parents=True, exist_ok=True) model_source = self._guess_source(source) remote_files, _ = self._remote_files_from_source(model_source) - job = self._download_queue.multifile_download( - parts=remote_files, + job = self._multifile_download( dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, ) - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None @@ -734,26 +735,12 @@ def _import_remote_model( ) # remember the temporary directory for later removal install_job._install_tmpdir = destdir + install_job.total_bytes = sum((x.size or 0) for x in remote_files) - # In the event that there is a subfolder specified in the source, - # we need to remove it from the destination path in order to avoid - # creating unwanted subfolders - if isinstance(source, HFModelSource) and source.subfolder: - root = Path(remote_files[0].path.parts[0]) - subfolder = root / source.subfolder - else: - root = Path(".") - subfolder = Path(".") - - parts: List[RemoteModelFile] = [] - for model_file in remote_files: - assert install_job.total_bytes is not None - assert model_file.size is not None - install_job.total_bytes += model_file.size - parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) multifile_job = self._multifile_download( - parts=parts, + remote_files=remote_files, dest=destdir, + subfolder=source.subfolder if isinstance(source, HFModelSource) else None, access_token=source.access_token, submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict ) @@ -776,8 +763,35 @@ def _stat_size(self, path: Path) -> int: return size def _multifile_download( - self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True + self, + remote_files: List[RemoteModelFile], + dest: Path, + subfolder: Optional[Path] = None, + access_token: Optional[str] = None, + submit_job: bool = True, ) -> MultiFileDownloadJob: + # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and + # we are installing the "vae" subfolder, we do not want to create an additional folder level, such + # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo". + # So what we do is to synthesize a folder named "sdxl-turbo_vae" here. + if subfolder: + top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" + path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ + path_to_add = Path(f"{top}_{subfolder}") + else: + path_to_remove = Path(".") + path_to_add = Path(".") + + parts: List[RemoteModelFile] = [] + for model_file in remote_files: + assert model_file.size is not None + parts.append( + RemoteModelFile( + url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json + path=path_to_add / model_file.path.relative_to(path_to_remove), + ) + ) + return self._download_queue.multifile_download( parts=parts, dest=dest, @@ -795,56 +809,53 @@ def _multifile_download( # ------------------------------------------------------------------ def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.id] - install_job.status = InstallStatus.DOWNLOADING + if install_job := self._download_cache.get(download_job.id, None): + install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: # first time - install_job.local_path = download_job.download_path - install_job.total_bytes = download_job.total_bytes + assert download_job.download_path + if install_job.local_path == install_job._install_tmpdir: # first time + install_job.local_path = download_job.download_path + install_job.total_bytes = download_job.total_bytes def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.id] - if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._download_queue.cancel_job(download_job) - else: - # update sizes - install_job.bytes = sum(x.bytes for x in download_job.download_parts) - self._signal_job_downloading(install_job) + if install_job := self._download_cache.get(download_job.id, None): + if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() + self._download_queue.cancel_job(download_job) + else: + # update sizes + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id) - self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) # this starts the installation and registration + if install_job := self._download_cache.pop(download_job.id, None): + self._signal_job_downloads_done(install_job) + self._put_in_queue(install_job) # this starts the installation and registration - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id) - assert install_job is not None - assert excp is not None - install_job.set_error(excp) - self._download_queue.cancel_job(download_job) + if install_job := self._download_cache.pop(download_job.id, None): + assert excp is not None + install_job.set_error(excp) + self._download_queue.cancel_job(download_job) - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id, None) - if not install_job: - return - self._downloads_changed_event.set() - # if install job has already registered an error, then do not replace its status with cancelled - if not install_job.errored: - install_job.cancel() - - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + if install_job := self._download_cache.pop(download_job.id, None): + self._downloads_changed_event.set() + # if install job has already registered an error, then do not replace its status with cancelled + if not install_job.errored: + install_job.cancel() + + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 32fc62fa5bc..7de36793fb6 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -43,11 +43,11 @@ def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" @abstractmethod - def load_ckpt_from_path( + def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None ) -> LoadedModel: """ - Load the checkpoint-format model file located at the indicated Path. + Load the model file or directory located at the indicated Path. This will load an arbitrary model file into the RAM cache. If the optional loader argument is provided, the loader will be invoked to load the model into diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index af211c260e5..cd14235ee0a 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -20,6 +20,8 @@ ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -94,7 +96,7 @@ def load_model( ) return loaded_model - def load_ckpt_from_path( + def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None ) -> LoadedModel: """ @@ -128,6 +130,16 @@ def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]: result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") return result + def diffusers_load_directory(directory: Path) -> AnyModel: + load_class = GenericDiffusersLoader( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self.convert_cache, + ).get_hf_load_class(directory) + result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + return result + if loader is None: loader = ( torch_load_file diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index d16c00302ee..063979ebe65 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -72,7 +72,7 @@ def stop(self, invoker: Invoker) -> None: pass @abstractmethod - def load_ckpt_from_url( + def load_model_from_url( self, source: str | AnyHttpUrl, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index ed274266f39..dd78f1f3b2e 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -64,6 +64,34 @@ def stop(self, invoker: Invoker) -> None: if hasattr(service, "stop"): service.stop(invoker) + def load_model_from_url( + self, + source: str | AnyHttpUrl, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + model_path = self.install.download_and_cache_model(source=str(source)) + return self.load.load_model_from_path(model_path=model_path, loader=loader) + @classmethod def build_model_manager( cls, @@ -102,31 +130,3 @@ def build_model_manager( event_bus=events, ) return cls(store=model_record_service, install=installer, load=loader) - - def load_ckpt_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - model_path = self.install.download_and_cache_ckpt(source=source) - return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c7602760f71..32d32e227b4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -435,11 +435,9 @@ def search_by_attrs( ) return result - def download_and_cache_ckpt( + def download_and_cache_model( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, ) -> Path: """ Download the model file located at source to the models cache and return its Path. @@ -449,12 +447,7 @@ def download_and_cache_ckpt( installed, the cached path will be returned. Otherwise it will be downloaded. Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. - + source: A model path, URL or repo_id. Result: Path to the downloaded model @@ -463,39 +456,14 @@ def download_and_cache_ckpt( TimeoutError """ installer = self._services.model_manager.install - path: Path = installer.download_and_cache_ckpt( + path: Path = installer.download_and_cache_model( source=source, - access_token=access_token, - timeout=timeout, ) return path - def load_ckpt_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None - ) -> LoadedModel: - """ - Load the checkpoint-format model file located at the indicated Path. - - This will load an arbitrary model file into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - model_path: A pathlib.Path to a checkpoint-style models file - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader) - return result - - def load_ckpt_from_url( + def load_and_cache_model( self, - source: str | AnyHttpUrl, + source: Path | str | AnyHttpUrl, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -511,14 +479,17 @@ def load_ckpt_from_url( Be aware that the LoadedModel object will have a `config` attribute of None. Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. + source: A model Path, URL, or repoid. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader) + result: LoadedModel = ( + self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) + if isinstance(source, Path) + else self._services.model_manager.load_model_from_url(source=source, loader=loader) + ) return result diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 564d9c30a07..c9317163c8a 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -304,6 +304,19 @@ def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: queue.stop() +def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + + with pytest.raises(AssertionError) as error: + queue.multifile_download( + parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))], + dest=tmp_path, + ) + assert str(error.value) == "only relative download paths accepted" + + @contextmanager def clear_config() -> Generator[None, None, None]: try: diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 7eb09fb3754..bd3a67a8944 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -24,7 +24,7 @@ def mock_context( def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: - downloaded_path = mock_context.models.download_and_cache_ckpt( + downloaded_path = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) assert downloaded_path.is_file() @@ -32,24 +32,24 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) assert downloaded_path.name == "test_embedding.safetensors" assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" - downloaded_path_2 = mock_context.models.download_and_cache_ckpt( + downloaded_path_2 = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) assert downloaded_path == downloaded_path_2 def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: - downloaded_path = mock_context.models.download_and_cache_ckpt( + downloaded_path = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path) + loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) assert isinstance(loaded_model_1, LoadedModel) - loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path) + loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) assert isinstance(loaded_model_2, LoadedModel) assert loaded_model_1.model is loaded_model_2.model - loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file) + loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) assert isinstance(loaded_model_3, LoadedModel) assert loaded_model_1.model is not loaded_model_3.model assert isinstance(loaded_model_1.model, dict) @@ -58,9 +58,25 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - def test_download_and_load(mock_context: InvocationContext) -> None: - loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + loaded_model_1 = mock_context.models.load_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) assert isinstance(loaded_model_1, LoadedModel) - loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + loaded_model_2 = mock_context.models.load_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) assert isinstance(loaded_model_2, LoadedModel) assert loaded_model_1.model is loaded_model_2.model # should be cached copy + + +def test_download_diffusers(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo") + assert (model_path / "model_index.json").exists() + assert (model_path / "vae").is_dir() + + +def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") + assert model_path.is_dir() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() From 8aebc29b91d0723848aafe04b2f84ff015e42ade Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 17 May 2024 22:48:54 -0400 Subject: [PATCH 27/45] fix test to run on 32bit cpu --- tests/app/services/model_load/test_load_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index bd3a67a8944..57d0fed3419 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -79,4 +79,4 @@ def test_download_diffusers(mock_context: InvocationContext) -> None: def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") assert model_path.is_dir() - assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (model_path / "diffusion_pytorch_model.safetensors").exists() From e77c7e40b7981d04f21cbb1ec6258d9c632ce0dc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 17 May 2024 22:53:45 -0400 Subject: [PATCH 28/45] fix ruff error --- tests/app/services/model_load/test_load_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 57d0fed3419..2af321d60f2 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -79,4 +79,6 @@ def test_download_diffusers(mock_context: InvocationContext) -> None: def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") assert model_path.is_dir() - assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (model_path / "diffusion_pytorch_model.safetensors").exists() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or ( + model_path / "diffusion_pytorch_model.safetensors" + ).exists() From cd12ca6e85c5fd9836da847abccc743024b3289b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 27 May 2024 22:40:01 -0400 Subject: [PATCH 29/45] add migration_11; fix typo --- invokeai/app/services/model_install/model_install_default.py | 5 +++-- invokeai/app/services/shared/sqlite/sqlite_util.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c78a09ce87b..cd4bce41080 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -750,8 +750,8 @@ def _import_remote_model( self._download_cache[multifile_job.id] = install_job install_job._download_job = multifile_job - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})") self._logger.debug(f"remote_files={remote_files}") self._download_queue.submit_multifile_download(multifile_job) return install_job @@ -828,6 +828,7 @@ def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> Non else: # update sizes install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) install_job.download_parts = download_job.download_parts self._signal_job_downloading(install_job) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index cadf09f4575..3b5f4473066 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_10()) + migrator.register_migration(build_migration_11(app_config=config, logger=logger)) migrator.run_migrations() return db From ead1748c544696faf49c8e0d73366feaf34463a9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 28 May 2024 19:30:42 -0400 Subject: [PATCH 30/45] issue a download progress event when install download starts --- .../services/model_install/model_install_default.py | 8 +++++--- .../app/services/model_install/test_model_install.py | 11 ++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index cd4bce41080..3b8a408e97f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -815,10 +815,13 @@ def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None if install_job := self._download_cache.get(download_job.id, None): install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path if install_job.local_path == install_job._install_tmpdir: # first time + assert download_job.download_path install_job.local_path = download_job.download_path - install_job.total_bytes = download_job.total_bytes + install_job.download_parts = download_job.download_parts + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = download_job.total_bytes + self._signal_job_downloading(install_job) def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: @@ -829,7 +832,6 @@ def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> Non # update sizes install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) - install_job.download_parts = download_job.download_parts self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 5c9f908ccc8..9602a79a278 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -251,11 +251,12 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: model_record = store.get_model(key) assert (mm2_app_config.models_path / model_record.path).exists() - assert len(bus.events) == 4 - assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) - assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) - assert isinstance(bus.events[2], ModelInstallStartedEvent) - assert isinstance(bus.events[3], ModelInstallCompleteEvent) + assert len(bus.events) == 5 + assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts + assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses + assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed + assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started + assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed @pytest.mark.timeout(timeout=10, method="thread") From 132bbf330a0218549b773bae7e5d8f36cd4c6e29 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:35:23 +1000 Subject: [PATCH 31/45] tidy(app): remove unnecessary changes in invocation_context - Any mypy issues are a misconfiguration of mypy - Use simple conditionals instead of ternaries - Consistent & standards-compliant docstring formatting - Use `dict` instead of `typing.Dict` --- .../app/services/shared/invocation_context.py | 68 ++++++++----------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c932e66989e..08ca207118d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,10 +1,10 @@ from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union -import torch from PIL.Image import Image from pydantic.networks import AnyHttpUrl +from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -268,7 +268,7 @@ def get_path(self, image_name: str, thumbnail: bool = False) -> Path: class TensorsInterface(InvocationContextInterface): - def save(self, tensor: torch.Tensor) -> str: + def save(self, tensor: Tensor) -> str: """Saves a tensor, returning its name. Args: @@ -281,7 +281,7 @@ def save(self, tensor: torch.Tensor) -> str: name = self._services.tensors.save(obj=tensor) return name - def load(self, name: str) -> torch.Tensor: + def load(self, name: str) -> Tensor: """Loads a tensor by name. Args: @@ -333,13 +333,9 @@ def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: True if the model exists, False if not. """ if isinstance(identifier, str): - # For some reason, Mypy is not getting the type annotations for many of - # the model manager service calls and raises a "returning Any in typed - # context" error. Hence the extra typing hints here and below. - result: bool = self._services.model_manager.store.exists(identifier) + return self._services.model_manager.store.exists(identifier) else: - result = self._services.model_manager.store.exists(identifier.key) - return result + return self._services.model_manager.store.exists(identifier.key) def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None @@ -353,6 +349,7 @@ def load( Returns: An object representing the loaded model. """ + # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. @@ -379,6 +376,7 @@ def load_by_attrs( Returns: An object representing the loaded model. """ + configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) if len(configs) == 0: raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") @@ -398,10 +396,9 @@ def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModel The model's config. """ if isinstance(identifier, str): - result: AnyModelConfig = self._services.model_manager.store.get_model(identifier) + return self._services.model_manager.store.get_model(identifier) else: - result = self._services.model_manager.store.get_model(identifier.key) - return result + return self._services.model_manager.store.get_model(identifier.key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: """Search for models by path. @@ -412,8 +409,7 @@ def search_by_path(self, path: Path) -> list[AnyModelConfig]: Returns: A list of models that match the path. """ - result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path) - return result + return self._services.model_manager.store.search_by_path(path) def search_by_attrs( self, @@ -433,13 +429,13 @@ def search_by_attrs( Returns: A list of models that match the attributes. """ - result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr( + + return self._services.model_manager.store.search_by_attr( model_name=name, base_model=base, model_type=type, model_format=format, ) - return result def download_and_cache_model( self, @@ -453,24 +449,18 @@ def download_and_cache_model( installed, the cached path will be returned. Otherwise it will be downloaded. Args: - source: A model path, URL or repo_id. - Result: - Path to the downloaded model + source: A model path, URL or repo_id. - May Raise: - HTTPError - TimeoutError + Returns: + Path to the downloaded model """ - installer = self._services.model_manager.install - path: Path = installer.download_and_cache_model( - source=source, - ) - return path + + return self._services.model_manager.install.download_and_cache_model(source=source) def load_and_cache_model( self, source: Path | str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, ) -> LoadedModel: """ Download, cache, and load the model file located at the indicated URL. @@ -485,24 +475,22 @@ def load_and_cache_model( Be aware that the LoadedModel object will have a `config` attribute of None. Args: - source: A model Path, URL, or repoid. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] + source: A model Path, URL, or repoid. + loader: A Callable that expects a Path and returns a dict[str|int, Any] Returns: - A LoadedModel object. + A LoadedModel object. """ - result: LoadedModel = ( - self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) - if isinstance(source, Path) - else self._services.model_manager.load_model_from_url(source=source, loader=loader) - ) - return result + + if isinstance(source, Path): + return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) + else: + return self._services.model_manager.load_model_from_url(source=source, loader=loader) class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """ - Gets the app's config. + """Gets the app's config. Returns: The app's config. From e3a70e598e0c566f2f43bc67588edc0d87c78738 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:40:29 +1000 Subject: [PATCH 32/45] docs(app): simplify docstring in invocation_context --- invokeai/app/services/shared/invocation_context.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 08ca207118d..27a29f6646f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -465,12 +465,10 @@ def load_and_cache_model( """ Download, cache, and load the model file located at the indicated URL. - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. + If the model is already downloaded, it will be loaded from the cache. + + If the a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. Be aware that the LoadedModel object will have a `config` attribute of None. From b1244400234cd3a44d4b060751c758d231a8da98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:51:21 +1000 Subject: [PATCH 33/45] tidy(mm): move `load_model_from_url` from mm to invocation context --- .../model_manager/model_manager_base.py | 31 ----------------- .../model_manager/model_manager_default.py | 34 ++----------------- .../app/services/shared/invocation_context.py | 4 ++- 3 files changed, 5 insertions(+), 64 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 063979ebe65..af1b68e1ec3 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,15 +1,11 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod -from pathlib import Path -from typing import Callable, Dict, Optional import torch -from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -70,30 +66,3 @@ def start(self, invoker: Invoker) -> None: @abstractmethod def stop(self, invoker: Invoker) -> None: pass - - @abstractmethod - def load_model_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index dd78f1f3b2e..1a2b9a34022 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,15 +1,13 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from pathlib import Path -from typing import Callable, Dict, Optional +from typing import Optional import torch -from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -64,34 +62,6 @@ def stop(self, invoker: Invoker) -> None: if hasattr(service, "stop"): service.stop(invoker) - def load_model_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - model_path = self.install.download_and_cache_model(source=str(source)) - return self.load.load_model_from_path(model_path=model_path, loader=loader) - @classmethod def build_model_manager( cls, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 27a29f6646f..b0d1ee4d2fd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -483,7 +483,9 @@ def load_and_cache_model( if isinstance(source, Path): return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) else: - return self._services.model_manager.load_model_from_url(source=source, loader=loader) + model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + class ConfigInterface(InvocationContextInterface): From ccdecf21a3dda1dc051eed59a5567aa60406f2f2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:41:17 +1000 Subject: [PATCH 34/45] tidy(nodes): cnet processors - Set `self._context=context` instead of changing the type signature of `run_processor` - Tidy a few typing things --- .../controlnet_image_processors.py | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e533583829e..1e4ad672bfd 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -132,7 +132,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image @@ -141,9 +141,10 @@ def load_image(self, context: InvocationContext) -> Image.Image: return context.images.get_pil(self.image.image_name, "RGB") def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(raw_image, context) + processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery @@ -184,7 +185,7 @@ def load_image(self, context: InvocationContext) -> Image.Image: # Keep alpha channel for Canny processing to detect edges of transparent areas return context.images.get_pil(self.image.image_name, "RGBA") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: processed_image = get_canny_edges( image, self.low_threshold, @@ -211,7 +212,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: hed_processor = HEDProcessor() processed_image = hed_processor.run( image, @@ -238,7 +239,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: lineart_processor = LineartProcessor() processed_image = lineart_processor.run( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse @@ -259,7 +260,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: processor = LineartAnimeProcessor() processed_image = processor.run( image, @@ -286,7 +287,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( @@ -314,9 +315,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") - processed_image: Image.Image = normalbae_processor( + processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @@ -333,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -356,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -384,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -408,7 +409,7 @@ def run_processor(self, image: Image.Image, context: InvocationContext) -> Image class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -429,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -457,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -499,8 +500,8 @@ def tile_resample( np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image: - np_img = np.array(img, dtype=np.uint8) + def run_processor(self, image: Image.Image) -> Image.Image: + np_img = np.array(image, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, # res=self.tile_size, @@ -523,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -569,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -604,13 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: def loader(model_path: Path): return DepthAnythingDetector.load_model( model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() ) - with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: + with self._context.models.load_and_cache_model( + source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader + ) as model: depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) processed_image = depth_anything_detector(image=image, resolution=self.resolution) return processed_image @@ -631,10 +634,9 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: - mm = context.models - onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) - onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + def run_processor(self, image: Image.Image) -> Image.Image: + onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( From 521f907f587854612960ef5a27cc85a465f59d03 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:43:25 +1000 Subject: [PATCH 35/45] tidy(nodes): infill - Set `self._context=context` instead of passing it as an arg --- invokeai/app/invocations/infill.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index ddd11cf93f8..f188ecf8f7d 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -38,26 +38,27 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") @abstractmethod - def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def infill(self, image: Image.Image) -> Image.Image: """Infill the image with the specified method""" pass - def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]: + def load_image(self) -> tuple[Image.Image, bool]: """Process the image to have an alpha channel before being infilled""" - image = context.images.get_pil(self.image.image_name) + image = self._context.images.get_pil(self.image.image_name) has_alpha = True if image.mode == "RGBA" else False return image, has_alpha def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context # Retrieve and process image to be infilled - input_image, has_alpha = self.load_image(context) + input_image, has_alpha = self.load_image() # If the input image has no alpha channel, return it if has_alpha is False: return ImageOutput.build(context.images.get_dto(self.image.image_name)) # Perform Infill action - infilled_image = self.infill(input_image, context) + infilled_image = self.infill(input_image) # Create ImageDTO for Infilled Image infilled_image_dto = context.images.save(image=infilled_image) @@ -75,7 +76,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation): description="The color to use to infill", ) - def infill(self, image: Image.Image, context: InvocationContext): + def infill(self, image: Image.Image): solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) @@ -94,7 +95,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation): description="The seed to use for tile generation (omit for random)", ) - def infill(self, image: Image.Image, context: InvocationContext): + def infill(self, image: Image.Image): output = infill_tile(image, seed=self.seed, tile_size=self.tile_size) return output.infilled @@ -108,7 +109,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def infill(self, image: Image.Image, context: InvocationContext): + def infill(self, image: Image.Image): resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width / self.downscale) @@ -132,8 +133,8 @@ def infill(self, image: Image.Image, context: InvocationContext): class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" - def infill(self, image: Image.Image, context: InvocationContext): - with context.models.load_and_cache_model( + def infill(self, image: Image.Image): + with self._context.models.load_and_cache_model( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=LaMA.load_jit_model, ) as model: @@ -145,7 +146,7 @@ def infill(self, image: Image.Image, context: InvocationContext): class CV2InfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using OpenCV Inpainting""" - def infill(self, image: Image.Image, context: InvocationContext): + def infill(self, image: Image.Image): return cv2_inpaint(image) @@ -167,5 +168,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation): description="The max threshold for color", ) - def infill(self, image: Image.Image, context: InvocationContext): + def infill(self, image: Image.Image): return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple()) From 6cc6a452741f9eef5394300b06a0f2c3e4a4b4a0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:05:52 +1000 Subject: [PATCH 36/45] feat(download): add type for callback_name Just a bit of typo protection in lieu of full type safety for these methods, which is difficult due to the typing of `DownloadEventHandler`. --- invokeai/app/services/download/download_default.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 5025255c913..4640a656dca 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,7 +8,7 @@ import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set import requests from pydantic.networks import AnyHttpUrl @@ -528,7 +528,13 @@ def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None def _execute_cb( self, job: DownloadJob | MultiFileDownloadJob, - callback_name: str, + callback_name: Literal[ + "on_start", + "on_progress", + "on_complete", + "on_cancelled", + "on_error", + ], excp: Optional[Exception] = None, ) -> None: if callback := getattr(job, callback_name, None): From c58ac1e80d14d23b9e00b0944830a7994c64bfaa Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:11:08 +1000 Subject: [PATCH 37/45] tidy(mm): minor formatting --- invokeai/app/services/model_install/model_install_base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 76b77f0419b..734f05d574e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -241,10 +241,7 @@ def sync_model_path(self, key: str) -> AnyModelConfig: """ @abstractmethod - def download_and_cache_model( - self, - source: str, - ) -> Path: + def download_and_cache_model(self, source: str) -> Path: """ Download the model file located at source to the models cache and return its Path. From aa9695e377acb6f3485e564c4a117063889ada9c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:15:53 +1000 Subject: [PATCH 38/45] tidy(download): `_download_job` -> `_multifile_job` --- .../services/model_install/model_install_common.py | 2 +- .../services/model_install/model_install_default.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index 751b5baa4be..c1538f543dc 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -162,7 +162,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) + _multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 3b8a408e97f..b561744ff47 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -278,7 +278,7 @@ def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" job.cancel() self._logger.warning(f"Cancelling {job.source}") - if dj := job._download_job: + if dj := job._multifile_job: self._download_queue.cancel_job(dj) def prune_jobs(self) -> None: @@ -514,9 +514,9 @@ def _register_or_install(self, job: ModelInstallJob) -> None: self._signal_job_completed(job) def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: - download_job = install_job._download_job - if download_job and any( - x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts + multifile_download_job = install_job._multifile_job + if multifile_download_job and any( + x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts ): install_job.set_error( InvalidModelConfigException( @@ -748,7 +748,7 @@ def _import_remote_model( submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict ) self._download_cache[multifile_job.id] = install_job - install_job._download_job = multifile_job + install_job._multifile_job = multifile_job files_string = "file" if len(remote_files) == 1 else "files" self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})") @@ -875,7 +875,7 @@ def _signal_job_running(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: - assert job._download_job is not None + assert job._multifile_job is not None assert job.bytes is not None assert job.total_bytes is not None self._event_bus.emit_model_install_download_progress(job) From 99413256ce1f847ec42918ed1af49fad3c54736a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:43:09 +1000 Subject: [PATCH 39/45] tidy(mm): pass enum member instead of string --- invokeai/app/services/model_install/model_install_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index b561744ff47..de9bb8eed88 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -442,7 +442,7 @@ def _guess_source(self, source: str) -> ModelSource: elif match := re.match(hf_repoid_re, source): source_obj = HFModelSource( repo_id=match.group(1), - variant=match.group(2) if match.group(2) else None, # pass None rather than '' + variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than '' subfolder=Path(match.group(3)) if match.group(3) else None, ) elif re.match(r"^https?://[^/]+", source): From c7f22b6a3b36e0b0bf7b2320ece7db1c84c22a3e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:46:28 +1000 Subject: [PATCH 40/45] tidy(mm): remove extraneous docstring It's inherited from the ABC. --- .../services/model_load/model_load_default.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 221a042da5c..776620edca0 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -87,23 +87,6 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None ) -> LoadedModel: - """ - Load the checkpoint-format model file located at the indicated Path. - - This will load an arbitrary model file into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - model_path: A pathlib.Path to a checkpoint-style models file - loader: A Callable that expects a Path and returns a Dict[str, Tensor] - - Returns: - A LoadedModel object. - """ cache_key = str(model_path) ram_cache = self.ram_cache try: From e7513f60887b358055e23631024e701d3ffca20d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:56:04 +1000 Subject: [PATCH 41/45] docs(mm): add comment in `move_model_to_device` --- .../model_manager/load/model_cache/model_cache_default.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 10c0210052b..335a15a5c81 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -261,6 +261,7 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device if torch.device(source_device).type == torch.device(target_device).type: return + # Some models don't have a `to` method, in which case they run in RAM/CPU. if not hasattr(cache_entry.model, "to"): return From a9962fd104823458f658f989eb73eca1f8a81444 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:53:20 +1000 Subject: [PATCH 42/45] chore: ruff --- invokeai/app/services/shared/invocation_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b0d1ee4d2fd..260bf6a61fe 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -487,7 +487,6 @@ def load_and_cache_model( return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) - class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: """Gets the app's config. From f81b8bc9f6452e20672ef6ee52b567a0be77e206 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Jun 2024 20:31:05 -0400 Subject: [PATCH 43/45] add support for generic loading of diffusers directories --- .../app/services/model_load/model_load_base.py | 7 ++++--- .../services/model_load/model_load_default.py | 13 ++++++++----- .../app/services/shared/invocation_context.py | 8 ++++---- invokeai/backend/model_manager/load/__init__.py | 3 ++- .../backend/model_manager/load/load_base.py | 10 ++++++++-- .../load/model_loaders/generic_diffusers.py | 9 +++------ tests/app/services/model_load/test_load_api.py | 17 +++++++++++------ .../model_manager/model_manager_fixtures.py | 4 ++++ 8 files changed, 44 insertions(+), 27 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 22d815483ee..f84b1dae139 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -8,7 +8,7 @@ from torch import Tensor from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType -from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -38,7 +38,7 @@ def convert_cache(self) -> ModelConvertCacheBase: @abstractmethod def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None - ) -> LoadedModel: + ) -> LoadedModelWithoutConfig: """ Load the model file or directory located at the indicated Path. @@ -47,7 +47,8 @@ def load_model_from_path( memory. Otherwise the method will call safetensors.torch.load_file() or torch.load() as appropriate to the file suffix. - Be aware that the LoadedModel object will have a `config` attribute of None. + Be aware that this returns a LoadedModelWithoutConfig object, which is the same as + LoadedModel, but without the config attribute. Args: model_path: A pathlib.Path to a checkpoint-style models file diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 776620edca0..113334ea0d5 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -14,6 +14,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, + LoadedModelWithoutConfig, ModelLoaderRegistry, ModelLoaderRegistryBase, ) @@ -85,12 +86,12 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo return loaded_model def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None - ) -> LoadedModel: + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None + ) -> LoadedModelWithoutConfig: cache_key = str(model_path) ram_cache = self.ram_cache try: - return LoadedModel(_locker=ram_cache.get(key=cache_key)) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) except IndexError: pass @@ -113,11 +114,13 @@ def diffusers_load_directory(directory: Path) -> AnyModel: if loader is None: loader = ( - torch_load_file + diffusers_load_directory + if model_path.is_dir() + else torch_load_file if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) else lambda path: safetensors_load_file(path, device="cpu") ) raw_model = loader(model_path) ram_cache.put(key=cache_key, model=raw_model) - return LoadedModel(_locker=ram_cache.get(key=cache_key)) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 260bf6a61fe..931fc40b822 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -16,7 +16,7 @@ from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -461,7 +461,7 @@ def load_and_cache_model( self, source: Path | str | AnyHttpUrl, loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, - ) -> LoadedModel: + ) -> LoadedModelWithoutConfig: """ Download, cache, and load the model file located at the indicated URL. @@ -470,14 +470,14 @@ def load_and_cache_model( If the a loader callable is provided, it will be invoked to load the model. Otherwise, `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. - Be aware that the LoadedModel object will have a `config` attribute of None. + Be aware that the LoadedModelWithoutConfig object has no `config` attribute Args: source: A model Path, URL, or repoid. loader: A Callable that expects a Path and returns a dict[str|int, Any] Returns: - A LoadedModel object. + A LoadedModelWithoutConfig object. """ if isinstance(source, Path): diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index f47a2c43684..25125f43fb0 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -7,7 +7,7 @@ from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import LoadedModel, ModelLoaderBase +from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase @@ -19,6 +19,7 @@ __all__ = [ "LoadedModel", + "LoadedModelWithoutConfig", "ModelCache", "ModelConvertCache", "ModelLoaderBase", diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 41a36d7b51a..a7c080ed2b0 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -20,11 +20,10 @@ @dataclass -class LoadedModel: +class LoadedModelWithoutConfig: """Context manager object that mediates transfer from RAM<->VRAM.""" _locker: ModelLockerBase - config: Optional[AnyModelConfig] = None def __enter__(self) -> AnyModel: """Context entry.""" @@ -41,6 +40,13 @@ def model(self) -> AnyModel: return self._locker.model +@dataclass +class LoadedModel(LoadedModelWithoutConfig): + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: Optional[AnyModelConfig] = None + + # TODO(MM2): # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # know about. I think the problem may be related to this class being an ABC. diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index a4874b33cea..6320797b8af 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -65,14 +65,11 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: + if class_name := config.get("_class_name"): result = self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures") - assert class_name is not None + elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) - if not class_name: + else: raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index a10bc4d66ac..9671c8c6c3b 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -2,11 +2,12 @@ import pytest import torch +from diffusers import AutoencoderTiny from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 @@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - "https://www.test.foo/download/test_embedding.safetensors" ) loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) - assert isinstance(loaded_model_1, LoadedModel) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) - assert isinstance(loaded_model_2, LoadedModel) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) - assert isinstance(loaded_model_3, LoadedModel) + assert isinstance(loaded_model_3, LoadedModelWithoutConfig) assert loaded_model_1.model is not loaded_model_3.model assert isinstance(loaded_model_1.model, dict) assert isinstance(loaded_model_3.model, dict) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) +def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: + loaded_model = mock_context.models.load_and_cache_model(vae_directory) + assert isinstance(loaded_model, LoadedModelWithoutConfig) + assert isinstance(loaded_model.model, AutoencoderTiny) def test_download_and_load(mock_context: InvocationContext) -> None: loaded_model_1 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - assert isinstance(loaded_model_1, LoadedModel) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) loaded_model_2 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - assert isinstance(loaded_model_2, LoadedModel) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model # should be cached copy diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index dc2ad2f1e40..ee66c459b8e 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path: def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" +@pytest.fixture +def vae_directory(mm2_model_files: Path) -> Path: + return mm2_model_files / "taesdxl" + @pytest.fixture def diffusers_dir(mm2_model_files: Path) -> Path: From 9f9379682e63c2993df4207ad176d8c971768478 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Jun 2024 20:33:31 -0400 Subject: [PATCH 44/45] ruff fixes --- tests/app/services/model_load/test_load_api.py | 2 ++ tests/backend/model_manager/model_manager_fixtures.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 9671c8c6c3b..8dd948692d2 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -57,11 +57,13 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - assert isinstance(loaded_model_3.model, dict) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) + def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: loaded_model = mock_context.models.load_and_cache_model(vae_directory) assert isinstance(loaded_model, LoadedModelWithoutConfig) assert isinstance(loaded_model.model, AutoencoderTiny) + def test_download_and_load(mock_context: InvocationContext) -> None: loaded_model_1 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index ee66c459b8e..9ce272fc42a 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -60,6 +60,7 @@ def mm2_model_files(tmp_path_factory) -> Path: def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" + @pytest.fixture def vae_directory(mm2_model_files: Path) -> Path: return mm2_model_files / "taesdxl" From dc134935c8be752eabf61fb71a7802c736e0f3dc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 6 Jun 2024 00:31:41 -0400 Subject: [PATCH 45/45] replace load_and_cache_model() with load_remote_model() and load_local_odel() --- docs/contributing/MODEL_MANAGER.md | 42 ++++++++++----- .../controlnet_image_processors.py | 2 +- invokeai/app/invocations/infill.py | 2 +- invokeai/app/invocations/upscale.py | 2 +- .../model_install/model_install_base.py | 4 +- .../model_install/model_install_default.py | 7 +-- .../services/model_load/model_load_base.py | 6 +-- .../services/model_load/model_load_default.py | 31 +++++------ .../app/services/shared/invocation_context.py | 52 ++++++++++++++----- .../migrations/migration_11.py | 4 +- .../app/services/model_load/test_load_api.py | 17 +++--- .../model_manager/model_manager_fixtures.py | 8 +-- 12 files changed, 107 insertions(+), 70 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index fbc9079d49e..dfa724fee89 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the ### context.download_and_cache_model(source) -> Path -This method accepts a `source` of a model, downloads and caches it -locally, and returns a Path to the local model. The source can be a -local file or directory, a URL, or a HuggingFace repo_id. +This method accepts a `source` of a remote model, downloads and caches +it locally, and then returns a Path to the local model. The source can +be a direct download URL or a HuggingFace repo_id. In the case of HuggingFace repo_id, the following variants are recognized: @@ -1602,16 +1602,34 @@ directory using this syntax: * stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors -### context.load_and_cache_model(source, [loader]) -> LoadedModel +### context.load_local_model(model_path, [loader]) -> LoadedModel -This method takes a model source, downloads it, caches it, and then -loads it into the RAM cache for use in inference. The optional loader -is a Callable that accepts a Path to the object, and returns a -`Dict[str, torch.Tensor]`. If no loader is provided, then the method -will use `torch.load()` for a .ckpt or .bin checkpoint file, -`safetensors.torch.load_file()` for a safetensors checkpoint file, or -`*.from_pretrained()` for a directory that looks like a -diffusers directory. +This method loads a local model from the indicated path, returning a +`LoadedModel`. The optional loader is a Callable that accepts a Path +to the object, and returns a `AnyModel` object. If no loader is +provided, then the method will use `torch.load()` for a .ckpt or .bin +checkpoint file, `safetensors.torch.load_file()` for a safetensors +checkpoint file, or `cls.from_pretrained()` for a directory that looks +like a diffusers directory. + +### context.load_remote_model(source, [loader]) -> LoadedModel + +This method accepts a `source` of a remote model, downloads and caches +it locally, loads it, and returns a `LoadedModel`. The source can be a +direct download URL or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1e4ad672bfd..c0b332f27b6 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -611,7 +611,7 @@ def loader(model_path: Path): model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() ) - with self._context.models.load_and_cache_model( + with self._context.models.load_remote_model( source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader ) as model: depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index f188ecf8f7d..7e1a2ee322f 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image): - with self._context.models.load_and_cache_model( + with self._context.models.load_remote_model( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=LaMA.load_jit_model, ) as model: diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 670082f1200..f93060f8d34 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -91,7 +91,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: context.logger.error(msg) raise ValueError(msg) - loadnet = context.models.load_and_cache_model( + loadnet = context.models.load_remote_model( source=ESRGAN_MODEL_URLS[self.model_name], ) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 734f05d574e..20afaeaa505 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from pydantic.networks import AnyHttpUrl + from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase @@ -241,7 +243,7 @@ def sync_model_path(self, key: str) -> AnyModelConfig: """ @abstractmethod - def download_and_cache_model(self, source: str) -> Path: + def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: """ Download the model file located at source to the models cache and return its Path. diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index de9bb8eed88..39e38a593f1 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -15,6 +15,7 @@ import yaml from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl +from pydantic_core import Url from requests import Session from invokeai.app.services.config import InvokeAIAppConfig @@ -374,7 +375,7 @@ def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: Invoke def download_and_cache_model( self, - source: str, + source: str | AnyHttpUrl, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" model_path = self._download_cache_path(str(source), self._app_config) @@ -388,7 +389,7 @@ def download_and_cache_model( return contents[0] model_path.mkdir(parents=True, exist_ok=True) - model_source = self._guess_source(source) + model_source = self._guess_source(str(source)) remote_files, _ = self._remote_files_from_source(model_source) job = self._multifile_download( dest=model_path, @@ -447,7 +448,7 @@ def _guess_source(self, source: str) -> ModelSource: ) elif re.match(r"^https?://[^/]+", source): source_obj = URLModelSource( - url=AnyHttpUrl(source), + url=Url(source), ) else: raise ValueError(f"Unsupported model source: '{source}'") diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f84b1dae139..da567721956 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -3,9 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Dict, Optional - -from torch import Tensor +from typing import Callable, Optional from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig @@ -37,7 +35,7 @@ def convert_cache(self) -> ModelConvertCacheBase: @abstractmethod def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: """ Load the model file or directory located at the indicated Path. diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 113334ea0d5..70674813785 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -2,11 +2,10 @@ """Implementation of model loader service.""" from pathlib import Path -from typing import Callable, Dict, Optional, Type +from typing import Callable, Optional, Type from picklescan.scanner import scan_file_path from safetensors.torch import load_file as safetensors_load_file -from torch import Tensor from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig @@ -86,7 +85,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo return loaded_model def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: cache_key = str(model_path) ram_cache = self.ram_cache @@ -95,11 +94,11 @@ def load_model_from_path( except IndexError: pass - def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]: + def torch_load_file(checkpoint: Path) -> AnyModel: scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") - result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") + result = torch_load(checkpoint, map_location="cpu") return result def diffusers_load_directory(directory: Path) -> AnyModel: @@ -109,18 +108,16 @@ def diffusers_load_directory(directory: Path) -> AnyModel: ram_cache=self._ram_cache, convert_cache=self.convert_cache, ).get_hf_load_class(directory) - result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) - return result - - if loader is None: - loader = ( - diffusers_load_directory - if model_path.is_dir() - else torch_load_file - if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) - else lambda path: safetensors_load_file(path, device="cpu") - ) - + return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + + loader = loader or ( + diffusers_load_directory + if model_path.is_dir() + else torch_load_file + if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + assert loader is not None raw_model = loader(model_path) ram_cache.put(key=cache_key, model=raw_model) return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 931fc40b822..01662335e46 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -15,7 +15,14 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -449,21 +456,42 @@ def download_and_cache_model( installed, the cached path will be returned. Otherwise it will be downloaded. Args: - source: A model path, URL or repo_id. + source: A URL that points to the model, or a huggingface repo_id. Returns: Path to the downloaded model """ - return self._services.model_manager.install.download_and_cache_model(source=source) - def load_and_cache_model( + def load_local_model( self, - source: Path | str | AnyHttpUrl, - loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, + model_path: Path, + loader: Optional[Callable[[Path], AnyModel]] = None, ) -> LoadedModelWithoutConfig: """ - Download, cache, and load the model file located at the indicated URL. + Load the model file located at the indicated path + + If a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. + + Be aware that the LoadedModelWithoutConfig object has no `config` attribute + + Args: + path: A model Path + loader: A Callable that expects a Path and returns a dict[str|int, Any] + + Returns: + A LoadedModelWithoutConfig object. + """ + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + + def load_remote_model( + self, + source: str | AnyHttpUrl, + loader: Optional[Callable[[Path], AnyModel]] = None, + ) -> LoadedModelWithoutConfig: + """ + Download, cache, and load the model file located at the indicated URL or repo_id. If the model is already downloaded, it will be loaded from the cache. @@ -473,18 +501,14 @@ def load_and_cache_model( Be aware that the LoadedModelWithoutConfig object has no `config` attribute Args: - source: A model Path, URL, or repoid. + source: A URL or huggingface repoid. loader: A Callable that expects a Path and returns a dict[str|int, Any] Returns: A LoadedModelWithoutConfig object. """ - - if isinstance(source, Path): - return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) - else: - model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) - return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py index 3b616e2b824..f66374e0b1e 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py @@ -59,14 +59,12 @@ def _remove_unused_core_models(self) -> None: def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: """ - Build the migration from database version 9 to 10. + Build the migration from database version 10 to 11. This migration does the following: - Moves "core" models previously downloaded with download_with_progress_bar() into new "models/.download_cache" directory. - Renames "models/.cache" to "models/.convert_cache". - - Adds `error_type` and `error_message` columns to the session queue table. - - Renames the `error` column to `error_traceback`. """ migration_11 = Migration( from_version=10, diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 8dd948692d2..6f2c7bd931b 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - downloaded_path = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) + loaded_model_1 = mock_context.models.load_local_model(downloaded_path) assert isinstance(loaded_model_1, LoadedModelWithoutConfig) - loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) + loaded_model_2 = mock_context.models.load_local_model(downloaded_path) assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model - loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) + loaded_model_3 = mock_context.models.load_local_model(embedding_file) assert isinstance(loaded_model_3, LoadedModelWithoutConfig) assert loaded_model_1.model is not loaded_model_3.model assert isinstance(loaded_model_1.model, dict) @@ -58,21 +58,18 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) +@pytest.mark.skip(reason="This requires a test model to load") def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: - loaded_model = mock_context.models.load_and_cache_model(vae_directory) + loaded_model = mock_context.models.load_local_model(vae_directory) assert isinstance(loaded_model, LoadedModelWithoutConfig) assert isinstance(loaded_model.model, AutoencoderTiny) def test_download_and_load(mock_context: InvocationContext) -> None: - loaded_model_1 = mock_context.models.load_and_cache_model( - "https://www.test.foo/download/test_embedding.safetensors" - ) + loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") assert isinstance(loaded_model_1, LoadedModelWithoutConfig) - loaded_model_2 = mock_context.models.load_and_cache_model( - "https://www.test.foo/download/test_embedding.safetensors" - ) + loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model # should be cached copy diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 9ce272fc42a..f82239298e1 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -61,9 +61,11 @@ def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" -@pytest.fixture -def vae_directory(mm2_model_files: Path) -> Path: - return mm2_model_files / "taesdxl" +# Can be used to test diffusers model directory loading, but +# the test file adds ~10MB of space. +# @pytest.fixture +# def vae_directory(mm2_model_files: Path) -> Path: +# return mm2_model_files / "taesdxl" @pytest.fixture