forked from magic-quill/ComfyUI_MagicQuill
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scribble_color_edit.py
122 lines (108 loc) · 6.32 KB
/
scribble_color_edit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch.nn.functional as F
import torch
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
sys.path.append(os.path.abspath(os.path.join(current_dir, '..', '..', 'comfy_extras')))
print(sys.path)
from ComfyUI_BrushNet.brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension
from comfyui_controlnet_aux.node_wrappers.lineart import LineArt_Preprocessor
from comfyui_controlnet_aux.node_wrappers.pidinet import PIDINET_Preprocessor
from comfyui_controlnet_aux.node_wrappers.color import Color_Preprocessor
from nodes import ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode
from nodes_mask import GrowMask
class ScribbleColorEditModel():
def __init__(self):
self.clip_text_encoder = CLIPTextEncode()
self.mask_processor = GrowMask()
self.controlnet_loader = ControlNetLoader()
self.scribble_processor = PIDINET_Preprocessor()
self.lineart_processor = LineArt_Preprocessor()
self.color_processor = Color_Preprocessor()
self.brushnet_loader = BrushNetLoader()
self.brushnet_node = BrushNet()
self.controlnet_apply = ControlNetApplyAdvanced()
self.ksampler = KSampler()
self.vae_decoder = VAEDecode()
self.blender = BlendInpaint()
# self.load_models('SD1.5', 'float16')
def load_models(self, base_model_version, dtype):
if base_model_version == "SD1.5":
edge_controlnet_name = "control_v11p_sd15_scribble.safetensors"
color_controlnet_name = "color_finetune.safetensors"
brushnet_name = os.path.join("brushnet", "random_mask_brushnet_ckpt", "diffusion_pytorch_model.safetensors")
else:
raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version))
self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0]
self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0]
self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
def process(self, model, vae, clip, image, colored_image, base_model_version, positive_prompt, negative_prompt, dtype, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
print("mask.shape", mask.shape)
print("image.shape", image.shape)
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
self.load_models(base_model_version, dtype)
positive = self.clip_text_encoder.encode(clip, positive_prompt)[0]
negative = self.clip_text_encoder.encode(clip, negative_prompt)[0]
mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0]
image_copy = image.clone()
if stroke_as_edge == "disable":
bool_add_mask = add_mask > 0.5
mean_brightness = image_copy[bool_add_mask].mean()
if mean_brightness > 0.8:
image_copy[bool_add_mask] = 0.0
else:
image_copy[bool_add_mask] = 1.0
if not torch.equal(image, colored_image):
print("Apply color controlnet")
color_output = self.color_processor.execute(colored_image, resolution=2048)[0]
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
print("edge_map.shape", lineart_output.shape)
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0)
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0)
else:
print("Apply edge controlnet")
color_output = self.color_processor.execute(colored_image, resolution=2048)[0]
if fine_edge == "enable":
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
else:
lineart_output = self.scribble_processor.execute(image, safe='enable',resolution=512)[0]
add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
bool_add_mask_resized = (add_mask_resized > 0.5)
bool_remove_mask_resized = (remove_mask_resized > 0.5)
if stroke_as_edge == "enable":
lineart_output[bool_remove_mask_resized] = 0.0
lineart_output[bool_add_mask_resized] = 1.0
else:
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0)
model, positive, negative, latent = self.brushnet_node.model_update(
model=model,
vae=vae,
image=image,
mask=mask,
brushnet=self.brushnet,
positive=positive,
negative=negative,
scale=inpaint_strength,
start_at=0,
end_at=10000
)
latent_samples = self.ksampler.sample(
model=model,
seed=seed,
steps=steps,
cfg=cfg,
sampler_name=sampler_name,
scheduler=scheduler,
positive=positive,
negative=negative,
latent_image=latent,
)[0]
final_image = self.vae_decoder.decode(vae, latent_samples)[0]
final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0]
return (latent_samples, final_image, lineart_output, color_output)