Skip to content

Commit

Permalink
Merge branch 'main' into ebr/docker-readme-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ebr committed Aug 22, 2024
2 parents 12d1c0b + c451f52 commit 6afc2df
Show file tree
Hide file tree
Showing 290 changed files with 29,681 additions and 23,493 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:

- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
run: pip install ruff==0.6.0
shell: bash

- name: ruff check
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use [email protected]
RUN corepack enable

WORKDIR /build
Expand Down
2 changes: 1 addition & 1 deletion installer/templates/invoke.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set -eu

# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
scriptdir=$(dirname $(readlink -f "$0"))
cd "$scriptdir"

. .venv/bin/activate
Expand Down
17 changes: 15 additions & 2 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

import asyncio
from logging import Logger

import torch
Expand Down Expand Up @@ -31,6 +32,8 @@
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
Expand Down Expand Up @@ -63,7 +66,12 @@ class ApiDependencies:
invoker: Invoker

@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")

Expand All @@ -74,6 +82,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
image_files = DiskImageFileStorage(f"{output_folder}/images")

model_images_folder = config.models_path
style_presets_folder = config.style_presets_path

db = init_db(config=config, logger=logger, image_files=image_files)

Expand All @@ -84,7 +93,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
Expand All @@ -109,6 +118,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")

services = InvocationServices(
board_image_records=board_image_records,
Expand All @@ -134,6 +145,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)

ApiDependencies.invoker = Invoker(services)
Expand Down
16 changes: 14 additions & 2 deletions invokeai/app/api/routers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,8 @@ async def get_image_workflow(
raise HTTPException(status_code=404)


@images_router.api_route(
@images_router.get(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
Expand All @@ -231,6 +230,18 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
Expand All @@ -242,6 +253,7 @@ async def get_image_full(
content = f.read()
response = Response(content, media_type="image/png")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)
Expand Down
29 changes: 14 additions & 15 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from typing import List, Optional, Type

from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
Expand Down Expand Up @@ -430,13 +430,11 @@ async def delete_model_image(
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
Expand All @@ -451,8 +449,9 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`config` is a ModelRecordChanges object. Fields in this object will override
the ones that are probed automatically. Pass an empty object to accept
all the defaults.
`access_token` is an optional access token for use with Urls that require
authentication.
Expand Down Expand Up @@ -737,7 +736,7 @@ async def convert_model(
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
raw_model.save_pretrained(convert_path) # type: ignore
assert convert_path.exists()

# temporarily rename the original safetensors file so that there is no naming conflict
Expand All @@ -750,12 +749,12 @@ async def convert_model(
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
config=ModelRecordChanges(
name=original_name,
description=model_config.description,
hash=model_config.hash,
source=model_config.source,
),
)
except Exception as e:
logger.error(str(e))
Expand Down
Loading

0 comments on commit 6afc2df

Please sign in to comment.