From 89ce7f60e355d26bfb74468ba8e23f16719c8a6d Mon Sep 17 00:00:00 2001 From: "shiwktju@gmail.com" Date: Wed, 4 Sep 2024 15:28:46 +0800 Subject: [PATCH 1/5] load with connector. --- modules/sd_models.py | 8 ++--- modules/sd_remote_models.py | 64 +++++++++++++++++++++---------------- 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4211131dfa1..8e6b1e6e0b0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -213,7 +213,7 @@ def remote_model_hash(model_name): import hashlib m = hashlib.sha256() - m.update(read_remote_model(model_name, start=0x100000, size=0x10000).getvalue()) + m.update(read_remote_model(model_name, start=0x100000, size=0x10000)) return m.hexdigest()[0:8] @@ -302,12 +302,12 @@ def read_metadata_from_safetensors(filename): def read_metadata_from_remote_safetensors(filename): import json - metadata_len = read_remote_model(filename, start=0, size=8).getvalue() + metadata_len = read_remote_model(filename, start=0, size=8) metadata_len = int.from_bytes(metadata_len, "little") - json_start = read_remote_model(filename, start=8, size=2).getvalue() + json_start = read_remote_model(filename, start=8, size=2) assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" - json_data = json_start + read_remote_model(filename, start=10, size=metadata_len-2).getvalue() + json_data = json_start + read_remote_model(filename, start=10, size=metadata_len-2) json_obj = json.loads(json_data) res = {} diff --git a/modules/sd_remote_models.py b/modules/sd_remote_models.py index 18147302e1c..4fc18c7e769 100644 --- a/modules/sd_remote_models.py +++ b/modules/sd_remote_models.py @@ -4,7 +4,7 @@ import threading from io import BytesIO from modules import shared - +from osstorchconnector import OssCheckpoint def __bucket__(): @@ -35,39 +35,49 @@ def list_remote_models(ext_filter): return output -def read_remote_model(checkpoint_file, start=0, size=-1): - time_start = time.time() - buffer = BytesIO() - obj_size = __get_object_size(checkpoint_file) +# def read_remote_model(checkpoint_file, start=0, size=-1): +# time_start = time.time() +# buffer = BytesIO() +# obj_size = __get_object_size(checkpoint_file) - s = start - end = (obj_size if size == -1 else start + size) - 1 +# s = start +# end = (obj_size if size == -1 else start + size) - 1 - tasks = [] +# tasks = [] - read_chunk_size = 2 * 1024 * 1024 - part_size = 256 * 1024 * 1024 +# read_chunk_size = 2 * 1024 * 1024 +# part_size = 256 * 1024 * 1024 - while True: - if s > end: - break - - e = min(s + part_size - 1, end) - t = threading.Thread(target=__range_get, - args=(checkpoint_file, buffer, start, s, e, read_chunk_size)) - tasks.append(t) - t.start() - s += part_size - - for t in tasks: - t.join() +# while True: +# if s > end: +# break + +# e = min(s + part_size - 1, end) +# t = threading.Thread(target=__range_get, +# args=(checkpoint_file, buffer, start, s, e, read_chunk_size)) +# tasks.append(t) +# t.start() +# s += part_size + +# for t in tasks: +# t.join() - time_end = time.time() +# time_end = time.time() + +# print ("remote %s read time cost: %f"%(checkpoint_file, time_end - time_start)) +# buffer.seek(0) +# return buffer + - print ("remote %s read time cost: %f"%(checkpoint_file, time_end - time_start)) - buffer.seek(0) - return buffer + +def read_remote_model(checkpoint_file, start=0, size=-1): + + checkpoint = OssCheckpoint(endpoint=endpoint, cred_path=cred_path, config_path=config_path) + CHECKPOINT_URI = "oss://%s/%s" % (shared.opts.bucket_name, checkpoint_file) + with checkpoint.reader(CHECKPOINT_URI) as reader: + reader.seek(start) + return reader.read(size) def __range_get(object_name, buffer, offset, start, end, read_chunk_size): chunk_size = int(read_chunk_size) From c7ee35d1bdf4940e3df1be5b0d76ff1fec3a4e98 Mon Sep 17 00:00:00 2001 From: "shiwktju@gmail.com" Date: Wed, 4 Sep 2024 15:28:46 +0800 Subject: [PATCH 2/5] load with connector. --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 99096afddcd..06e2a193a9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,5 @@ torchdiffeq torchsde transformers==4.30.2 -oss2 \ No newline at end of file +oss2 +osstorchconnector \ No newline at end of file From 39216547f9db92c853ed61f6fa4554a21d0dc0c1 Mon Sep 17 00:00:00 2001 From: shiwk Date: Fri, 6 Sep 2024 15:07:52 +0800 Subject: [PATCH 3/5] conn debug --- .gitignore | 1 + modules/initialize.py | 2 +- modules/launch_utils.py | 6 +++++- modules/sd_remote_models.py | 18 ++++++++++++++---- modules/shared_options.py | 2 +- requirements_versions.txt | 5 +++-- 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 09734267ff5..942d4d248c1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ /SwinIR/* /repositories /venv +/venv310 /tmp /model.ckpt /models/**/* diff --git a/modules/initialize.py b/modules/initialize.py index f24f76375db..dda41716075 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -152,7 +152,7 @@ def load_model(): from modules import devices devices.first_time_calculation() - Thread(target=load_model).start() + # Thread(target=load_model).start() from modules import shared_items shared_items.reload_hypernetworks() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 6e54d06367c..ab21a419ff3 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -351,7 +351,11 @@ def prepare_environment(): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") - + + if not is_installed("cython"): + run_pip("install cython", "cython") + startup_timer.record("install cython") + if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' diff --git a/modules/sd_remote_models.py b/modules/sd_remote_models.py index 4fc18c7e769..73889833823 100644 --- a/modules/sd_remote_models.py +++ b/modules/sd_remote_models.py @@ -7,6 +7,12 @@ from osstorchconnector import OssCheckpoint +def __check_bucket_opts(): + if shared.opts.bucket_name and shared.opts.bucket_endpoint: + return True + print("Bucket opts not specified.") + return False + def __bucket__(): auth = oss2.Auth(os.environ.get('ACCESS_KEY_ID'), os.environ.get('ACCESS_KEY_SECRET')) return oss2.Bucket(auth, shared.opts.bucket_endpoint, shared.opts.bucket_name, enable_crc=False) @@ -19,8 +25,10 @@ def get_remote_model_mmtime(model_name): return __bucket__().head_object(model_name).last_modified def list_remote_models(ext_filter): - dir = shared.opts.bucket_model_ckpt_dir if shared.opts.bucket_model_ckpt_dir.endswith('/') else shared.opts.bucket_model_ckpt_dir + '/' + if not __check_bucket_opts(): + return [] output = [] + dir = shared.opts.bucket_model_ckpt_dir if shared.opts.bucket_model_ckpt_dir.endswith('/') else shared.opts.bucket_model_ckpt_dir + '/' for obj in oss2.ObjectIteratorV2(__bucket__(), prefix = dir, delimiter = '/', start_after=dir, fetch_owner=False): if obj.is_prefix(): print('directory: ', obj.key) @@ -71,9 +79,11 @@ def list_remote_models(ext_filter): -def read_remote_model(checkpoint_file, start=0, size=-1): - - checkpoint = OssCheckpoint(endpoint=endpoint, cred_path=cred_path, config_path=config_path) +def read_remote_model(checkpoint_file, start=0, size=-1) -> bytes: + if not __check_bucket_opts(): + return bytes() + + checkpoint = OssCheckpoint(endpoint=shared.opts.bucket_endpoint) CHECKPOINT_URI = "oss://%s/%s" % (shared.opts.bucket_name, checkpoint_file) with checkpoint.reader(CHECKPOINT_URI) as reader: reader.seek(start) diff --git a/modules/shared_options.py b/modules/shared_options.py index 83eee599aed..6a1549e8436 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -264,7 +264,7 @@ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), - "load_remote_ckpt": OptionInfo(False, "Load ckpt models from remote object storage").needs_reload_ui(), + "load_remote_ckpt": OptionInfo(True, "Load ckpt models from remote object storage").needs_reload_ui(), 'bucket_name': OptionInfo("", "Bucket name to download ckpt model"), 'bucket_endpoint': OptionInfo("", "Bucket endpoint to download ckpt model"), 'bucket_model_ckpt_dir': OptionInfo("", "Ckpt model directory in bucket"), diff --git a/requirements_versions.txt b/requirements_versions.txt index ba1af69dfa8..c5e3bb43d9a 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -8,7 +8,7 @@ einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.41.2 -httpcore==0.15 +httpcore==1.0.5 inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 @@ -30,4 +30,5 @@ torchdiffeq==0.2.3 torchsde==0.2.5 transformers==4.30.2 oss2==2.18.1 -urllib3==1.26.16 \ No newline at end of file +urllib3==1.26.16 +osstorchconnector==1.0.0rc1 From 1d4e0fc660d9700a8f48fd991fc7f78cead8d9e9 Mon Sep 17 00:00:00 2001 From: shiwk Date: Fri, 6 Sep 2024 15:18:33 +0800 Subject: [PATCH 4/5] gitignore edit --- .gitignore | 2 ++ webui-user.sh | 53 --------------------------------------------------- 2 files changed, 2 insertions(+), 53 deletions(-) delete mode 100644 webui-user.sh diff --git a/.gitignore b/.gitignore index 942d4d248c1..ecc1f38edc9 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,5 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/output.log +/webui-user.sh \ No newline at end of file diff --git a/webui-user.sh b/webui-user.sh deleted file mode 100644 index e57d4740a75..00000000000 --- a/webui-user.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -######################################################### -# Uncomment and change the variables below to your need:# -######################################################### - -# Install directory without trailing slash -#install_dir="/home/$(whoami)" - -# Name of the subdirectory -#clone_dir="stable-diffusion-webui" - -# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" -#export COMMANDLINE_ARGS="" - -# python3 executable -#python_cmd="python3" - -# git executable -#export GIT="git" - -# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) -#venv_dir="venv" - -# script to launch to start the app -#export LAUNCH_SCRIPT="launch.py" - -# install command for torch -#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" - -# Requirements file to use for stable-diffusion-webui -#export REQS_FILE="requirements_versions.txt" - -# Fixed git repos -#export K_DIFFUSION_PACKAGE="" -#export GFPGAN_PACKAGE="" - -# Fixed git commits -#export STABLE_DIFFUSION_COMMIT_HASH="" -#export CODEFORMER_COMMIT_HASH="" -#export BLIP_COMMIT_HASH="" - -# Uncomment to enable accelerated launch -#export ACCELERATE="True" - -# Uncomment to disable TCMalloc -#export NO_TCMALLOC="True" - -# Bucket access -#export ACCESS_KEY_ID="" -#export ACCESS_KEY_SECRET="" - - -########################################### From 1814e59592226cea111fed2b6c4189dcf716eefa Mon Sep 17 00:00:00 2001 From: shiwk Date: Tue, 10 Sep 2024 00:14:45 +0800 Subject: [PATCH 5/5] conn debug --- modules/api/api.py | 2 + modules/hashes.py | 8 + modules/sd_models.py | 19 +- modules/sd_remote_models.py | 65 +- modules/ui.py | 1414 +++++++++++++++++------------------ 5 files changed, 764 insertions(+), 744 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe7144..e51b6a77563 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -546,7 +546,9 @@ def unloadapi(self): return {} def reloadapi(self): + print("start reload api") reload_model_weights() + print("end reload api") return {} diff --git a/modules/hashes.py b/modules/hashes.py index 01b0865ea74..4527dd1114c 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -30,6 +30,14 @@ def calculate_remote_sha256(filename): return hash_sha256.hexdigest() +# def calculate_remote_sha256(filename): +# blksize = 1024 * 1024 + +# buf = read_remote_model(filename, start = 0, size=blksize) +# hash_object = hashlib.sha256(buf) + +# return hash_object.hexdigest() + def sha256_from_cache(filename, title, use_addnet_hash=False, remote_model = False): hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") ondisk_mtime = os.path.getmtime(filename) if not remote_model else get_remote_model_mmtime(filename) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8e6b1e6e0b0..3fc4cccea60 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -76,7 +76,7 @@ def read_metadata(): self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name + ('[remote]' if self.remote_model else '' ), filename, read_metadata, remote_model) except Exception as e: errors.display(e, f"reading metadata for {filename}") - + print("CheckpointInfo start") self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] @@ -84,13 +84,18 @@ def read_metadata(): self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name + ('[remote]' if self.remote_model else '' )}", remote_model=remote_model) self.shorthash = self.sha256[0:10] if self.sha256 else None - + print("sha256: %s" % self.sha256) + print("shorthash: %s" % self.shorthash) + self.title = name + ('[remote]' if self.remote_model else '' )+ ('' if self.shorthash is None else f'[{self.shorthash}]') self.short_title = self.name_for_extra + ('[remote]' if self.remote_model else '') + ('' if self.shorthash is None else f'[{self.shorthash}]') + print("title: %s" % self.title) + print("short_title: %s" % self.short_title) self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] if self.shorthash: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] + print("CheckpointInfo end") def register(self): checkpoints_list[self.title] = self @@ -170,6 +175,8 @@ def list_models(): for filename in remote_models: checkpoint_info = CheckpointInfo(filename, remote_model=True) checkpoint_info.register() + print ("list_model: %s " % filename) + re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") @@ -213,7 +220,7 @@ def remote_model_hash(model_name): import hashlib m = hashlib.sha256() - m.update(read_remote_model(model_name, start=0x100000, size=0x10000)) + m.update(read_remote_model(model_name, start=0x100000, size=0x10000).getvalue()) return m.hexdigest()[0:8] @@ -302,12 +309,12 @@ def read_metadata_from_safetensors(filename): def read_metadata_from_remote_safetensors(filename): import json - metadata_len = read_remote_model(filename, start=0, size=8) + metadata_len = read_remote_model(filename, start=0, size=8).getvalue() metadata_len = int.from_bytes(metadata_len, "little") - json_start = read_remote_model(filename, start=8, size=2) + json_start = read_remote_model(filename, start=8, size=2).getvalue() assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" - json_data = json_start + read_remote_model(filename, start=10, size=metadata_len-2) + json_data = json_start + read_remote_model(filename, start=10, size=metadata_len-2).getvalue() json_obj = json.loads(json_data) res = {} diff --git a/modules/sd_remote_models.py b/modules/sd_remote_models.py index 73889833823..e42ce008d88 100644 --- a/modules/sd_remote_models.py +++ b/modules/sd_remote_models.py @@ -5,7 +5,7 @@ from io import BytesIO from modules import shared from osstorchconnector import OssCheckpoint - +import torch def __check_bucket_opts(): if shared.opts.bucket_name and shared.opts.bucket_endpoint: @@ -43,51 +43,54 @@ def list_remote_models(ext_filter): return output -# def read_remote_model(checkpoint_file, start=0, size=-1): -# time_start = time.time() -# buffer = BytesIO() -# obj_size = __get_object_size(checkpoint_file) +def read_remote_model(checkpoint_file, start=0, size=-1): + time_start = time.time() + buffer = BytesIO() + obj_size = __get_object_size(checkpoint_file) -# s = start -# end = (obj_size if size == -1 else start + size) - 1 + s = start + end = (obj_size if size == -1 else start + size) - 1 -# tasks = [] + tasks = [] -# read_chunk_size = 2 * 1024 * 1024 -# part_size = 256 * 1024 * 1024 + read_chunk_size = 2 * 1024 * 1024 + part_size = 256 * 1024 * 1024 -# while True: -# if s > end: -# break - -# e = min(s + part_size - 1, end) -# t = threading.Thread(target=__range_get, -# args=(checkpoint_file, buffer, start, s, e, read_chunk_size)) -# tasks.append(t) -# t.start() -# s += part_size - -# for t in tasks: -# t.join() + while True: + if s > end: + break + + e = min(s + part_size - 1, end) + t = threading.Thread(target=__range_get, + args=(checkpoint_file, buffer, start, s, e, read_chunk_size)) + tasks.append(t) + t.start() + s += part_size + + for t in tasks: + t.join() -# time_end = time.time() + time_end = time.time() -# print ("remote %s read time cost: %f"%(checkpoint_file, time_end - time_start)) -# buffer.seek(0) -# return buffer + print ("remote %s read time cost: %f"%(checkpoint_file, time_end - time_start)) + buffer.seek(0) + return buffer -def read_remote_model(checkpoint_file, start=0, size=-1) -> bytes: +def load_remote_model_ckpt(checkpoint_file, map_location) -> bytes: if not __check_bucket_opts(): return bytes() - + checkpoint = OssCheckpoint(endpoint=shared.opts.bucket_endpoint) CHECKPOINT_URI = "oss://%s/%s" % (shared.opts.bucket_name, checkpoint_file) + print("load %s state.." % CHECKPOINT_URI) + state_dict = None with checkpoint.reader(CHECKPOINT_URI) as reader: - reader.seek(start) - return reader.read(size) + state_dict = torch.load(reader, map_location = map_location, weights_only = True) + print("type:", type(state_dict)) + return state_dict def __range_get(object_name, buffer, offset, start, end, read_chunk_size): chunk_size = int(read_chunk_size) diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..e60a1caec27 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -540,705 +540,705 @@ def create_ui(): extra_tabs.__exit__() - scripts.scripts_current = scripts.scripts_img2img - scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - toprow = Toprow(is_img2img=True) - - extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") - extra_tabs.__enter__() - - with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="img2img_settings"): - copy_image_buttons = [] - copy_image_destinations = {} - - def add_copy_image_controls(tab_name, elem): - with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): - gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") - - for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): - if name == tab_name: - gr.Button(title, interactive=False) - copy_image_destinations[name] = elem - continue - - button = gr.Button(title) - copy_image_buttons.append((button, name, elem)) - - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

Process images in a directory on the same machine where the server is running." + - "
Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

" - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], - ) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - scripts.scripts_img2img.prepare_ui() - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) - - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") - - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) - tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "denoising": - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "cfg": - with gr.Row(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) - - elif category == "checkboxes": - with FormRow(elem_classes="checkboxes-row", variant="compact"): - pass - - elif category == "accordions": - with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): - scripts.scripts_img2img.setup_ui_for_section(category) - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "override_settings": - with FormRow(elem_id="img2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('img2img', row) - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = scripts.scripts_img2img.setup_ui() - - elif category == "inpaint": - with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - - with FormRow(): - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate(img2img_tabs): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - - if category not in {"accordions"}: - scripts.scripts_img2img.setup_ui_for_section(category) - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps, - sampler_name, - mask_blur, - mask_alpha, - inpainting_fill, - batch_count, - batch_size, - cfg_scale, - image_cfg_scale, - denoising_strength, - selected_scale_tab, - height, - width, - scale_by, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - img2img_batch_inpaint_mask_dir, - override_settings, - img2img_batch_use_png_info, - img2img_batch_png_info_props, - img2img_batch_png_info_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - interrogate_args = dict( - _js="get_img2img_tab_index", - inputs=[ - dummy_component, - img2img_batch_input_dir, - img2img_batch_output_dir, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - init_img_inpaint, - ], - outputs=[toprow.prompt, dummy_component], - ) - - toprow.prompt.submit(**img2img_args) - toprow.submit.click(**img2img_args) - - res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) - - detect_image_size_btn.click( - fn=lambda w, h, _: (w or gr.update(), h or gr.update()), - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, dummy_component], - outputs=[width, height], - show_progress=False, - ) - - toprow.restore_progress_button.click( - fn=progress.restore_progress, - _js="restoreProgressImg2img", - inputs=[dummy_component], - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - toprow.button_interrogate.click( - fn=lambda *args: process_interrogate(interrogate, *args), - **interrogate_args, - ) - - toprow.button_deepbooru.click( - fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), - **interrogate_args, - ) - - toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) - - img2img_paste_fields = [ - (toprow.prompt, "Prompt"), - (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), - (cfg_scale, "CFG scale"), - (image_cfg_scale, "Image CFG scale"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) - parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, - )) - - extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) - - extra_tabs.__exit__() - - scripts.scripts_current = None + # scripts.scripts_current = scripts.scripts_img2img + # scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + # with gr.Blocks(analytics_enabled=False) as img2img_interface: + # toprow = Toprow(is_img2img=True) + + # extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") + # extra_tabs.__enter__() + + # with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): + # with gr.Column(variant='compact', elem_id="img2img_settings"): + # copy_image_buttons = [] + # copy_image_destinations = {} + + # def add_copy_image_controls(tab_name, elem): + # with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + # gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + # for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + # if name == tab_name: + # gr.Button(title, interactive=False) + # copy_image_destinations[name] = elem + # continue + + # button = gr.Button(title) + # copy_image_buttons.append((button, name, elem)) + + # with gr.Tabs(elem_id="mode_img2img"): + # img2img_selected_tab = gr.State(0) + + # with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + # init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + # add_copy_image_controls('img2img', init_img) + + # with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + # sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + # add_copy_image_controls('sketch', sketch) + + # with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + # init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + # add_copy_image_controls('inpaint', init_img_with_mask) + + # with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + # inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + # inpaint_color_sketch_orig = gr.State(None) + # add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + # def update_orig(image, state): + # if image is not None: + # same_size = state is not None and state.size == image.size + # has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + # edited = same_size and has_exact_match + # return image if not edited or state is None else state + + # inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + # with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + # init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + # init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + # with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + # hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + # gr.HTML( + # "

Process images in a directory on the same machine where the server is running." + + # "
Use an empty output directory to save pictures normally instead of writing to the output directory." + + # f"
Add inpaint batch mask directory to enable inpaint batch processing." + # f"{hidden}

" + # ) + # img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + # img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + # img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + # with gr.Accordion("PNG info", open=False): + # img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + # img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + # img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + # img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + # for i, tab in enumerate(img2img_tabs): + # tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + # def copy_image(img): + # if isinstance(img, dict) and 'image' in img: + # return img['image'] + + # return img + + # for button, name, elem in copy_image_buttons: + # button.click( + # fn=copy_image, + # inputs=[elem], + # outputs=[copy_image_destinations[name]], + # ) + # button.click( + # fn=lambda: None, + # _js=f"switch_to_{name.replace(' ', '_')}", + # inputs=[], + # outputs=[], + # ) + + # with FormRow(): + # resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + # scripts.scripts_img2img.prepare_ui() + + # for category in ordered_ui_categories(): + # if category == "sampler": + # steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") + + # elif category == "dimensions": + # with FormRow(): + # with gr.Column(elem_id="img2img_column_size", scale=4): + # selected_scale_tab = gr.State(value=0) + + # with gr.Tabs(): + # with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: + # with FormRow(): + # with gr.Column(elem_id="img2img_column_size", scale=4): + # width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + # height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + # with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + # res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + # detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") + + # with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: + # scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") + + # with FormRow(): + # scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") + # gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") + # button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") + + # on_change_args = dict( + # fn=resize_from_to_html, + # _js="currentImg2imgSourceResolution", + # inputs=[dummy_component, dummy_component, scale_by], + # outputs=scale_by_html, + # show_progress=False, + # ) + + # scale_by.release(**on_change_args) + # button_update_resize_to.click(**on_change_args) + + # # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + # for component in [init_img, sketch]: + # component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + # tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) + # tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) + + # if opts.dimensions_and_batch_together: + # with gr.Column(elem_id="img2img_column_batch"): + # batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + # batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + # elif category == "denoising": + # denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + # elif category == "cfg": + # with gr.Row(): + # cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + # image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) + + # elif category == "checkboxes": + # with FormRow(elem_classes="checkboxes-row", variant="compact"): + # pass + + # elif category == "accordions": + # with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): + # scripts.scripts_img2img.setup_ui_for_section(category) + + # elif category == "batch": + # if not opts.dimensions_and_batch_together: + # with FormRow(elem_id="img2img_column_batch"): + # batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + # batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + # elif category == "override_settings": + # with FormRow(elem_id="img2img_override_settings_row") as row: + # override_settings = create_override_settings_dropdown('img2img', row) + + # elif category == "scripts": + # with FormGroup(elem_id="img2img_script_container"): + # custom_inputs = scripts.scripts_img2img.setup_ui() + + # elif category == "inpaint": + # with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + # with FormRow(): + # mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + # mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + # with FormRow(): + # inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + # with FormRow(): + # inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + # with FormRow(): + # with gr.Column(): + # inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + # with gr.Column(scale=4): + # inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + # def select_img2img_tab(tab): + # return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + # for i, elem in enumerate(img2img_tabs): + # elem.select( + # fn=lambda tab=i: select_img2img_tab(tab), + # inputs=[], + # outputs=[inpaint_controls, mask_alpha], + # ) + + # if category not in {"accordions"}: + # scripts.scripts_img2img.setup_ui_for_section(category) + + # img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + + # img2img_args = dict( + # fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + # _js="submit_img2img", + # inputs=[ + # dummy_component, + # dummy_component, + # toprow.prompt, + # toprow.negative_prompt, + # toprow.ui_styles.dropdown, + # init_img, + # sketch, + # init_img_with_mask, + # inpaint_color_sketch, + # inpaint_color_sketch_orig, + # init_img_inpaint, + # init_mask_inpaint, + # steps, + # sampler_name, + # mask_blur, + # mask_alpha, + # inpainting_fill, + # batch_count, + # batch_size, + # cfg_scale, + # image_cfg_scale, + # denoising_strength, + # selected_scale_tab, + # height, + # width, + # scale_by, + # resize_mode, + # inpaint_full_res, + # inpaint_full_res_padding, + # inpainting_mask_invert, + # img2img_batch_input_dir, + # img2img_batch_output_dir, + # img2img_batch_inpaint_mask_dir, + # override_settings, + # img2img_batch_use_png_info, + # img2img_batch_png_info_props, + # img2img_batch_png_info_dir, + # ] + custom_inputs, + # outputs=[ + # img2img_gallery, + # generation_info, + # html_info, + # html_log, + # ], + # show_progress=False, + # ) + + # interrogate_args = dict( + # _js="get_img2img_tab_index", + # inputs=[ + # dummy_component, + # img2img_batch_input_dir, + # img2img_batch_output_dir, + # init_img, + # sketch, + # init_img_with_mask, + # inpaint_color_sketch, + # init_img_inpaint, + # ], + # outputs=[toprow.prompt, dummy_component], + # ) + + # toprow.prompt.submit(**img2img_args) + # toprow.submit.click(**img2img_args) + + # res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) + + # detect_image_size_btn.click( + # fn=lambda w, h, _: (w or gr.update(), h or gr.update()), + # _js="currentImg2imgSourceResolution", + # inputs=[dummy_component, dummy_component, dummy_component], + # outputs=[width, height], + # show_progress=False, + # ) + + # toprow.restore_progress_button.click( + # fn=progress.restore_progress, + # _js="restoreProgressImg2img", + # inputs=[dummy_component], + # outputs=[ + # img2img_gallery, + # generation_info, + # html_info, + # html_log, + # ], + # show_progress=False, + # ) + + # toprow.button_interrogate.click( + # fn=lambda *args: process_interrogate(interrogate, *args), + # **interrogate_args, + # ) + + # toprow.button_deepbooru.click( + # fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), + # **interrogate_args, + # ) + + # toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + # toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + + # img2img_paste_fields = [ + # (toprow.prompt, "Prompt"), + # (toprow.negative_prompt, "Negative prompt"), + # (steps, "Steps"), + # (sampler_name, "Sampler"), + # (cfg_scale, "CFG scale"), + # (image_cfg_scale, "Image CFG scale"), + # (width, "Size-1"), + # (height, "Size-2"), + # (batch_size, "Batch size"), + # (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + # (denoising_strength, "Denoising strength"), + # (mask_blur, "Mask blur"), + # *scripts.scripts_img2img.infotext_fields + # ] + # parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) + # parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + # parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + # paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, + # )) + + # extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') + # ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + # extra_tabs.__exit__() + + # scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: ui_postprocessing.create_ui() - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - - for tabname, button in buttons.items(): - parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, - )) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger() - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row(variant="compact", equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding", id="create_embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images", id="preprocess_images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Column(visible=False) as process_multicrop_col: - gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') - with gr.Row(): - process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") - process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") - with gr.Row(): - process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") - process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") - with gr.Row(): - process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") - process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - process_multicrop.change( - fn=lambda show: gr_show(show), - inputs=[process_multicrop], - outputs=[process_multicrop_col], - ) - - def get_textual_inversion_template_names(): - return sorted(textual_inversion.textual_inversion_templates) - - with gr.Tab(label="Train", id="train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks)) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(elem_id='ti_gallery_container'): - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4) - gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - - create_embedding.click( - fn=textual_inversion_ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=hypernetworks_ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_keep_original_size, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - process_multicrop, - process_multicrop_mindim, - process_multicrop_maxdim, - process_multicrop_minarea, - process_multicrop_maxarea, - process_multicrop_objective, - process_multicrop_threshold, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - use_weight, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - use_weight, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) + # with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + # with gr.Row(equal_height=False): + # with gr.Column(variant='panel'): + # image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + # with gr.Column(variant='panel'): + # html = gr.HTML() + # generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + # html2 = gr.HTML() + # with gr.Row(): + # buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + + # for tabname, button in buttons.items(): + # parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + # paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, + # )) + + # image.change( + # fn=wrap_gradio_call(modules.extras.run_pnginfo), + # inputs=[image], + # outputs=[html, generation_info, html2], + # ) + + # modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger() + + # with gr.Blocks(analytics_enabled=False) as train_interface: + # with gr.Row(equal_height=False): + # gr.HTML(value="

See wiki for detailed explanation.

") + + # with gr.Row(variant="compact", equal_height=False): + # with gr.Tabs(elem_id="train_tabs"): + + # with gr.Tab(label="Create embedding", id="create_embedding"): + # new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + # initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + # nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + # overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + # with gr.Row(): + # with gr.Column(scale=3): + # gr.HTML(value="") + + # with gr.Column(): + # create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + # with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"): + # new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + # new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + # new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + # new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") + # new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + # new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + # new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + # new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + # overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + # with gr.Row(): + # with gr.Column(scale=3): + # gr.HTML(value="") + + # with gr.Column(): + # create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + # with gr.Tab(label="Preprocess images", id="preprocess_images"): + # process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + # process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + # process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + # process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + # preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + # with gr.Row(): + # process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") + # process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + # process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + # process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + # process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") + # process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + # process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + # with gr.Row(visible=False) as process_split_extra_row: + # process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + # process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + # with gr.Row(visible=False) as process_focal_crop_row: + # process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + # process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + # process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + # process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + # with gr.Column(visible=False) as process_multicrop_col: + # gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + # with gr.Row(): + # process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") + # process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") + # with gr.Row(): + # process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") + # process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") + # with gr.Row(): + # process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") + # process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") + + # with gr.Row(): + # with gr.Column(scale=3): + # gr.HTML(value="") + + # with gr.Column(): + # with gr.Row(): + # interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + # run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + # process_split.change( + # fn=lambda show: gr_show(show), + # inputs=[process_split], + # outputs=[process_split_extra_row], + # ) + + # process_focal_crop.change( + # fn=lambda show: gr_show(show), + # inputs=[process_focal_crop], + # outputs=[process_focal_crop_row], + # ) + + # process_multicrop.change( + # fn=lambda show: gr_show(show), + # inputs=[process_multicrop], + # outputs=[process_multicrop_col], + # ) + + # def get_textual_inversion_template_names(): + # return sorted(textual_inversion.textual_inversion_templates) + + # with gr.Tab(label="Train", id="train"): + # gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + # with FormRow(): + # train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + # create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + # train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks)) + # create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name") + + # with FormRow(): + # embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + # hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + # with FormRow(): + # clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + # clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + # with FormRow(): + # batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + # gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + # dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + # log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + # with FormRow(): + # template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + # create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + # training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + # training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + # varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + # steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + # with FormRow(): + # create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + # save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + # use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight") + + # save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + # preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + # shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + # tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + # latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + # with gr.Row(): + # train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + # interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + # train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + # params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + # script_callbacks.ui_train_tabs_callback(params) + + # with gr.Column(elem_id='ti_gallery_container'): + # ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + # gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4) + # gr.HTML(elem_id="ti_progress", value="") + # ti_outcome = gr.HTML(elem_id="ti_error", value="") + + # create_embedding.click( + # fn=textual_inversion_ui.create_embedding, + # inputs=[ + # new_embedding_name, + # initialization_text, + # nvpt, + # overwrite_old_embedding, + # ], + # outputs=[ + # train_embedding_name, + # ti_output, + # ti_outcome, + # ] + # ) + + # create_hypernetwork.click( + # fn=hypernetworks_ui.create_hypernetwork, + # inputs=[ + # new_hypernetwork_name, + # new_hypernetwork_sizes, + # overwrite_old_hypernetwork, + # new_hypernetwork_layer_structure, + # new_hypernetwork_activation_func, + # new_hypernetwork_initialization_option, + # new_hypernetwork_add_layer_norm, + # new_hypernetwork_use_dropout, + # new_hypernetwork_dropout_structure + # ], + # outputs=[ + # train_hypernetwork_name, + # ti_output, + # ti_outcome, + # ] + # ) + + # run_preprocess.click( + # fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), + # _js="start_training_textual_inversion", + # inputs=[ + # dummy_component, + # process_src, + # process_dst, + # process_width, + # process_height, + # preprocess_txt_action, + # process_keep_original_size, + # process_flip, + # process_split, + # process_caption, + # process_caption_deepbooru, + # process_split_threshold, + # process_overlap_ratio, + # process_focal_crop, + # process_focal_crop_face_weight, + # process_focal_crop_entropy_weight, + # process_focal_crop_edges_weight, + # process_focal_crop_debug, + # process_multicrop, + # process_multicrop_mindim, + # process_multicrop_maxdim, + # process_multicrop_minarea, + # process_multicrop_maxarea, + # process_multicrop_objective, + # process_multicrop_threshold, + # ], + # outputs=[ + # ti_output, + # ti_outcome, + # ], + # ) + + # train_embedding.click( + # fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), + # _js="start_training_textual_inversion", + # inputs=[ + # dummy_component, + # train_embedding_name, + # embedding_learn_rate, + # batch_size, + # gradient_step, + # dataset_directory, + # log_directory, + # training_width, + # training_height, + # varsize, + # steps, + # clip_grad_mode, + # clip_grad_value, + # shuffle_tags, + # tag_drop_out, + # latent_sampling_method, + # use_weight, + # create_image_every, + # save_embedding_every, + # template_file, + # save_image_with_stored_embedding, + # preview_from_txt2img, + # *txt2img_preview_params, + # ], + # outputs=[ + # ti_output, + # ti_outcome, + # ] + # ) + + # train_hypernetwork.click( + # fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), + # _js="start_training_textual_inversion", + # inputs=[ + # dummy_component, + # train_hypernetwork_name, + # hypernetwork_learn_rate, + # batch_size, + # gradient_step, + # dataset_directory, + # log_directory, + # training_width, + # training_height, + # varsize, + # steps, + # clip_grad_mode, + # clip_grad_value, + # shuffle_tags, + # tag_drop_out, + # latent_sampling_method, + # use_weight, + # create_image_every, + # save_embedding_every, + # template_file, + # preview_from_txt2img, + # *txt2img_preview_params, + # ], + # outputs=[ + # ti_output, + # ti_outcome, + # ] + # ) + + # interrupt_training.click( + # fn=lambda: shared.state.interrupt(), + # inputs=[], + # outputs=[], + # ) + + # interrupt_preprocessing.click( + # fn=lambda: shared.state.interrupt(), + # inputs=[], + # outputs=[], + # ) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) @@ -1247,18 +1247,18 @@ def get_textual_inversion_template_names(): interfaces = [ (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "train"), + # (img2img_interface, "img2img", "img2img"), + # (extras_interface, "Extras", "extras"), + # (pnginfo_interface, "PNG Info", "pnginfo"), + # (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), + # (train_interface, "Train", "train"), ] interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings.interface, "Settings", "settings")] - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] + # extensions_interface = ui_extensions.create_ui() + # interfaces += [(extensions_interface, "Extensions", "extensions")] shared.tab_names = [] for _interface, label, _ifid in interfaces: @@ -1267,7 +1267,7 @@ def get_textual_inversion_template_names(): with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: settings.add_quicksettings() - parameters_copypaste.connect_paste_params_buttons() + # parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)} @@ -1295,11 +1295,11 @@ def get_textual_inversion_template_names(): settings.add_functionality(demo) - update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") - settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + # update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + # settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + # demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) + # modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) loadsave.dump_defaults() demo.ui_loadsave = loadsave