Skip to content

Commit

Permalink
refactor interrogate/analyze/vqa code
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Feb 1, 2025
1 parent ceaf023 commit 654f44f
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
> python modules/api/nvml.py
- **Refactor**:
- unified trace handler with configurable tracebacks
- refactor interrogate/analyze/vqa code
- **Fixes**:
- photomaker with offloading
- photomaker with refine
Expand Down
11 changes: 6 additions & 5 deletions modules/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, f
return res

def get_interrogate():
from modules.interrogate import get_clip_models
from modules.interrogate.legacy import get_clip_models
return ['clip', 'deepdanbooru'] + get_clip_models()

def post_interrogate(req: models.ReqInterrogate):
Expand All @@ -84,16 +84,17 @@ def post_interrogate(req: models.ReqInterrogate):
image = image.convert('RGB')
if req.model == "clip":
try:
caption = shared.interrogator.interrogate(image)
from modules.interrogate import legacy
caption = legacy.interrogator.interrogate(image)
except Exception as e:
caption = str(e)
return models.ResInterrogate(caption=caption)
elif req.model == "deepdanbooru" or req.model == 'deepbooru':
from modules import deepbooru
from modules.interrogate import deepbooru
caption = deepbooru.model.tag(image)
return models.ResInterrogate(caption=caption)
else:
from modules.interrogate import interrogate_image, analyze_image, get_clip_models
from modules.interrogate.legacy import interrogate_image, analyze_image, get_clip_models
if req.model not in get_clip_models():
raise HTTPException(status_code=404, detail="Model not found")
try:
Expand All @@ -111,7 +112,7 @@ def post_vqa(req: models.ReqVQA):
raise HTTPException(status_code=404, detail="Image not found")
image = helpers.decode_base64_to_image(req.image)
image = image.convert('RGB')
from modules import vqa
from modules.interrogate import vqa
answer = vqa.interrogate(req.question, image, req.model)
return models.ResVQA(answer=answer)

Expand Down
5 changes: 3 additions & 2 deletions modules/deepbooru.py → modules/interrogate/deepbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import numpy as np
from PIL import Image
from modules import modelloader, paths, deepbooru_model, devices, images, shared
from modules import modelloader, paths, devices, images, shared

re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
Expand All @@ -27,7 +27,8 @@ def load(self):
download_name='model-resnet_custom_v3.pt',
)

self.model = deepbooru_model.DeepDanbooruModel()
from modules.interrogate.deepbooru_model import DeepDanbooruModel
self.model = DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))

self.model.eval()
Expand Down
File renamed without changes.
4 changes: 4 additions & 0 deletions modules/interrogate/interrogate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def interrogate(image):
from modules.interrogate import legacy
prompt = legacy.interrogator.interrogate(image)
return prompt
13 changes: 9 additions & 4 deletions modules/interrogate.py → modules/interrogate/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, errors
from modules import devices, paths, shared, lowvram, errors, sd_models


config = {
Expand Down Expand Up @@ -39,7 +39,7 @@


def category_types():
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
return [f.stem for f in Path(interrogator.content_dir).glob('*.txt')]


def download_default_clip_interrogate_categories(content_dir):
Expand All @@ -65,10 +65,10 @@ class InterrogateModels:
dtype = None
running_on_cpu = None

def __init__(self, content_dir):
def __init__(self, content_dir: str = None):
self.loaded_categories = None
self.skip_categories = []
self.content_dir = content_dir
self.content_dir = content_dir or os.path.join(paths.models_path, "interrogate")
self.running_on_cpu = False

def categories(self):
Expand Down Expand Up @@ -327,6 +327,8 @@ def interrogate_image(image, clip_model, blip_model, mode):
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
devices.torch_gc()
if shared.native:
sd_models.apply_balanced_offload(shared.sd_model)
load_interrogator(clip_model, blip_model)
image = image.convert('RGB')
prompt = interrogate(image, mode)
Expand Down Expand Up @@ -410,3 +412,6 @@ def analyze_image(image, clip_model, blip_model):
trending_ranks = dict(zip(top_trendings, ci.similarities(image_features, top_trendings)))
flavor_ranks = dict(zip(top_flavors, ci.similarities(image_features, top_flavors)))
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks


interrogator = InterrogateModels()
File renamed without changes.
6 changes: 6 additions & 0 deletions modules/sd_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def apply_balanced_offload(sd_model, exclude=[]):
global offload_hook_instance # pylint: disable=global-statement
if shared.opts.diffusers_offload_mode != "balanced":
return sd_model
if sd_model is None:
if not shared.sd_loaded:
return sd_model
sd_model = shared.sd_model
if sd_model is None:
return sd_model
t0 = time.time()
excluded = ['OmniGenPipeline']
if sd_model.__class__.__name__ in excluded:
Expand Down
5 changes: 2 additions & 3 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
from modules.onnx_impl import initialize_onnx, execution_providers
from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import
from modules.interrogate.legacy import category_types
from modules.ui_components import DropdownEditable
import modules.interrogate
import modules.memmon
import modules.styles
import modules.paths as paths
Expand All @@ -36,7 +36,6 @@
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
xformers_available = False
locking_available = True # used by file read/write locking
interrogator = modules.interrogate.InterrogateModels(os.path.join("models", "interrogate"))
sd_upscalers = []
detailers = []
face_restorers = []
Expand Down Expand Up @@ -896,7 +895,7 @@ def get_default_modes():
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(32, "Interrogate: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(192, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": category_types()}, refresh=category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.65, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(False, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "Use spaces for tags in deepbooru"),
Expand Down
11 changes: 6 additions & 5 deletions modules/ui_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,20 @@ def open_folder(result_gallery, gallery_index = 0):
subprocess.Popen(["xdg-open", path]) # pylint: disable=consider-using-with


def interrogate_clip(image):
def interrogate_clip(image): # legacy function
if image is None:
shared.log.error("Interrogate: no image selected")
return gr.update()
prompt = shared.interrogator.interrogate(image)
from modules.interrogate import legacy
prompt = legacy.interrogator.interrogate(image)
return gr.update() if prompt is None else prompt


def interrogate_booru(image):
def interrogate_booru(image): # legacy function
if image is None:
shared.log.error("Interrogate: no image selected")
return gr.update()
from modules import deepbooru
from modules.interrogate import deepbooru
prompt = deepbooru.model.tag(image)
return gr.update() if prompt is None else prompt

Expand All @@ -258,7 +259,7 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None):
elem_classes=["gallery_main"],
)
if prompt is not None:
ui_sections.create_interrogate_button(tab=tabname, inputs=result_gallery, output=prompt)
ui_sections.create_interrogate_button(tab=tabname, inputs=result_gallery, outputs=prompt)
# interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons(tabname)
# interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt])
# interrogate_booru_btn.click(fn=interrogate_booru, inputs=[result_gallery], outputs=[prompt])
Expand Down
9 changes: 5 additions & 4 deletions modules/ui_control_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def initialize():
scripts.scripts_control.initialize_scripts(is_img2img=False, is_control=True)


def interrogate_clip():
def interrogate_clip(): # legacy function
prompt = None
try:
prompt = shared.interrogator.interrogate(input_source[0])
from modules.interrogate import legacy
prompt = legacy.interrogator.interrogate(input_source[0])
except Exception:
pass
return gr.update() if prompt is None else prompt


def interrogate_booru():
def interrogate_booru(): # legacy function
prompt = None
try:
from modules import deepbooru
from modules.interrogate import deepbooru
prompt = deepbooru.model.tag(input_source[0])
except Exception:
pass
Expand Down
29 changes: 15 additions & 14 deletions modules/ui_postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue, interrogate, generation_parameters_copypaste
from modules import scripts, shared, ui_common, postprocessing, call_queue, generation_parameters_copypaste
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call # pylint: disable=unused-import
from modules.interrogate import legacy


def submit_info(image):
Expand Down Expand Up @@ -46,8 +47,8 @@ def create_ui():
flavor = gr.Label(elem_id="interrogate_label_flavor", label="Flavor", num_top_classes=5)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLiP model')
ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models')
blip_model = gr.Dropdown(list(interrogate.caption_models), value='blip-base', label='Caption model')
ui_common.create_refresh_button(clip_model, legacy.get_clip_models, lambda: {"choices": legacy.get_clip_models()}, 'refresh_interrogate_models')
blip_model = gr.Dropdown(list(legacy.caption_models), value='blip-base', label='Caption model')
mode = gr.Dropdown(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='fast')
with gr.Accordion(label='Advanced', open=False, visible=True):
with gr.Row():
Expand All @@ -56,20 +57,20 @@ def create_ui():
min_flavors = gr.Number(label='Min flavors', value=2, minimum=1, maximum=16, min_width=300)
max_flavors = gr.Number(label='Max flavors', value=8, minimum=1, maximum=64, min_width=300)
flavor_intermediate_count = gr.Number(label='Intermediates', value=1024, minimum=256, maximum=4096)
caption_max_length.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
chunk_size.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
min_flavors.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
max_flavors.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
flavor_intermediate_count.change(fn=interrogate.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
caption_max_length.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
chunk_size.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
min_flavors.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
max_flavors.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
flavor_intermediate_count.change(fn=legacy.update_interrogate_params, inputs=[caption_max_length, chunk_size, min_flavors, max_flavors, flavor_intermediate_count], outputs=[])
with gr.Row(elem_id='interrogate_buttons_image'):
btn_interrogate_img = gr.Button("Interrogate", elem_id="interrogate_btn_interrogate", variant='primary')
btn_analyze_img = gr.Button("Analyze", elem_id="interrogate_btn_analyze", variant='primary')
btn_unload = gr.Button("Unload", elem_id="interrogate_btn_unload")
with gr.Row(elem_id='copy_buttons_interrogate'):
copy_interrogate_buttons = generation_parameters_copypaste.create_buttons(["txt2img", "img2img", "extras", "control"])
btn_interrogate_img.click(interrogate.interrogate_image, inputs=[image, clip_model, blip_model, mode], outputs=prompt)
btn_analyze_img.click(interrogate.analyze_image, inputs=[image, clip_model, blip_model], outputs=[medium, artist, movement, trending, flavor])
btn_unload.click(interrogate.unload_clip_model)
btn_interrogate_img.click(legacy.interrogate_image, inputs=[image, clip_model, blip_model, mode], outputs=prompt)
btn_analyze_img.click(legacy.analyze_image, inputs=[image, clip_model, blip_model], outputs=[medium, artist, movement, trending, flavor])
btn_unload.click(legacy.unload_clip_model)
with gr.Tab("Interrogate Batch"):
with gr.Row():
batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100)
Expand All @@ -81,11 +82,11 @@ def create_ui():
batch = gr.Text(label="Prompts", lines=10)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLiP Batch Model')
ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models')
ui_common.create_refresh_button(clip_model, legacy.get_clip_models, lambda: {"choices": legacy.get_clip_models()}, 'refresh_interrogate_models')
with gr.Row(elem_id='interrogate_buttons_batch'):
btn_interrogate_batch = gr.Button("Interrogate", elem_id="interrogate_btn_interrogate", variant='primary')
with gr.Tab("Visual Query"):
from modules import vqa
from modules.interrogate import vqa
with gr.Row():
vqa_image = gr.Image(type='pil', label="Image")
with gr.Row():
Expand Down Expand Up @@ -148,7 +149,7 @@ def create_ui():
]
)
btn_interrogate_batch.click(
fn=interrogate.interrogate_batch,
fn=legacy.interrogate_batch,
inputs=[batch_files, batch_folder, batch_str, clip_model, blip_model, mode, save_output],
outputs=[batch],
)
Expand Down
6 changes: 3 additions & 3 deletions modules/ui_sections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gradio as gr
from modules import shared, modelloader, ui_symbols, ui_common, sd_samplers
from modules.ui_components import ToolButton
from modules.interrogate import interrogate


def create_toprow(is_img2img: bool = False, id_part: str = None):
Expand Down Expand Up @@ -92,11 +93,10 @@ def create_resolution_inputs(tab):

def create_interrogate_button(tab: str, inputs: list, outputs: str):
button_interrogate = gr.Button(ui_symbols.interrogate, elem_id=f"{tab}_interrogate", elem_classes=['interrogate'])
button_interrogate.click(fn=lambda: None, _js='() => quickInterrogate()', inputs=[], outputs=[])
return button_interrogate
button_interrogate.click(fn=interrogate.interrogate, inputs=inputs, outputs=[outputs])


def create_interrogate_buttons(tab):
def create_interrogate_buttons(tab): # legacy function
button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip'])
button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip'])
return button_interrogate, button_deepbooru
Expand Down
6 changes: 4 additions & 2 deletions scripts/loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gradio as gr
import modules.scripts as scripts
from modules import deepbooru, images, processing, shared
from modules import images, processing
from modules.processing import Processed
from modules.shared import opts, state

Expand Down Expand Up @@ -90,8 +90,10 @@ def calculate_denoising_strength(loop):
if append_interrogation != "None":
p.prompt = f"{original_prompt}, " if original_prompt else ""
if append_interrogation == "CLIP":
p.prompt += shared.interrogator.interrogate(p.init_images[0])
from modules.interrogate import legacy
p.prompt += legacy.interrogator.interrogate(p.init_images[0])
elif append_interrogation == "DeepBooru":
from modules.interrogate import deepbooru
p.prompt += deepbooru.model.tag(p.init_images[0])

state.job = f"loopback iteration {i+1}/{loops} batch {n+1}/{batch_count}"
Expand Down

0 comments on commit 654f44f

Please sign in to comment.