From 654f44f66f4496f041caaf652d9f62549f0f6662 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 1 Feb 2025 11:47:20 -0500 Subject: [PATCH] refactor interrogate/analyze/vqa code Signed-off-by: Vladimir Mandic --- CHANGELOG.md | 1 + modules/api/endpoints.py | 11 +++---- modules/{ => interrogate}/deepbooru.py | 5 ++-- modules/{ => interrogate}/deepbooru_model.py | 0 modules/interrogate/interrogate.py | 4 +++ .../{interrogate.py => interrogate/legacy.py} | 13 ++++++--- modules/{ => interrogate}/vqa.py | 0 modules/sd_offload.py | 6 ++++ modules/shared.py | 5 ++-- modules/ui_common.py | 11 +++---- modules/ui_control_helpers.py | 9 +++--- modules/ui_postprocessing.py | 29 ++++++++++--------- modules/ui_sections.py | 6 ++-- scripts/loopback.py | 6 ++-- 14 files changed, 64 insertions(+), 42 deletions(-) rename modules/{ => interrogate}/deepbooru.py (94%) rename modules/{ => interrogate}/deepbooru_model.py (100%) create mode 100644 modules/interrogate/interrogate.py rename modules/{interrogate.py => interrogate/legacy.py} (97%) rename modules/{ => interrogate}/vqa.py (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccfe41ace..4957fb1c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ > python modules/api/nvml.py - **Refactor**: - unified trace handler with configurable tracebacks + - refactor interrogate/analyze/vqa code - **Fixes**: - photomaker with offloading - photomaker with refine diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index 1c56b7171..8112fe4a8 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -74,7 +74,7 @@ def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, f return res def get_interrogate(): - from modules.interrogate import get_clip_models + from modules.interrogate.legacy import get_clip_models return ['clip', 'deepdanbooru'] + get_clip_models() def post_interrogate(req: models.ReqInterrogate): @@ -84,16 +84,17 @@ def post_interrogate(req: models.ReqInterrogate): image = image.convert('RGB') if req.model == "clip": try: - caption = shared.interrogator.interrogate(image) + from modules.interrogate import legacy + caption = legacy.interrogator.interrogate(image) except Exception as e: caption = str(e) return models.ResInterrogate(caption=caption) elif req.model == "deepdanbooru" or req.model == 'deepbooru': - from modules import deepbooru + from modules.interrogate import deepbooru caption = deepbooru.model.tag(image) return models.ResInterrogate(caption=caption) else: - from modules.interrogate import interrogate_image, analyze_image, get_clip_models + from modules.interrogate.legacy import interrogate_image, analyze_image, get_clip_models if req.model not in get_clip_models(): raise HTTPException(status_code=404, detail="Model not found") try: @@ -111,7 +112,7 @@ def post_vqa(req: models.ReqVQA): raise HTTPException(status_code=404, detail="Image not found") image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') - from modules import vqa + from modules.interrogate import vqa answer = vqa.interrogate(req.question, image, req.model) return models.ResVQA(answer=answer) diff --git a/modules/deepbooru.py b/modules/interrogate/deepbooru.py similarity index 94% rename from modules/deepbooru.py rename to modules/interrogate/deepbooru.py index f961a4351..5e54dcc1a 100644 --- a/modules/deepbooru.py +++ b/modules/interrogate/deepbooru.py @@ -4,7 +4,7 @@ import torch import numpy as np from PIL import Image -from modules import modelloader, paths, deepbooru_model, devices, images, shared +from modules import modelloader, paths, devices, images, shared re_special = re.compile(r'([\\()])') load_lock = threading.Lock() @@ -27,7 +27,8 @@ def load(self): download_name='model-resnet_custom_v3.pt', ) - self.model = deepbooru_model.DeepDanbooruModel() + from modules.interrogate.deepbooru_model import DeepDanbooruModel + self.model = DeepDanbooruModel() self.model.load_state_dict(torch.load(files[0], map_location="cpu")) self.model.eval() diff --git a/modules/deepbooru_model.py b/modules/interrogate/deepbooru_model.py similarity index 100% rename from modules/deepbooru_model.py rename to modules/interrogate/deepbooru_model.py diff --git a/modules/interrogate/interrogate.py b/modules/interrogate/interrogate.py new file mode 100644 index 000000000..1696e88e7 --- /dev/null +++ b/modules/interrogate/interrogate.py @@ -0,0 +1,4 @@ +def interrogate(image): + from modules.interrogate import legacy + prompt = legacy.interrogator.interrogate(image) + return prompt diff --git a/modules/interrogate.py b/modules/interrogate/legacy.py similarity index 97% rename from modules/interrogate.py rename to modules/interrogate/legacy.py index 1cf9aee95..486a6e2bb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate/legacy.py @@ -10,7 +10,7 @@ from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, errors +from modules import devices, paths, shared, lowvram, errors, sd_models config = { @@ -39,7 +39,7 @@ def category_types(): - return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + return [f.stem for f in Path(interrogator.content_dir).glob('*.txt')] def download_default_clip_interrogate_categories(content_dir): @@ -65,10 +65,10 @@ class InterrogateModels: dtype = None running_on_cpu = None - def __init__(self, content_dir): + def __init__(self, content_dir: str = None): self.loaded_categories = None self.skip_categories = [] - self.content_dir = content_dir + self.content_dir = content_dir or os.path.join(paths.models_path, "interrogate") self.running_on_cpu = False def categories(self): @@ -327,6 +327,8 @@ def interrogate_image(image, clip_model, blip_model, mode): if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): lowvram.send_everything_to_cpu() devices.torch_gc() + if shared.native: + sd_models.apply_balanced_offload(shared.sd_model) load_interrogator(clip_model, blip_model) image = image.convert('RGB') prompt = interrogate(image, mode) @@ -410,3 +412,6 @@ def analyze_image(image, clip_model, blip_model): trending_ranks = dict(zip(top_trendings, ci.similarities(image_features, top_trendings))) flavor_ranks = dict(zip(top_flavors, ci.similarities(image_features, top_flavors))) return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks + + +interrogator = InterrogateModels() diff --git a/modules/vqa.py b/modules/interrogate/vqa.py similarity index 100% rename from modules/vqa.py rename to modules/interrogate/vqa.py diff --git a/modules/sd_offload.py b/modules/sd_offload.py index 1c3265b66..b200bed1a 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -179,6 +179,12 @@ def apply_balanced_offload(sd_model, exclude=[]): global offload_hook_instance # pylint: disable=global-statement if shared.opts.diffusers_offload_mode != "balanced": return sd_model + if sd_model is None: + if not shared.sd_loaded: + return sd_model + sd_model = shared.sd_model + if sd_model is None: + return sd_model t0 = time.time() excluded = ['OmniGenPipeline'] if sd_model.__class__.__name__ in excluded: diff --git a/modules/shared.py b/modules/shared.py index d9a3b1572..4aa5a87ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -18,8 +18,8 @@ from modules.dml import memory_providers, default_memory_provider, directml_do_hijack from modules.onnx_impl import initialize_onnx, execution_providers from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import +from modules.interrogate.legacy import category_types from modules.ui_components import DropdownEditable -import modules.interrogate import modules.memmon import modules.styles import modules.paths as paths @@ -36,7 +36,6 @@ hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} xformers_available = False locking_available = True # used by file read/write locking -interrogator = modules.interrogate.InterrogateModels(os.path.join("models", "interrogate")) sd_upscalers = [] detailers = [] face_restorers = [] @@ -896,7 +895,7 @@ def get_default_modes(): "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(32, "Interrogate: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(192, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), + "interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": category_types()}, refresh=category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.65, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(False, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "Use spaces for tags in deepbooru"), diff --git a/modules/ui_common.py b/modules/ui_common.py index b42373957..c3d5f80fd 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -221,19 +221,20 @@ def open_folder(result_gallery, gallery_index = 0): subprocess.Popen(["xdg-open", path]) # pylint: disable=consider-using-with -def interrogate_clip(image): +def interrogate_clip(image): # legacy function if image is None: shared.log.error("Interrogate: no image selected") return gr.update() - prompt = shared.interrogator.interrogate(image) + from modules.interrogate import legacy + prompt = legacy.interrogator.interrogate(image) return gr.update() if prompt is None else prompt -def interrogate_booru(image): +def interrogate_booru(image): # legacy function if image is None: shared.log.error("Interrogate: no image selected") return gr.update() - from modules import deepbooru + from modules.interrogate import deepbooru prompt = deepbooru.model.tag(image) return gr.update() if prompt is None else prompt @@ -258,7 +259,7 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None): elem_classes=["gallery_main"], ) if prompt is not None: - ui_sections.create_interrogate_button(tab=tabname, inputs=result_gallery, output=prompt) + ui_sections.create_interrogate_button(tab=tabname, inputs=result_gallery, outputs=prompt) # interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons(tabname) # interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt]) # interrogate_booru_btn.click(fn=interrogate_booru, inputs=[result_gallery], outputs=[prompt]) diff --git a/modules/ui_control_helpers.py b/modules/ui_control_helpers.py index 7573d6f69..dae07b021 100644 --- a/modules/ui_control_helpers.py +++ b/modules/ui_control_helpers.py @@ -47,19 +47,20 @@ def initialize(): scripts.scripts_control.initialize_scripts(is_img2img=False, is_control=True) -def interrogate_clip(): +def interrogate_clip(): # legacy function prompt = None try: - prompt = shared.interrogator.interrogate(input_source[0]) + from modules.interrogate import legacy + prompt = legacy.interrogator.interrogate(input_source[0]) except Exception: pass return gr.update() if prompt is None else prompt -def interrogate_booru(): +def interrogate_booru(): # legacy function prompt = None try: - from modules import deepbooru + from modules.interrogate import deepbooru prompt = deepbooru.model.tag(input_source[0]) except Exception: pass diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 6e12339e7..a3faf9766 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,7 +1,8 @@ import json import gradio as gr -from modules import scripts, shared, ui_common, postprocessing, call_queue, interrogate, generation_parameters_copypaste +from modules import scripts, shared, ui_common, postprocessing, call_queue, generation_parameters_copypaste from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call # pylint: disable=unused-import +from modules.interrogate import legacy def submit_info(image): @@ -46,8 +47,8 @@ def create_ui(): flavor = gr.Label(elem_id="interrogate_label_flavor", label="Flavor", num_top_classes=5) with gr.Row(): clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLiP model') - ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models') - blip_model = gr.Dropdown(list(interrogate.caption_models), value='blip-base', label='Caption model') + ui_common.create_refresh_button(clip_model, legacy.get_clip_models, lambda: {"choices": legacy.get_clip_models()}, 'refresh_interrogate_models') + blip_model = gr.Dropdown(list(legacy.caption_models), value='blip-base', label='Caption model') mode = gr.Dropdown(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='fast') with gr.Accordion(label='Advanced', open=False, visible=True): with gr.Row(): @@ -56,20 +57,20 @@ def create_ui(): min_flavors = gr.Number(label='Min flavors', value=2, minimum=1, maximum=16, min_width=300) max_flavors = gr.Number(label='Max flavors', value=8, minimum=1, maximum=64, min_width=300) flavor_intermediate_count = gr.Number(label='Intermediates', value=1024, minimum=256, maximum=4096) - caption_max_length.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) - chunk_size.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) - min_flavors.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) - max_flavors.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) - flavor_intermediate_count.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) + caption_max_length.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) + chunk_size.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) + min_flavors.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) + max_flavors.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) + flavor_intermediate_count.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[]) with gr.Row(elem_id='interrogate_buttons_image'): btn_interrogate_img = gr.Button("Interrogate", elem_id="interrogate_btn_interrogate", variant='primary') btn_analyze_img = gr.Button("Analyze", elem_id="interrogate_btn_analyze", variant='primary') btn_unload = gr.Button("Unload", elem_id="interrogate_btn_unload") with gr.Row(elem_id='copy_buttons_interrogate'): copy_interrogate_buttons = generation_parameters_copypaste.create_buttons(["txt2img", "img2img", "extras", "control"]) - btn_interrogate_img.click(interrogate.interrogate_image, inputs=[image, clip_model, blip_model, mode], outputs=prompt) - btn_analyze_img.click(interrogate.analyze_image, inputs=[image, clip_model, blip_model], outputs=[medium, artist, movement, trending, flavor]) - btn_unload.click(interrogate.unload_clip_model) + btn_interrogate_img.click(legacy.interrogate_image, inputs=[image, clip_model, blip_model, mode], outputs=prompt) + btn_analyze_img.click(legacy.analyze_image, inputs=[image, clip_model, blip_model], outputs=[medium, artist, movement, trending, flavor]) + btn_unload.click(legacy.unload_clip_model) with gr.Tab("Interrogate Batch"): with gr.Row(): batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100) @@ -81,11 +82,11 @@ def create_ui(): batch = gr.Text(label="Prompts", lines=10) with gr.Row(): clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLiP Batch Model') - ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models') + ui_common.create_refresh_button(clip_model, legacy.get_clip_models, lambda: {"choices": legacy.get_clip_models()}, 'refresh_interrogate_models') with gr.Row(elem_id='interrogate_buttons_batch'): btn_interrogate_batch = gr.Button("Interrogate", elem_id="interrogate_btn_interrogate", variant='primary') with gr.Tab("Visual Query"): - from modules import vqa + from modules.interrogate import vqa with gr.Row(): vqa_image = gr.Image(type='pil', label="Image") with gr.Row(): @@ -148,7 +149,7 @@ def create_ui(): ] ) btn_interrogate_batch.click( - fn=interrogate.interrogate_batch, + fn=legacy.interrogate_batch, inputs=[batch_files, batch_folder, batch_str, clip_model, blip_model, mode, save_output], outputs=[batch], ) diff --git a/modules/ui_sections.py b/modules/ui_sections.py index f3395ff4a..fca9466b8 100644 --- a/modules/ui_sections.py +++ b/modules/ui_sections.py @@ -1,6 +1,7 @@ import gradio as gr from modules import shared, modelloader, ui_symbols, ui_common, sd_samplers from modules.ui_components import ToolButton +from modules.interrogate import interrogate def create_toprow(is_img2img: bool = False, id_part: str = None): @@ -92,11 +93,10 @@ def create_resolution_inputs(tab): def create_interrogate_button(tab: str, inputs: list, outputs: str): button_interrogate = gr.Button(ui_symbols.interrogate, elem_id=f"{tab}_interrogate", elem_classes=['interrogate']) - button_interrogate.click(fn=lambda: None, _js='() => quickInterrogate()', inputs=[], outputs=[]) - return button_interrogate + button_interrogate.click(fn=interrogate.interrogate, inputs=inputs, outputs=[outputs]) -def create_interrogate_buttons(tab): +def create_interrogate_buttons(tab): # legacy function button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip']) button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip']) return button_interrogate, button_deepbooru diff --git a/scripts/loopback.py b/scripts/loopback.py index af4844b8e..a1541a161 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -2,7 +2,7 @@ import gradio as gr import modules.scripts as scripts -from modules import deepbooru, images, processing, shared +from modules import images, processing from modules.processing import Processed from modules.shared import opts, state @@ -90,8 +90,10 @@ def calculate_denoising_strength(loop): if append_interrogation != "None": p.prompt = f"{original_prompt}, " if original_prompt else "" if append_interrogation == "CLIP": - p.prompt += shared.interrogator.interrogate(p.init_images[0]) + from modules.interrogate import legacy + p.prompt += legacy.interrogator.interrogate(p.init_images[0]) elif append_interrogation == "DeepBooru": + from modules.interrogate import deepbooru p.prompt += deepbooru.model.tag(p.init_images[0]) state.job = f"loopback iteration {i+1}/{loops} batch {n+1}/{batch_count}"