Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simplified model manager install API to InvocationContext #6132

Merged
merged 60 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9cc1f20
add simplified model manager install API to InvocationContext
Apr 4, 2024
af1b57a
add simplified model manager install API to InvocationContext
Apr 4, 2024
df5ebdb
add invocation_context.load_ckpt_from_url() method
Apr 12, 2024
3a26c7b
fix merge conflicts
Apr 12, 2024
41b909c
port dw_openpose, depth_anything, and lama processors to new model do…
Apr 13, 2024
3ddd7ce
change names of convert and download caches and add migration script
Apr 14, 2024
34438ce
add simplified model manager install API to InvocationContext
Apr 4, 2024
c140d3b
add invocation_context.load_ckpt_from_url() method
Apr 12, 2024
3ead827
port dw_openpose, depth_anything, and lama processors to new model do…
Apr 13, 2024
fa6efac
change names of convert and download caches and add migration script
Apr 14, 2024
f055e1e
Merge branch 'lstein/feat/simple-mm2-api' of github.com:invoke-ai/Inv…
Apr 15, 2024
f1e79d5
Merge branch 'main' into lstein/feat/simple-mm2-api
Apr 15, 2024
470a399
fix merge conflicts with main
Apr 15, 2024
34cdfc6
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein Apr 17, 2024
d72f272
Address change requests in first round of PR reviews.
Apr 25, 2024
70903ef
refactor load_ckpt_from_url()
Apr 28, 2024
bb04f49
Merge branch 'main' into lstein/feat/simple-mm2-api
Apr 28, 2024
a26667d
make download and convert cache keys safe for filename length
Apr 28, 2024
7c39929
support VRAM caching of dict models that lack `to()`
Apr 28, 2024
f65c7e2
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein Apr 28, 2024
57c8314
fix safe_filename() on windows
Apr 28, 2024
fcb071f
feat(backend): lift managed model loading out of lama class
psychedelicious Apr 28, 2024
1fe90c3
feat(backend): lift managed model loading out of depthanything class
psychedelicious Apr 28, 2024
49c84cd
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein Apr 30, 2024
3b64e7a
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein May 3, 2024
38df6f3
fix ruff error
May 3, 2024
e9a2005
refactor DWOpenPose and add type hints
May 3, 2024
8e5e9b5
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein May 4, 2024
f211c95
move access token regex matching into download queue
lstein May 6, 2024
b48d4a0
bad implementation of diffusers folder download
lstein May 9, 2024
0bf14c2
add multifile_download() method to download service
lstein May 13, 2024
287c679
clean up type checking for single file and multifile download job cal…
May 13, 2024
f29c406
refactor model_install to work with refactored download queue
May 14, 2024
911a244
add tests for model install file size reporting
May 16, 2024
2dae5eb
more refactoring; HF subfolders not working
May 17, 2024
d968c6f
refactor multifile download code
May 18, 2024
8aebc29
fix test to run on 32bit cpu
May 18, 2024
e77c7e4
fix ruff error
May 18, 2024
987ee70
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein May 18, 2024
34e1eb1
merge with main and resolve conflicts
May 28, 2024
cd12ca6
add migration_11; fix typo
May 28, 2024
ead1748
issue a download progress event when install download starts
May 28, 2024
2276f32
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein Jun 2, 2024
132bbf3
tidy(app): remove unnecessary changes in invocation_context
psychedelicious Jun 2, 2024
e3a70e5
docs(app): simplify docstring in invocation_context
psychedelicious Jun 2, 2024
b124440
tidy(mm): move `load_model_from_url` from mm to invocation context
psychedelicious Jun 2, 2024
ccdecf2
tidy(nodes): cnet processors
psychedelicious Jun 2, 2024
521f907
tidy(nodes): infill
psychedelicious Jun 2, 2024
6cc6a45
feat(download): add type for callback_name
psychedelicious Jun 3, 2024
c58ac1e
tidy(mm): minor formatting
psychedelicious Jun 3, 2024
aa9695e
tidy(download): `_download_job` -> `_multifile_job`
psychedelicious Jun 3, 2024
9941325
tidy(mm): pass enum member instead of string
psychedelicious Jun 3, 2024
c7f22b6
tidy(mm): remove extraneous docstring
psychedelicious Jun 3, 2024
e7513f6
docs(mm): add comment in `move_model_to_device`
psychedelicious Jun 3, 2024
a9962fd
chore: ruff
psychedelicious Jun 3, 2024
f81b8bc
add support for generic loading of diffusers directories
Jun 4, 2024
9f93796
ruff fixes
Jun 4, 2024
dc13493
replace load_and_cache_model() with load_remote_model() and load_loca…
Jun 6, 2024
fde58ce
Merge remote-tracking branch 'origin/main' into lstein/feat/simple-mm…
psychedelicious Jun 7, 2024
7d19af2
Merge branch 'main' into lstein/feat/simple-mm2-api
lstein Jun 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions docs/contributing/DOWNLOAD_QUEUE.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.

The only job type currently implemented is `DownloadJob`, a pydantic object with the
Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields:

| **Field** | **Type** | **Default** | **Description** |
Expand All @@ -138,7 +139,7 @@ following fields:
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
Expand Down Expand Up @@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.

The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:

| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|

Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.


### Callbacks

Download jobs can be associated with a series of callbacks, each with
Expand Down Expand Up @@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.

#### job = queue.download(source, dest, priority, access_token)
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)

Create a new download job and put it on the queue, returning the
DownloadJob object.

#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)

This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.

`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.

The method returns a `MultiFileDownloadJob`.


```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```

#### jobs = queue.list_jobs()

Return a list of all active and inactive `DownloadJob`s.
Expand Down
71 changes: 63 additions & 8 deletions docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
following initialization pattern:

```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger

config = InvokeAIAppConfig.get_config()
config.parse_args()
config = get_config()

logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()

installer = ModelInstallService(app_config=config,
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
download_queue=queue
)
installer.start()
```

Expand Down Expand Up @@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.

## Invocation Context Model Manager API

Within invocations, the following methods are available from the
`InvocationContext` object:

### context.download_and_cache_model(source) -> Path

This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.

In the case of HuggingFace repo_id, the following variants are
recognized:

* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder

You can also point at an arbitrary individual file within a repo_id
directory using this syntax:

* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

### context.load_local_model(model_path, [loader]) -> LoadedModel

This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.

### context.load_remote_model(source, [loader]) -> LoadedModel

This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.

In the case of HuggingFace repo_id, the following variants are
recognized:

* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder

You can also point at an arbitrary individual file within a repo_id
directory using this syntax:

* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors



2 changes: 1 addition & 1 deletion invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
Expand Down
55 changes: 34 additions & 21 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union

import cv2
Expand Down Expand Up @@ -36,12 +37,13 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output

Expand Down Expand Up @@ -139,6 +141,7 @@ def load_image(self, context: InvocationContext) -> Image.Image:
return context.images.get_pil(self.image.image_name, "RGB")

def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
Expand Down Expand Up @@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
lstein marked this conversation as resolved.
Show resolved Hide resolved
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
Expand All @@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
Expand All @@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
Expand All @@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
Expand Down Expand Up @@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
Expand All @@ -405,7 +409,7 @@ def run_processor(self, image):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
Expand All @@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
Expand Down Expand Up @@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
Expand Down Expand Up @@ -496,8 +500,8 @@ def tile_resample(
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img

def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
Expand All @@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
Expand Down Expand Up @@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):

color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)

def run_processor(self, image: Image.Image):
def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]

Expand Down Expand Up @@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)

processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image


@invocation(
Expand All @@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)

def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
def run_processor(self, image: Image.Image) -> Image.Image:
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])

dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,
Expand Down
Loading
Loading