Skip to content

Commit

Permalink
Feat/controlnet extras (#3596)
Browse files Browse the repository at this point in the history
Trying to get a few ControlNet extras in before 3.0 release:

- SegmentAnything ControlNet preprocessor node
- LeResDepth ControlNet preprocessor node (but commented out till
controlnet_aux v0.0.6 is released & required by InvokeAI)
- TileResampler ControlNet preprocessor node (should be equivalent to
Mikubill/sd-webui-controlnet extension tile_resampler)
- fix for Midas ControlNet preprocessor error with images that have
alpha channel

Example usage of SegmentAnything preprocessor node:
![Screenshot from 2023-06-26
16-53-44](https://github.com/invoke-ai/InvokeAI/assets/303100/c6278f9a-5f6b-44bd-98b1-fcaf77251a76)
  • Loading branch information
blessedcoolant authored Jun 28, 2023
2 parents 00c78b1 + 32883ad commit 201b843
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
116 changes: 112 additions & 4 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# InvokeAI nodes for ControlNet image preprocessors
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float, bool

import cv2
import numpy as np
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator

Expand All @@ -29,8 +30,13 @@
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
SamDetector,
LeresDetector,
)

from controlnet_aux.util import HWC3, ade_palette


from .image import ImageOutput, PILInvocationConfig

CONTROLNET_DEFAULT_MODELS = [
Expand Down Expand Up @@ -95,6 +101,9 @@

CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]


class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
Expand All @@ -105,7 +114,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")

@validator("control_weight")
def abs_le_one(cls, v):
Expand Down Expand Up @@ -180,7 +190,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput:
),
)

# TODO: move image processors to separate file (image_analysis.py

class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""

Expand Down Expand Up @@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: on

def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == 'RGBA':
image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image

class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on

def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image


class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):

# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on

# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img

def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img,
#res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image




class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on

def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image

class SamDetectorReproducibleColors(SamDetector):

# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1",
"controlnet-aux>=0.0.4",
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.17.1",
Expand Down

0 comments on commit 201b843

Please sign in to comment.