Skip to content

Commit

Permalink
Feature mcboaty llm + new node McPrompty (#79)
Browse files Browse the repository at this point in the history
* Update McBoaty Upscale Mechanism + add Per Tile Prompting generation + add new node to generate prompt from image

* hotfix llm vision name calling

* hotfix output size upscaling

* hotfix tiles order and execution duration display as int

* hotfix tiles order display and default llm model

* hotfix display index + default LLM Models list

* hotfix McPrompty name

* hotfix McPrompty name

* hotfix prestartup_script

* hotfix readme on tile prompting
  • Loading branch information
MaraScott committed Jun 15, 2024
1 parent f1948e4 commit 4121e4a
Show file tree
Hide file tree
Showing 11 changed files with 355 additions and 34 deletions.
5 changes: 4 additions & 1 deletion MaraScott_Nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .py.nodes.UpscalerRefiner.McBoaty_v2 import UpscalerRefiner_McBoaty_v2
from .py.nodes.UpscalerRefiner.McBoaty_v3 import UpscalerRefiner_McBoaty_v3
from .py.nodes.KSampler.InpaintingTileByMask_v1 import KSampler_setInpaintingTileByMask_v1, KSampler_pasteInpaintingTileByMask_v1
from .py.nodes.Prompt.PromptFromImage_v1 import PromptFromImage_v1
from .py.vendor.ComfyUI_JNodes.blob.main.py.prompting_nodes import TokenCounter as TokenCounter_v1

WEB_DIRECTORY = "./web/assets/js"
Expand All @@ -24,6 +25,7 @@
"MaraScottUpscalerRefinerNode_v3": UpscalerRefiner_McBoaty_v3,
"MaraScottSetInpaintingByMask_v1": KSampler_setInpaintingTileByMask_v1,
"MaraScottPasteInpaintingByMask_v1": KSampler_pasteInpaintingTileByMask_v1,
"MaraScottPromptFromImage_v1": PromptFromImage_v1,

"MaraScott_Kijai_TokenCounter_v1": TokenCounter_v1,

Expand All @@ -41,6 +43,7 @@
"MaraScottUpscalerRefinerNode_v3": "\ud83d\udc30 Large Refiner - McBoaty v3 /u",
"MaraScottSetInpaintingByMask_v1": "\ud83d\udc30 Set Inpainting Tile by mask - McInpainty [1/2] v1 /m",
"MaraScottPasteInpaintingByMask_v1": "\ud83d\udc30 Paste Inpainting Tile by mask - McInpainty [2/2] v1 /m",
"MaraScottPromptFromImage_v1": "\ud83d\udc30 Prompt From Image - McPrompty v1 /p",

"MaraScott_Kijai_TokenCounter_v1": "\ud83d\udc30 TokenCounter (from kijai/ComfyUI-KJNodes) /v",

Expand All @@ -49,4 +52,4 @@
"MaraScottUpscalerRefinerNode_v2": "\u274C Large Refiner - McBoaty v2 /u",
}

print('\033[34m[Maras IT] \033[92mLoaded\033[0m')
print('\033[34m[MaraScott] \033[92mLoaded\033[0m')
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,16 @@ This AnyBus is dyslexia friendly :D
The Upscaler Refiner Node (AKA McBoaty Node) is an upscaler coupled with a refiner to achieve higher rendering results.
The output image is a slightly modified image.

to use `Tile Prompting` we recommend to [setup your groq API key on your computer](https://console.groq.com/docs/quickstart) to improve tile prompting accurancy

**Not Supported**:
- ControlNet : your conditioning needs to be ControlNet Free

# Prompt From Image Node AKA McPrompty Node

The Prompt From Image Node (AKA McPrompty Node) is a prompt generator node using an image as input coupled with llm engine (Grok) to generate the text.
The output text can be used as prompt afterwards.

# Thanks

## Special thanks
Expand Down
6 changes: 5 additions & 1 deletion prestartup_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@
#
###

print('\033[34m[Maras IT] \033[92mInitialization\033[0m')
from custom_nodes.ComfyUI_MaraScott_Nodes.py.inc.lib.llm import MS_Llm

print('\033[34m[MaraScott] \033[92mInitialization\033[0m')

MS_Llm.prestartup_script()
11 changes: 6 additions & 5 deletions py/inc/lib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_dynamic_grid_specs(self, width, height, rows_qty = 3, cols_qty = 3, size
tiles = []
for row_index, row in enumerate(tile_order_rows):
for col_index, col in enumerate(tile_order_cols):
index = (row * len(tile_order_rows)) + col
order = (row_index * len(tile_order_rows)) + col_index

_tile_width = (tile_width_units_qty + 2) * size_unit
_tile_height = (tile_height_units_qty + 2) * size_unit
Expand All @@ -112,7 +112,7 @@ def get_dynamic_grid_specs(self, width, height, rows_qty = 3, cols_qty = 3, size
tiles.append([
row_index,
col_index,
index,
order,
x, # x
y, # y
_tile_width, # width
Expand All @@ -136,7 +136,7 @@ def get_grid_images(self, image, grid_specs):
return grids

@classmethod
def rebuild_image_from_parts(self, iteration, output_images, image, grid_specs, feather_mask, upscale_scale):
def rebuild_image_from_parts(self, iteration, output_images, image, grid_specs, feather_mask, upscale_scale, grid_prompts):

width_feather_seam = feather_mask
height_feather_seam = feather_mask
Expand Down Expand Up @@ -186,7 +186,8 @@ def rebuild_image_from_parts(self, iteration, output_images, image, grid_specs,
for index, grid_spec in enumerate(grid_specs):
log(f"Rebuilding tile {index + 1}/{total}", None, None, f"Refining {iteration}")
row, col, order, x_start, y_start, width_inc, height_inc = grid_spec
tiles_order.append((order, output_images[index]))
prompt = grid_prompts[index] if 0 <= index < len(grid_prompts) else ""
tiles_order.append((order, output_images[index], prompt))
if col == 0:
outputRow = nodes.ImagePadForOutpaint().expand_image(output_images[index], 0, 0, (image.shape[2]*upscale_scale) - tile_width, 0, 0)[0]
elif col == last_tile_col_index:
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_dynamic_grid_specs(self, width, height, tile_rows = 3, tile_cols =3):
tiles.append([
(col * len(tile_order)) + row,
(row * tile_width) - (row * width_unit), # x
(col * tile_height) - (col * height_unit), # x
(col * tile_height) - (col * height_unit), # y
tile_width, # width
tile_height, # height
])
Expand Down
174 changes: 174 additions & 0 deletions py/inc/lib/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os
import requests
import torch
import folder_paths
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers import BlipProcessor, BlipForConditionalGeneration
from groq import Groq
from .image import MS_Image_v2 as MS_Image

from ...utils.log import log

class MS_Llm_Microsoft():

@classmethod
def __init__(self, model_name = "microsoft/kosmos-2-patch14-224"):
self.name = model_name
self.model = AutoModelForVision2Seq.from_pretrained(self.name)
self.processor = AutoProcessor.from_pretrained(self.name)

@classmethod
def generate_prompt(self, image):

# prompt_prefix = "<grounding>An image of"
prompt_prefix = ""

_image = MS_Image.tensor2pil(image)

inputs = self.processor(text=prompt_prefix, images=_image, return_tensors="pt")

# Generate the caption
generated_ids = self.model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=128,
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
caption, _ = self.processor.post_process_generation(generated_text)

return caption


class MS_Llm_Salesforce():

@classmethod
def __init__(self, model_name = "Salesforce/blip-image-captioning-large"):
self.name = model_name
self.model = BlipForConditionalGeneration.from_pretrained(self.name)
self.processor = BlipProcessor.from_pretrained(self.name)

@classmethod
def generate_prompt(self, image):

# prompt_prefix = "<grounding>An image of"
prompt_prefix = ""

_image = MS_Image.tensor2pil(image)

inputs = self.processor(text=prompt_prefix, images=_image, return_tensors="pt")

# Generate the caption
generated_ids = self.model.generate(**inputs)
caption = self.processor.decode(generated_ids[0], skip_special_tokens=True)

return caption

class MS_Llm_Nlpconnect():

@classmethod
def __init__(self, model_name = "nlpconnect/vit-gpt2-image-captioning"):
self.name = model_name
self.model = VisionEncoderDecoderModel.from_pretrained(self.name)
self.processor = ViTImageProcessor.from_pretrained(self.name)
self.tokenizer = AutoTokenizer.from_pretrained(self.name)

@classmethod
def generate_prompt(self, image):

_image = MS_Image.tensor2pil(image)
inputs = self.processor(images=_image, return_tensors="pt")
# Generate the caption
generated_ids = self.model.generate(
inputs.pixel_values,
max_length=16,
num_beams=4,
num_return_sequences=1
)
caption = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

return caption

class MS_Llm():

LLM_MODELS = [
# 'gemma-7b-it',
'llama3-70b-8192',
# 'llama3-8b-8192',
# 'mixtral-8x7b-32768',
]

# list of model https://huggingface.co/models?pipeline_tag=image-to-text&sort=downloads
VISION_LLM_MODELS = [
# 'nlpconnect/vit-gpt2-image-captioning',
'microsoft/kosmos-2-patch14-224',
# 'Salesforce/blip-image-captioning-large',
]

@staticmethod
def prestartup_script():
folder_paths.add_model_folder_path("nlpconnect", os.path.join(folder_paths.models_dir, "nlpconnect"))

@classmethod
def __init__(self, vision_llm_name = "nlpconnect/vit-gpt2-image-captioning", llm_name = "llama3-8b-8192"):

if vision_llm_name == 'microsoft/kosmos-2-patch14-224':
self.vision_llm = MS_Llm_Microsoft()
elif vision_llm_name == 'Salesforce/blip-image-captioning-large':
self.vision_llm = MS_Llm_Salesforce()
else:
self.vision_llm = MS_Llm_Nlpconnect()

self._groq_key = os.getenv("GROQ_API_KEY", "")
self.llm = llm_name

@classmethod
def generate_tile_prompt(self, image, prompt_context, seed=None):
prompt_tile = self.vision_llm.generate_prompt(image)
if self.vision_llm.name == 'microsoft/kosmos-2-patch14-224':
_prompt = self.get_grok_prompt(prompt_context, prompt_tile)
else:
_prompt = self.get_grok_prompt(prompt_context, prompt_tile)
if self._groq_key != "":
prompt = self.call_grok_api(_prompt, seed)
else:
prompt = _prompt
log(prompt, None, None, self.vision_llm.name)
return prompt


@classmethod
def get_grok_prompt(self, prompt_context, prompt_tile):
prompt = [
f"tile_prompt: \"{prompt_tile}\".",
f"full_image_prompt: \"{prompt_context}\".",
"tile_prompt is part of full_image_prompt.",
"If tile_prompt is describing something different than the full image, correct tile_prompt to match full_image_prompt.",
"if you don't need to change the tile_prompt return the tile_prompt.",
"your answer will strictly and only return the tile_prompt string without any decoration like markdown syntax."
]
return " ".join(prompt)

@classmethod
def call_grok_api(self, prompt, seed=None):

client = Groq(api_key=self._groq_key) # Assuming the Groq client accepts an api_key parameter
completion = client.chat.completions.create(
model=self.llm,
messages=[{
"role": "user",
"content": prompt
}],
temperature=1,
max_tokens=1024,
top_p=1,
stream=False,
stop=None,
seed=seed,
)

return completion.choices[0].message.content
64 changes: 64 additions & 0 deletions py/nodes/Prompt/PromptFromImage_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#
###
#
# Display Info or any string
#
# Largely inspired by PYSSSSS - ShowText
#
###

from types import SimpleNamespace

from ...inc.lib.llm import MS_Llm

from ...utils.log import *

class PromptFromImage_v1:

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", {"label": "image"}),
"vision_llm_model": (MS_Llm.VISION_LLM_MODELS, { "label": "Vision LLM Model", "default": "microsoft/kosmos-2-patch14-224" }),
"llm_model": (MS_Llm.LLM_MODELS, { "label": "LLM Model", "default": "llama3-70b-8192" }),
},
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}

INPUT_IS_LIST = False
FUNCTION = "fn"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (False,)
CATEGORY = "MaraScott/Prompt"

RETURN_TYPES = (
"STRING",
)
RETURN_NAMES = (
"Prompt",
)

@classmethod
def fn(self, **kwargs):

self.INPUTS = SimpleNamespace(
image = kwargs.get('image', None)
)
self.LLM = SimpleNamespace(
vision_model_name = kwargs.get('vision_llm_model', None),
model_name = kwargs.get('llm_model', None),
model = None,
)
self.LLM.model = MS_Llm(self.LLM.vision_model_name, self.LLM.model_name)

self.OUPUTS = SimpleNamespace(
prompt = self.LLM.model.vision_llm.generate_prompt(self.INPUTS.image)
)

return {"ui": {"text": self.OUPUTS.prompt}, "result": (self.OUPUTS.prompt,)}
Loading

0 comments on commit 4121e4a

Please sign in to comment.