Skip to content

Commit

Permalink
feat(nodes): add PiDiNetEdgeDetectionInvocation
Browse files Browse the repository at this point in the history
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.

All code related to the invocation now lives in the Invoke repo.
  • Loading branch information
psychedelicious committed Sep 10, 2024
1 parent 11704da commit 8dab787
Show file tree
Hide file tree
Showing 5 changed files with 824 additions and 2 deletions.
33 changes: 33 additions & 0 deletions invokeai/app/invocations/pidi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.pidi import PIDINetDetector
from invokeai.backend.image_util.pidi.model import PiDiNet


@invocation(
"pidi_edge_detection",
title="PiDiNet Edge Detection",
tags=["controlnet", "edge"],
category="controlnet",
version="1.0.0",
)
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an edge map using PiDiNet."""

image: ImageField = InputField(description="The image to process")
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)

def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)

with loaded_model as model:
assert isinstance(model, PiDiNet)
detector = PIDINetDetector(model)
edge_map = detector.run(image=image, safe=self.safe, scribble=self.scribble)

image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)
79 changes: 79 additions & 0 deletions invokeai/backend/image_util/pidi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Adapted from https://github.com/huggingface/controlnet_aux

import pathlib

import cv2
import huggingface_hub
import numpy as np
import torch
from einops import rearrange
from PIL import Image

from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step


class PIDINetDetector:
"""Simple wrapper around a PiDiNet model for edge detection."""

hf_repo_id = "lllyasviel/Annotators"
hf_filename = "table5_pidinet.pth"

@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)

@classmethod
def load_model(cls, model_path: pathlib.Path) -> PiDiNet:
"""Load the model from a file."""

model = pidinet()
model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(model_path)["state_dict"].items()})
model.eval()
return model

def __init__(self, model: PiDiNet) -> None:
self.model = model

def to(self, device: torch.device):
self.model.to(device)
return self

def run(
self, image: Image.Image, safe: bool = False, scribble: bool = False, apply_filter: bool = False
) -> Image.Image:
"""Processes an image and returns the detected edges."""

device = next(iter(self.model.parameters())).device

np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)

assert np_img.ndim == 3

bgr_img = np_img[:, :, ::-1].copy()

with torch.no_grad():
image_pidi = torch.from_numpy(bgr_img).float().to(device)
image_pidi = image_pidi / 255.0
image_pidi = rearrange(image_pidi, "h w c -> 1 c h w")
edge = self.model(image_pidi)[-1]
edge = edge.cpu().numpy()
if apply_filter:
edge = edge > 0.5
if safe:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)

detected_map = edge[0, 0]

if scribble:
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0

output_img = np_to_pil(detected_map)

return output_img
Loading

0 comments on commit 8dab787

Please sign in to comment.