From 17ec5d52dfc071c0a3f50d15b033b3138594bd26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Thu, 21 Dec 2023 18:43:32 -0300 Subject: [PATCH] added download of fusecap for captioning (#27) --- configs/download_list.json | 38 +++++++++++++++++ src/iartisanxl/app/downloader_dialog.py | 9 +++- .../modules/dataset/dataset_module.py | 42 +++++++++++-------- .../threads/generate_captions_thread.py | 13 ++++-- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/configs/download_list.json b/configs/download_list.json index 831b4f0..94d20d8 100644 --- a/configs/download_list.json +++ b/configs/download_list.json @@ -388,5 +388,43 @@ } ] } + ], + "captions_items": [ + { + "title": "FuseCap", + "description": "A framework designed to enhance image captioning by incorporating detailed visual information into traditional captions.", + "destination_directory": "app_models", + "destination_subdirectory": "captions/fusecap", + "files": [ + { + "file": "config.json", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/config.json?download=true" + }, + { + "file": "preprocessor_config.json", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/preprocessor_config.json?download=true" + }, + { + "file": "pytorch_model.bin", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/resolve/main/pytorch_model.bin?download=true" + }, + { + "file": "special_tokens_map.json", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/special_tokens_map.json?download=true" + }, + { + "file": "tokenizer.json", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer.json?download=true" + }, + { + "file": "tokenizer_config.json", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer_config.json?download=true" + }, + { + "file": "vocab.txt", + "url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/vocab.txt?download=true" + } + ] + } ] } \ No newline at end of file diff --git a/src/iartisanxl/app/downloader_dialog.py b/src/iartisanxl/app/downloader_dialog.py index 4f5a164..e138d1e 100644 --- a/src/iartisanxl/app/downloader_dialog.py +++ b/src/iartisanxl/app/downloader_dialog.py @@ -67,9 +67,15 @@ def init_ui(self): self.t2i_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop) t2i_widget.setLayout(self.t2i_items_layout) + captions_widget = QWidget() + self.captions_items_layout = QGridLayout() + self.captions_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + captions_widget.setLayout(self.captions_items_layout) + tab_widget.addTab(essentials_widget, "Essentials") tab_widget.addTab(controlnets_widget, "ControlNet") tab_widget.addTab(t2i_widget, "T2I Adapters") + tab_widget.addTab(captions_widget, "Captions") self.main_layout.addWidget(tab_widget) sdxl_download_button = QPushButton("Download") @@ -103,6 +109,7 @@ def load_items(self): "essential_items": self.essentials_items_layout, "controlnet_items": self.controlnets_items_layout, "t2i_items": self.t2i_items_layout, + "captions_items": self.captions_items_layout, } for category, layout in layouts.items(): @@ -156,7 +163,7 @@ def make_final_directory(self, destination_directory, destination_subdirectory): return final_directory def on_start_download(self): - layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout] + layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout, self.captions_items_layout] for layout in layouts: for i in range(layout.count()): item = layout.itemAt(i).widget() diff --git a/src/iartisanxl/modules/dataset/dataset_module.py b/src/iartisanxl/modules/dataset/dataset_module.py index 6768b60..4103cde 100644 --- a/src/iartisanxl/modules/dataset/dataset_module.py +++ b/src/iartisanxl/modules/dataset/dataset_module.py @@ -190,6 +190,7 @@ def on_ai_caption(self): self.generate_captions_thread = GenerateCaptionsThread(self.device) self.generate_captions_thread.status_update.connect(self.update_status_bar) self.generate_captions_thread.caption_done.connect(self.on_ai_caption_done) + self.generate_captions_thread.error.connect(self.ai_caption_error) else: try: self.generate_captions_thread.caption_done.disconnect(self.generate_item_ai_caption_done) @@ -205,6 +206,11 @@ def on_ai_caption(self): self.generate_captions_thread.start() + def ai_caption_error(self, text): + self.enable_ui() + self.show_snackbar(text) + self.update_status_bar(text) + def on_ai_caption_done(self, text): self.image_caption_edit.setPlainText(text) self.enable_ui() @@ -243,25 +249,27 @@ def on_mass_caption(self): break def on_ai_mass_caption(self): - self.disable_ui() - - if self.generate_captions_thread is None: - self.generate_captions_thread = GenerateCaptionsThread(self.device) - self.generate_captions_thread.status_update.connect(self.update_status_bar) - self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done) - else: - try: - self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done) - except TypeError: - pass - self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done) + if self.dataset_dir is not None and len(self.dataset_dir) > 0: + self.disable_ui() + + if self.generate_captions_thread is None: + self.generate_captions_thread = GenerateCaptionsThread(self.device) + self.generate_captions_thread.status_update.connect(self.update_status_bar) + self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done) + self.generate_captions_thread.error.connect(self.ai_caption_error) + else: + try: + self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done) + except TypeError: + pass + self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done) - text = self.image_caption_edit.toPlainText() + text = self.image_caption_edit.toPlainText() - self.progress_bar.setMaximum(self.dataset_items_view.item_count) - self.dataset_items_view.get_first_item() - self.update_status_bar("Generating captions...") - self.generate_item_ai_caption(text) + self.progress_bar.setMaximum(self.dataset_items_view.item_count) + self.dataset_items_view.get_first_item() + self.update_status_bar("Generating captions...") + self.generate_item_ai_caption(text) def generate_item_ai_caption(self, text): item = self.dataset_items_view.current_item diff --git a/src/iartisanxl/threads/generate_captions_thread.py b/src/iartisanxl/threads/generate_captions_thread.py index dc00c78..e29e915 100644 --- a/src/iartisanxl/threads/generate_captions_thread.py +++ b/src/iartisanxl/threads/generate_captions_thread.py @@ -8,6 +8,7 @@ class GenerateCaptionsThread(QThread): status_update = pyqtSignal(str) caption_done = pyqtSignal(str) + error = pyqtSignal(str) def __init__(self, device): super().__init__() @@ -21,8 +22,14 @@ def __init__(self, device): def run(self): if self.model is None: self.status_update.emit("Loading FuseCap model...") - self.processor = BlipProcessor.from_pretrained("models/captions/fusecap") - self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device) + + try: + self.processor = BlipProcessor.from_pretrained("models/captions/fusecap") + self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device) + except OSError: + self.error.emit("Need to download the FuseCap model from the downloader menu.") + return + self.status_update.emit("FuseCap loaded.") self.status_update.emit("Generating AI caption...") @@ -32,8 +39,6 @@ def run(self): buffer.open(QBuffer.ReadWrite) qimage.save(buffer, "PNG") - print(f"{self.text=}") - raw_image = Image.open(io.BytesIO(buffer.data())) inputs = self.processor(raw_image, self.text, return_tensors="pt").to(self.device)