Skip to content

Commit

Permalink
add vlms: qwen-vl2, smol-vl2, toriigate
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Feb 2, 2025
1 parent 0b0fd79 commit e40b33d
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 31 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Change Log for SD.Next

## Update for 2025-02-01
## Update for 2025-02-02

- **GitHub**
- rename core repo from <https://github.com/vladmandic/automatic> to <https://github.com/vladmandic/sdnext>
Expand All @@ -27,7 +27,9 @@
- single interrogate button for every input or output image
- behavior of interrogate configurable in *settings -> interrogate*
with detailed defaults for each model type also configurable
- select between 100+ *OpenCLiP* supported models, 10+ built-in *VLMs*, *DeepDanbooru*
- select between 150+ *OpenCLiP* supported models, 20+ built-in *VLMs*, *DeepDanbooru*
- **VLM**: now that we can use VLMs freely, we've also added support for few more out-of-the-box
*Alibaba Qwen VL2*, *Huggingface Smol VL2*, *ToriiGate 0.4*
- **Other**:
- **networks**: imporove search/filter and add visual indicators for types
- **balanced offload** new defaults: *lowvram/4gb min threshold: 0, medvram/8gb min threshold: 0, default min threshold 0.25*
Expand Down
2 changes: 1 addition & 1 deletion modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class ResInterrogate(BaseModel):

class ReqVQA(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
model: str = Field(default="MS Florence 2 Base", title="Model", description="The interrogate model used.")
model: str = Field(default="Microsoft Florence 2 Base", title="Model", description="The interrogate model used.")
question: str = Field(default="describe the image", title="Question", description="Question to ask the model.")

class ReqHistory(BaseModel):
Expand Down
187 changes: 160 additions & 27 deletions modules/interrogate/vqa.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
import io
import time
import json
import base64
import torch
import transformers
import transformers.dynamic_module_utils
from PIL import Image
from modules import shared, devices, errors

# TODO add additional vlmn
# https://huggingface.co/nvidia/Eagle2-1B
# https://huggingface.co/deepseek-ai/deepseek-vl2-tiny


processor = None
model = None
loaded: str = None
vlm_models = {
"MS Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB
"MS Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB
"Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB
"Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB
"MiaoshouAI PromptGen 1.5 Base": "MiaoshouAI/Florence-2-base-PromptGen-v1.5@c06a5f02cc6071a5d65ee5d294cf3732d3097540", # 1.1GB
"MiaoshouAI PromptGen 1.5 Large": "MiaoshouAI/Florence-2-large-PromptGen-v1.5@28a42440e39c9c32b83f7ae74ec2b3d1540404f0", # 3.3GB
"MiaoshouAI PromptGen 2.0 Base": "MiaoshouAI/Florence-2-base-PromptGen-v2.0", # 1.1GB
"MiaoshouAI PromptGen 2.0 Large": "MiaoshouAI/Florence-2-large-PromptGen-v2.0", # 3.3GB
"CogFlorence 2.0 Large": "thwri/CogFlorence-2-Large-Freeze", # 1.6GB
"CogFlorence 2.2 Large": "thwri/CogFlorence-2.2-Large", # 1.6GB
"Moondream 2": "vikhyatk/moondream2", # 3.7GB
"GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
"GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
"GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
"BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
"BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
"Alibaba Qwen VL2 2B": "Qwen/Qwen2-VL-2B-Instruct",
"Huggingface Smol VL2 0.5B": "HuggingFaceTB/SmolVLM-500M-Instruct",
"Huggingface Smol VL2 2B": "HuggingFaceTB/SmolVLM-Instruct",
"Salesforce BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
"Salesforce BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
"Google Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
"Microsoft GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
"Microsoft GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
"Microsoft GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
"ToriiGate 0.4 2B": "Minthy/ToriiGate-v0.4-2B", # TODO
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
"Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
}
vlm_prompts = [
'<CAPTION>',
Expand All @@ -44,6 +55,120 @@
]


def b64(image):
with io.BytesIO() as stream:
image.save(stream, 'JPEG')
values = stream.getvalue()
encoded = base64.b64encode(values).decode()
return encoded


def clean(response, question):
if isinstance(response, dict):
if 'task' in response:
response = response['task']
if 'answer' in response:
response = response['answer']
response = json.dumps(response)
if isinstance(response, list):
response = response[0]
question = question.replace('<', '').replace('>', '')
if question in response:
response = response.split(question, 1)[1]
response = response.replace('\n', '').replace('\r', '').replace('\t', '').strip()
response = response.replace('Assistant:', '').strip()
return response


def qwen(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
model.to(devices.device, devices.dtype)
if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored."}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\n'
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
output_ids = model.generate(
**inputs,
max_new_tokens=shared.opts.interrogate_vlm_max_length,
)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return response


def deepseek(question: str, image: Image.Image, repo: str = None):
return ''


def smol(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
model = transformers.AutoModelForVision2Seq.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
_attn_implementation="eager",
)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
model.to(devices.device, devices.dtype)
if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored."}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\n'
inputs = processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
output_ids = model.generate(
**inputs,
max_new_tokens=shared.opts.interrogate_vlm_max_length,
)
response = processor.batch_decode(output_ids,skip_special_tokens=True)
return response


def git(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
Expand Down Expand Up @@ -149,6 +274,10 @@ def get_imports(f):
if "flash_attn" in R:
R.remove("flash_attn") # flash_attn is optional
return R
if '@' in model:
model, revision = model.split('@')
else:
revision = None
if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
transformers.dynamic_module_utils.get_imports = get_imports
Expand Down Expand Up @@ -176,37 +305,32 @@ def get_imports(f):
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height))
if 'task' in response:
response = response['task']
if 'answer' in response:
response = response['answer']
if isinstance(response, dict):
response = json.dumps(response)
response = response.replace('\n', '').replace('\r', '').replace('\t', '').strip()
return response


def interrogate(question, image, model_name):
t0 = time.time()
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if image is None:
return ''
if image.width > 768 or image.height > 768:
image.thumbnail((768, 768), Image.Resampling.HAMMING)
if image.mode != 'RGB':
image = image.convert('RGB')
try:
vqa_model = vlm_models.get(model_name, None)
revision = None
if '@' in vqa_model:
vqa_model, revision = vqa_model.split('@')
if image is None:
answer = 'no image provided'
return answer
if model_name is None:
answer = 'no model selected'
return answer
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
return ''
vqa_model = vlm_models.get(model_name, None)
if vqa_model is None:
answer = f'unknown: model={model_name} available={vlm_models.keys()}'
return answer
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
return ''
if image is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image')
return ''
if 'git' in vqa_model.lower():
answer = git(question, image, vqa_model)
elif 'vilt' in vqa_model.lower():
Expand All @@ -218,7 +342,13 @@ def interrogate(question, image, model_name):
elif 'moondream2' in vqa_model.lower():
answer = moondream(question, image, vqa_model)
elif 'florence' in vqa_model.lower():
answer = florence(question, image, vqa_model, revision)
answer = florence(question, image, vqa_model)
elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower():
answer = qwen(question, image, vqa_model)
elif 'smol' in vqa_model.lower():
answer = smol(question, image, vqa_model)
elif 'deepseek' in vqa_model.lower():
answer = deepseek(question, image, vqa_model)
else:
answer = 'unknown model'
except Exception as e:
Expand All @@ -227,4 +357,7 @@ def interrogate(question, image, model_name):
if shared.opts.interrogate_offload and model is not None:
model.to(devices.cpu)
devices.torch_gc()
answer = clean(answer, question)
t1 = time.time()
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" time={t1-t0:.2f}')
return answer
2 changes: 1 addition & 1 deletion modules/ui_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_ui():
with gr.Row():
vqa_question = gr.Dropdown(label="Question", allow_custom_value=True, choices=vqa.vlm_prompts, value=vqa.vlm_prompts[2])
with gr.Row():
vqa_answer = gr.Textbox(label="Answer", lines=3)
vqa_answer = gr.Textbox(label="Answer", lines=5)
with gr.Row(elem_id='interrogate_buttons_query'):
vqa_model = gr.Dropdown(list(vqa.vlm_models), value=list(vqa.vlm_models)[0], label='VLM Model')
vqa_submit = gr.Button("Interrogate", elem_id="interrogate_btn_interrogate", variant='primary')
Expand Down

0 comments on commit e40b33d

Please sign in to comment.