Skip to content

Commit

Permalink
added download of fusecap for captioning (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
asomoza authored Dec 21, 2023
1 parent 2b0837a commit 17ec5d5
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
38 changes: 38 additions & 0 deletions configs/download_list.json
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,43 @@
}
]
}
],
"captions_items": [
{
"title": "FuseCap",
"description": "A framework designed to enhance image captioning by incorporating detailed visual information into traditional captions.",
"destination_directory": "app_models",
"destination_subdirectory": "captions/fusecap",
"files": [
{
"file": "config.json",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/config.json?download=true"
},
{
"file": "preprocessor_config.json",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/preprocessor_config.json?download=true"
},
{
"file": "pytorch_model.bin",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/resolve/main/pytorch_model.bin?download=true"
},
{
"file": "special_tokens_map.json",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/special_tokens_map.json?download=true"
},
{
"file": "tokenizer.json",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer.json?download=true"
},
{
"file": "tokenizer_config.json",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer_config.json?download=true"
},
{
"file": "vocab.txt",
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/vocab.txt?download=true"
}
]
}
]
}
9 changes: 8 additions & 1 deletion src/iartisanxl/app/downloader_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ def init_ui(self):
self.t2i_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
t2i_widget.setLayout(self.t2i_items_layout)

captions_widget = QWidget()
self.captions_items_layout = QGridLayout()
self.captions_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
captions_widget.setLayout(self.captions_items_layout)

tab_widget.addTab(essentials_widget, "Essentials")
tab_widget.addTab(controlnets_widget, "ControlNet")
tab_widget.addTab(t2i_widget, "T2I Adapters")
tab_widget.addTab(captions_widget, "Captions")
self.main_layout.addWidget(tab_widget)

sdxl_download_button = QPushButton("Download")
Expand Down Expand Up @@ -103,6 +109,7 @@ def load_items(self):
"essential_items": self.essentials_items_layout,
"controlnet_items": self.controlnets_items_layout,
"t2i_items": self.t2i_items_layout,
"captions_items": self.captions_items_layout,
}

for category, layout in layouts.items():
Expand Down Expand Up @@ -156,7 +163,7 @@ def make_final_directory(self, destination_directory, destination_subdirectory):
return final_directory

def on_start_download(self):
layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout]
layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout, self.captions_items_layout]
for layout in layouts:
for i in range(layout.count()):
item = layout.itemAt(i).widget()
Expand Down
42 changes: 25 additions & 17 deletions src/iartisanxl/modules/dataset/dataset_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def on_ai_caption(self):
self.generate_captions_thread = GenerateCaptionsThread(self.device)
self.generate_captions_thread.status_update.connect(self.update_status_bar)
self.generate_captions_thread.caption_done.connect(self.on_ai_caption_done)
self.generate_captions_thread.error.connect(self.ai_caption_error)
else:
try:
self.generate_captions_thread.caption_done.disconnect(self.generate_item_ai_caption_done)
Expand All @@ -205,6 +206,11 @@ def on_ai_caption(self):

self.generate_captions_thread.start()

def ai_caption_error(self, text):
self.enable_ui()
self.show_snackbar(text)
self.update_status_bar(text)

def on_ai_caption_done(self, text):
self.image_caption_edit.setPlainText(text)
self.enable_ui()
Expand Down Expand Up @@ -243,25 +249,27 @@ def on_mass_caption(self):
break

def on_ai_mass_caption(self):
self.disable_ui()

if self.generate_captions_thread is None:
self.generate_captions_thread = GenerateCaptionsThread(self.device)
self.generate_captions_thread.status_update.connect(self.update_status_bar)
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
else:
try:
self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done)
except TypeError:
pass
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
if self.dataset_dir is not None and len(self.dataset_dir) > 0:
self.disable_ui()

if self.generate_captions_thread is None:
self.generate_captions_thread = GenerateCaptionsThread(self.device)
self.generate_captions_thread.status_update.connect(self.update_status_bar)
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
self.generate_captions_thread.error.connect(self.ai_caption_error)
else:
try:
self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done)
except TypeError:
pass
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)

text = self.image_caption_edit.toPlainText()
text = self.image_caption_edit.toPlainText()

self.progress_bar.setMaximum(self.dataset_items_view.item_count)
self.dataset_items_view.get_first_item()
self.update_status_bar("Generating captions...")
self.generate_item_ai_caption(text)
self.progress_bar.setMaximum(self.dataset_items_view.item_count)
self.dataset_items_view.get_first_item()
self.update_status_bar("Generating captions...")
self.generate_item_ai_caption(text)

def generate_item_ai_caption(self, text):
item = self.dataset_items_view.current_item
Expand Down
13 changes: 9 additions & 4 deletions src/iartisanxl/threads/generate_captions_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class GenerateCaptionsThread(QThread):
status_update = pyqtSignal(str)
caption_done = pyqtSignal(str)
error = pyqtSignal(str)

def __init__(self, device):
super().__init__()
Expand All @@ -21,8 +22,14 @@ def __init__(self, device):
def run(self):
if self.model is None:
self.status_update.emit("Loading FuseCap model...")
self.processor = BlipProcessor.from_pretrained("models/captions/fusecap")
self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device)

try:
self.processor = BlipProcessor.from_pretrained("models/captions/fusecap")
self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device)
except OSError:
self.error.emit("Need to download the FuseCap model from the downloader menu.")
return

self.status_update.emit("FuseCap loaded.")

self.status_update.emit("Generating AI caption...")
Expand All @@ -32,8 +39,6 @@ def run(self):
buffer.open(QBuffer.ReadWrite)
qimage.save(buffer, "PNG")

print(f"{self.text=}")

raw_image = Image.open(io.BytesIO(buffer.data()))
inputs = self.processor(raw_image, self.text, return_tensors="pt").to(self.device)

Expand Down

0 comments on commit 17ec5d5

Please sign in to comment.