From e4c2fc8bdbdc2115e8f3cf8d19b7a38cec610b0a Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 9 Jun 2024 07:26:09 -0700 Subject: [PATCH] Revert "aider: Added a third page focused on model merging." This reverts commit 2e4566994ae7a9489c2ac0289106cf66ff9efa87. --- src/invoke_training/ui/app.py | 3 - .../ui/pages/model_merge_page.py | 56 ------------------- 2 files changed, 59 deletions(-) delete mode 100644 src/invoke_training/ui/pages/model_merge_page.py diff --git a/src/invoke_training/ui/app.py b/src/invoke_training/ui/app.py index b36f0f4a..6ba895b4 100644 --- a/src/invoke_training/ui/app.py +++ b/src/invoke_training/ui/app.py @@ -7,12 +7,10 @@ from invoke_training.ui.pages.data_page import DataPage from invoke_training.ui.pages.training_page import TrainingPage -from invoke_training.ui.pages.model_merge_page import ModelMergePage def build_app(): training_page = TrainingPage() - model_merge_page = ModelMergePage() data_page = DataPage() app = FastAPI() @@ -26,5 +24,4 @@ async def root(): app = gr.mount_gradio_app(app, training_page.app(), "/train", app_kwargs={"favicon_path": "/assets/favicon.png"}) app = gr.mount_gradio_app(app, data_page.app(), "/data", app_kwargs={"favicon_path": "/assets/favicon.png"}) - app = gr.mount_gradio_app(app, model_merge_page.app(), "/merge", app_kwargs={"favicon_path": "/assets/favicon.png"}) return app diff --git a/src/invoke_training/ui/pages/model_merge_page.py b/src/invoke_training/ui/pages/model_merge_page.py deleted file mode 100644 index c7596574..00000000 --- a/src/invoke_training/ui/pages/model_merge_page.py +++ /dev/null @@ -1,56 +0,0 @@ -import gradio as gr - -class ModelMergePage: - def __init__(self): - with gr.Blocks( - title="Model Merging", - analytics_enabled=False, - head='', - ) as app: - with gr.Tab(label="Merge LoRA into SD Model"): - self._create_merge_tab() - with gr.Tab(label="Extract LoRA from Checkpoint"): - self._create_extract_tab() - - self._app = app - - def _create_merge_tab(self): - with gr.Row(): - gr.Markdown("## Merge LoRA into SD Model") - with gr.Row(): - base_model = gr.Textbox(label="Base Model Path") - lora_model = gr.Textbox(label="LoRA Model Path") - output_path = gr.Textbox(label="Output Path") - merge_button = gr.Button("Merge") - - merge_button.click( - fn=self._merge_lora_into_sd_model, - inputs=[base_model, lora_model, output_path], - outputs=[] - ) - - def _create_extract_tab(self): - with gr.Row(): - gr.Markdown("## Extract LoRA from Checkpoint") - with gr.Row(): - model_orig = gr.Textbox(label="Original Model Path") - model_tuned = gr.Textbox(label="Tuned Model Path") - save_to = gr.Textbox(label="Save To Path") - extract_button = gr.Button("Extract") - - extract_button.click( - fn=self._extract_lora_from_checkpoint, - inputs=[model_orig, model_tuned, save_to], - outputs=[] - ) - - def _merge_lora_into_sd_model(self, base_model, lora_model, output_path): - # Placeholder function for merging LoRA into SD model - print(f"Merging LoRA model {lora_model} into base model {base_model} and saving to {output_path}") - - def _extract_lora_from_checkpoint(self, model_orig, model_tuned, save_to): - # Placeholder function for extracting LoRA from checkpoint - print(f"Extracting LoRA from {model_tuned} using original model {model_orig} and saving to {save_to}") - - def app(self): - return self._app