Skip to content

Commit

Permalink
Merge branch 'main' into stalker7779/modular_rescale_cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Jul 23, 2024
2 parents 1b359b5 + 154e8f6 commit d014dc9
Show file tree
Hide file tree
Showing 45 changed files with 1,374 additions and 240 deletions.
196 changes: 137 additions & 59 deletions invokeai/app/invocations/spandrel_image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

import numpy as np
import torch
from PIL import Image
Expand All @@ -21,7 +23,7 @@
from invokeai.backend.tiles.utils import TBLR, Tile


@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""

Expand All @@ -34,8 +36,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
scale: float = InputField(
default=4.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)

def _scale_tile(self, tile: Tile, scale: int) -> Tile:
@classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
return Tile(
coords=TBLR(
top=tile.coords.top * scale,
Expand All @@ -51,20 +64,22 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
),
)

@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")

@classmethod
def upscale_image(
cls,
image: Image.Image,
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
) -> Image.Image:
# Compute the image tiles.
if self.tile_size > 0:
if tile_size > 0:
min_overlap = 20
tiles = calc_tiles_min_overlap(
image_height=image.height,
image_width=image.width,
tile_height=self.tile_size,
tile_width=self.tile_size,
tile_height=tile_size,
tile_width=tile_size,
min_overlap=min_overlap,
)
else:
Expand All @@ -85,60 +100,123 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)

# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

# Run the model on each tile.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Scale the tiles for re-assembling the final image.
scale = spandrel_model.scale
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]

# Scale the tiles for re-assembling the final image.
scale = spandrel_model.scale
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)

# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)

image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)

for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if context.util.is_canceled():
raise CanceledException

# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)

# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)

# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))

# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException

# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)

# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)

# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))

# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]

# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)

return pil_image

@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")

# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)

# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
# to be considered an upscale model.
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height

if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1

# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
break
else:
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
# to be the same as the processed image size.

# The output size is now the size of the processed image.
target_width = pil_image.width
target_height = pil_image.height

# Warn the user if they requested a scale greater than 1.
if self.scale > 1:
context.logger.warning(
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
)

# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
# in the final resize
if self.fit_to_multiple_of_8:
target_width = int(target_width // 8 * 8)
target_height = int(target_height // 8 * 8)

# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)

image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
18 changes: 17 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@
"imageActions": "Image Actions",
"sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas",
"sendToUpscale": "Send To Upscale",
"showOptionsPanel": "Show Side Panel (O or T)",
"shuffle": "Shuffle Seed",
"steps": "Steps",
Expand Down Expand Up @@ -1640,6 +1641,19 @@
"layers_one": "Layer",
"layers_other": "Layers"
},
"upscaling": {
"creativity": "Creativity",
"structure": "Structure",
"upscaleModel": "Upscale Model",
"scale": "Scale",
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
},
"ui": {
"tabs": {
"generation": "Generation",
Expand All @@ -1651,7 +1665,9 @@
"models": "Models",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Queue",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerM
import type { AppDispatch, RootState } from 'app/store/store';

import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';

export const listenerMiddleware = createListenerMiddleware();

Expand Down Expand Up @@ -85,6 +86,7 @@ addGalleryOffsetChangedListener(startAppListening);
addEnqueueRequestedCanvasListener(startAppListening);
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { queueApi } from 'services/api/endpoints/queue';

export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;

const graph = await buildMultidiffusionUpscaleGraph(state);

const batchConfig = prepareLinearUIBatch(state, graph, prepend);

const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
} finally {
req.reset();
}
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
} from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';

export const dndDropped = createAction<{
Expand Down Expand Up @@ -243,6 +244,20 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
return;
}

/**
* Image dropped on upscale initial image
*/
if (
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;

dispatch(upscaleInitialImageChanged(imageDTO));
return;
}

/**
* Multiple images dropped on user board
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
Expand Down Expand Up @@ -89,6 +90,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}

if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
...DEFAULT_UPLOADED_TOAST,
description: 'set as upscale initial image',
});
return;
}

if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
dispatch(
Expand Down
Loading

0 comments on commit d014dc9

Please sign in to comment.