diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
new file mode 100644
index 00000000..0d8c0647
--- /dev/null
+++ b/.github/workflows/lint.yml
@@ -0,0 +1,24 @@
+name: Lint
+
+on:
+ pull_request:
+ push:
+ branches: [main]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install dev dependencies
+ run: python -m pip install -r requirements.dev.txt
+
+ - name: Run lint and static type checks
+ run: bash ${GITHUB_WORKSPACE}/style.sh
diff --git a/.gitignore b/.gitignore
index 4ba344f7..6bae41ac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -155,6 +155,7 @@ cython_debug/
# repo specific
/models/
+models
/outputs/
/temp/
/gfpgan/
@@ -170,4 +171,5 @@ custom-conda-path.txt
bridgeData.py
src/taming-transformers
-/xformers-*
\ No newline at end of file
+/xformers-*
+/bin/
diff --git a/bridge.py b/bridge.py
index 8e0be8a8..be4ed512 100644
--- a/bridge.py
+++ b/bridge.py
@@ -1,82 +1,223 @@
-import requests, json, os, time, argparse, urllib3, time, base64, re, getpass
+import argparse
+import base64
+import getpass
+import json
+import os
+import random
+import sys
+import time
+from base64 import binascii
+from io import BytesIO
+
+import requests
+from PIL import Image, UnidentifiedImageError
+
+from nataili import disable_voodoo, disable_xformers
+from nataili.util import logger, quiesce_logger, set_logger_verbosity
+from nataili.util.cache import torch_gc
arg_parser = argparse.ArgumentParser()
-arg_parser.add_argument('-i', '--interval', action="store", required=False, type=int, default=1, help="The amount of seconds with which to check if there's new prompts to generate")
-arg_parser.add_argument('-a','--api_key', action="store", required=False, type=str, help="The API key corresponding to the owner of this Horde instance")
-arg_parser.add_argument('-n','--worker_name', action="store", required=False, type=str, help="The server name for the Horde. It will be shown to the world and there can be only one.")
-arg_parser.add_argument('-u','--horde_url', action="store", required=False, type=str, help="The SH Horde URL. Where the bridge will pickup prompts and send the finished generations.")
-arg_parser.add_argument('--priority_usernames',type=str, action='append', required=False, help="Usernames which get priority use in this horde instance. The owner's username is always in this list.")
-arg_parser.add_argument('-p','--max_power',type=int, required=False, help="How much power this instance has to generate pictures. Min: 2")
-arg_parser.add_argument('--sfw', action='store_true', required=False, help="Set to true if you do not want this worker generating NSFW images.")
-arg_parser.add_argument('--blacklist', nargs='+', required=False, help="List the words that you want to blacklist.")
-arg_parser.add_argument('--censorlist', nargs='+', required=False, help="List the words that you want to censor.")
-arg_parser.add_argument('--censor_nsfw', action='store_true', required=False, help="Set to true if you want this bridge worker to censor NSFW images.")
-arg_parser.add_argument('--allow_img2img', action='store_true', required=False, help="Set to true if you want this bridge worker to allow img2img request.")
-arg_parser.add_argument('--allow_painting', action='store_true', required=False, help="Set to true if you want this bridge worker to allow inpainting/outpainting requests.")
-arg_parser.add_argument('--allow_unsafe_ip', action='store_true', required=False, help="Set to true if you want this bridge worker to allow img2img requests from unsafe IPs.")
-arg_parser.add_argument('-m', '--model', action='store', required=False, help="Which model to run on this horde.")
-arg_parser.add_argument('--debug', action="store_true", default=False, help="Show debugging messages.")
-arg_parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
-arg_parser.add_argument('-q', '--quiet', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
-arg_parser.add_argument('--log_file', action='store_true', default=False, help="If specified will dump the log to the specified file")
-arg_parser.add_argument('--skip_md5', action='store_true', default=False, help="If specified will not check the downloaded model md5sum.")
-arg_parser.add_argument('--disable_voodoo', action='store_true', default=False, help="If specified this bridge will not use voodooray to offload models into RAM and save VRAM (useful for cloud providers).")
-arg_parser.add_argument('--disable_xformers', action='store_true', default=False, help="If specified this bridge will not try use xformers to speed up generations. This should normally be automatic, but in case you need to disable it manually, you can do so here.")
+arg_parser.add_argument(
+ "-i",
+ "--interval",
+ action="store",
+ required=False,
+ type=int,
+ default=1,
+ help="The amount of seconds with which to check if there's new prompts to generate",
+)
+arg_parser.add_argument(
+ "-a",
+ "--api_key",
+ action="store",
+ required=False,
+ type=str,
+ help="The API key corresponding to the owner of this Horde instance",
+)
+arg_parser.add_argument(
+ "-n",
+ "--worker_name",
+ action="store",
+ required=False,
+ type=str,
+ help="The server name for the Horde. It will be shown to the world and there can be only one.",
+)
+arg_parser.add_argument(
+ "-u",
+ "--horde_url",
+ action="store",
+ required=False,
+ type=str,
+ help="The SH Horde URL. Where the bridge will pickup prompts and send the finished generations.",
+)
+arg_parser.add_argument(
+ "--priority_usernames",
+ type=str,
+ action="append",
+ required=False,
+ help="Usernames which get priority use in this horde instance. The owner's username is always in this list.",
+)
+arg_parser.add_argument(
+ "-p",
+ "--max_power",
+ type=int,
+ required=False,
+ help="How much power this instance has to generate pictures. Min: 2",
+)
+arg_parser.add_argument(
+ "--sfw",
+ action="store_true",
+ required=False,
+ help="Set to true if you do not want this worker generating NSFW images.",
+)
+arg_parser.add_argument(
+ "--blacklist",
+ nargs="+",
+ required=False,
+ help="List the words that you want to blacklist.",
+)
+arg_parser.add_argument(
+ "--censorlist",
+ nargs="+",
+ required=False,
+ help="List the words that you want to censor.",
+)
+arg_parser.add_argument(
+ "--censor_nsfw",
+ action="store_true",
+ required=False,
+ help="Set to true if you want this bridge worker to censor NSFW images.",
+)
+arg_parser.add_argument(
+ "--allow_img2img",
+ action="store_true",
+ required=False,
+ help="Set to true if you want this bridge worker to allow img2img request.",
+)
+arg_parser.add_argument(
+ "--allow_painting",
+ action="store_true",
+ required=False,
+ help="Set to true if you want this bridge worker to allow inpainting/outpainting requests.",
+)
+arg_parser.add_argument(
+ "--allow_unsafe_ip",
+ action="store_true",
+ required=False,
+ help="Set to true if you want this bridge worker to allow img2img requests from unsafe IPs.",
+)
+arg_parser.add_argument(
+ "-m",
+ "--model",
+ action="store",
+ required=False,
+ help="Which model to run on this horde.",
+)
+arg_parser.add_argument("--debug", action="store_true", default=False, help="Show debugging messages.")
+arg_parser.add_argument(
+ "-v",
+ "--verbosity",
+ action="count",
+ default=0,
+ help=(
+ "The default logging level is ERROR or higher. "
+ "This value increases the amount of logging seen in your screen"
+ ),
+)
+arg_parser.add_argument(
+ "-q",
+ "--quiet",
+ action="count",
+ default=0,
+ help=(
+ "The default logging level is ERROR or higher. "
+ "This value decreases the amount of logging seen in your screen"
+ ),
+)
+arg_parser.add_argument(
+ "--log_file",
+ action="store_true",
+ default=False,
+ help="If specified will dump the log to the specified file",
+)
+arg_parser.add_argument(
+ "--skip_md5",
+ action="store_true",
+ default=False,
+ help="If specified will not check the downloaded model md5sum.",
+)
+arg_parser.add_argument(
+ "--disable_voodoo",
+ action="store_true",
+ default=False,
+ help=(
+ "If specified this bridge will not use voodooray to offload models into RAM and save VRAM"
+ " (useful for cloud providers)."
+ ),
+)
+arg_parser.add_argument(
+ "--disable_xformers",
+ action="store_true",
+ default=False,
+ help=(
+ "If specified this bridge will not try use xformers to speed up generations."
+ " This should normally be automatic, but in case you need to disable it manually, you can do so here."
+ ),
+)
args = arg_parser.parse_args()
-from nataili import disable_xformers, disable_voodoo
-disable_xformers.toggle(args.disable_xformers)
-disable_voodoo.toggle(args.disable_voodoo)
-from nataili.inference.diffusers.inpainting import inpainting
-from nataili.inference.compvis.img2img import img2img
-from nataili.model_manager import ModelManager
-from nataili.inference.compvis.txt2img import txt2img
-from nataili.util.cache import torch_gc
-from nataili.util import logger, set_logger_verbosity, quiesce_logger
-from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, ImageChops, UnidentifiedImageError
-from io import BytesIO
-from base64 import binascii
+disable_xformers.toggle(args.disable_xformers)
+disable_voodoo.toggle(args.disable_voodoo)
-import random
-model = ''
+# Note: for now we cannot put them at the top of the file because the imports
+# will use the disable_voodoo and disable_xformers global variables
+from nataili.inference.compvis.img2img import img2img # noqa: E402
+from nataili.inference.compvis.txt2img import txt2img # noqa: E402
+from nataili.inference.diffusers.inpainting import inpainting # noqa: E402
+from nataili.model_manager import ModelManager # noqa: E402
+
+model = ""
max_content_length = 1024
max_length = 80
current_softprompt = None
softprompts = {}
-import os
+
class BridgeData(object):
def __init__(self):
random.seed()
- self.horde_url =os.environ.get("HORDE_URL", "https://stablehorde.net")
+ self.horde_url = os.environ.get("HORDE_URL", "https://stablehorde.net")
# Give a cool name to your instance
- self.worker_name = os.environ.get("HORDE_WORKER_NAME", f"Automated Instance #{random.randint(-100000000, 100000000)}")
+ self.worker_name = os.environ.get(
+ "HORDE_WORKER_NAME",
+ f"Automated Instance #{random.randint(-100000000, 100000000)}",
+ )
# The api_key identifies a unique user in the horde
self.api_key = os.environ.get("HORDE_API_KEY", "0000000000")
# Put other users whose prompts you want to prioritize.
- # The owner's username is always included so you don't need to add it here, unless you want it to have lower priority than another user
- self.priority_usernames = list(filter(lambda a : a,os.environ.get("HORDE_PRIORITY_USERNAMES", "").split(",")))
+ # The owner's username is always included so you don't need to add it here,
+ # unless you want it to have lower priority than another user
+ self.priority_usernames = list(filter(lambda a: a, os.environ.get("HORDE_PRIORITY_USERNAMES", "").split(",")))
self.max_power = int(os.environ.get("HORDE_MAX_POWER", 8))
self.nsfw = os.environ.get("HORDE_NSFW", "true") == "true"
- self.censor_nsfw = os.environ.get("HORDE_CENSOR", "false") == "true"
- self.blacklist = list(filter(lambda a : a,os.environ.get("HORDE_BLACKLIST", "").split(",")))
- self.censorlist = list(filter(lambda a : a,os.environ.get("HORDE_CENSORLIST", "").split(",")))
+ self.censor_nsfw = os.environ.get("HORDE_CENSOR", "false") == "true"
+ self.blacklist = list(filter(lambda a: a, os.environ.get("HORDE_BLACKLIST", "").split(",")))
+ self.censorlist = list(filter(lambda a: a, os.environ.get("HORDE_CENSORLIST", "").split(",")))
self.allow_img2img = os.environ.get("HORDE_IMG2IMG", "true") == "true"
self.allow_painting = os.environ.get("HORDE_PAINTING", "true") == "true"
self.allow_unsafe_ip = os.environ.get("HORDE_ALLOW_UNSAFE_IP", "true") == "true"
self.model_names = os.environ.get("HORDE_MODELNAMES", "stable_diffusion").split(",")
- self.max_pixels = 64*64*8*self.max_power
-
+ self.max_pixels = 64 * 64 * 8 * self.max_power
@logger.catch(reraise=True)
def bridge(interval, model_manager, bd):
- horde_url = bd.horde_url # Will replace later
+ horde_url = bd.horde_url # Will replace later
current_id = None
current_payload = None
loop_retry = 0
while True:
- ### Pop new request from the Horde
+ # Pop new request from the Horde
if loop_retry > 10 and current_id:
logger.error(f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!")
current_id = None
@@ -109,7 +250,12 @@ def bridge(interval, model_manager, bd):
loop_retry += 1
else:
try:
- pop_req = requests.post(horde_url + '/api/v2/generate/pop', json = gen_dict, headers = headers, timeout=10)
+ pop_req = requests.post(
+ horde_url + "/api/v2/generate/pop",
+ json=gen_dict,
+ headers=headers,
+ timeout=10,
+ )
except requests.exceptions.ConnectionError:
logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
time.sleep(10)
@@ -128,38 +274,40 @@ def bridge(interval, model_manager, bd):
logger.error(f"Could not decode response from {horde_url} as json. Please inform its administrator!")
time.sleep(interval)
continue
- if pop == None:
+ if pop is None:
logger.error(f"Something has gone wrong with {horde_url}. Please inform its administrator!")
time.sleep(interval)
continue
if not pop_req.ok:
- message = pop['message']
- logger.warning(f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds...")
- if 'errors' in pop:
+ logger.warning(
+ f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: "
+ f"{pop['message']}. Waiting for 10 seconds..."
+ )
+ if "errors" in pop:
logger.warning(f"Detailed Request Errors: {pop['errors']}")
time.sleep(10)
continue
if not pop.get("id"):
- skipped_info = pop.get('skipped')
+ skipped_info = pop.get("skipped")
if skipped_info and len(skipped_info):
skipped_info = f" Skipped Info: {skipped_info}."
else:
- skipped_info = ''
+ skipped_info = ""
logger.debug(f"Server {horde_url} has no valid generations to do for us.{skipped_info}")
time.sleep(interval)
continue
- current_id = pop['id']
- current_payload = pop['payload']
- ### Generate Image
+ current_id = pop["id"]
+ current_payload = pop["payload"]
+ # Generate Image
model = pop.get("model", available_models[0])
# logger.info([current_id,current_payload])
use_nsfw_censor = current_payload.get("use_nsfw_censor", False)
if bd.censor_nsfw and not bd.nsfw:
use_nsfw_censor = True
- elif any(word in current_payload['prompt'] for word in bd.censorlist):
+ elif any(word in current_payload["prompt"] for word in bd.censorlist):
use_nsfw_censor = True
- use_gfpgan = current_payload.get("use_gfpgan", True)
- use_real_esrgan = current_payload.get("use_real_esrgan", False)
+ # use_gfpgan = current_payload.get("use_gfpgan", True)
+ # use_real_esrgan = current_payload.get("use_real_esrgan", False)
source_processing = pop.get("source_processing")
source_image = pop.get("source_image")
source_mask = pop.get("source_mask")
@@ -168,7 +316,6 @@ def bridge(interval, model_manager, bd):
"prompt": current_payload["prompt"],
"height": current_payload["height"],
"width": current_payload["width"],
- "width": current_payload["width"],
"seed": current_payload["seed"],
"n_iter": 1,
"batch_size": 1,
@@ -176,25 +323,32 @@ def bridge(interval, model_manager, bd):
"save_grid": False,
}
# These params might not always exist in the horde payload
- if 'ddim_steps' in current_payload: gen_payload['ddim_steps'] = current_payload['ddim_steps']
- if 'sampler_name' in current_payload: gen_payload['sampler_name'] = current_payload['sampler_name']
- if 'cfg_scale' in current_payload: gen_payload['cfg_scale'] = current_payload['cfg_scale']
- if 'ddim_eta' in current_payload: gen_payload['ddim_eta'] = current_payload['ddim_eta']
- if 'denoising_strength' in current_payload and source_image:
- gen_payload['denoising_strength'] = current_payload['denoising_strength']
+ if "ddim_steps" in current_payload:
+ gen_payload["ddim_steps"] = current_payload["ddim_steps"]
+ if "sampler_name" in current_payload:
+ gen_payload["sampler_name"] = current_payload["sampler_name"]
+ if "cfg_scale" in current_payload:
+ gen_payload["cfg_scale"] = current_payload["cfg_scale"]
+ if "ddim_eta" in current_payload:
+ gen_payload["ddim_eta"] = current_payload["ddim_eta"]
+ if "denoising_strength" in current_payload and source_image:
+ gen_payload["denoising_strength"] = current_payload["denoising_strength"]
# logger.debug(gen_payload)
req_type = "txt2img"
if source_image:
- img_source = None
- img_mask = None
- if source_processing == "img2img":
- req_type = "img2img"
- elif source_processing == "inpainting":
- req_type = "inpainting"
- if source_processing == "outpainting":
- req_type = "outpainting"
+ img_source = None
+ img_mask = None
+ if source_processing == "img2img":
+ req_type = "img2img"
+ elif source_processing == "inpainting":
+ req_type = "inpainting"
+ if source_processing == "outpainting":
+ req_type = "outpainting"
# Prevent inpainting from picking text2img and img2img gens (as those go via compvis pipelines)
- if model == "stable_diffusion_inpainting" and req_type not in ["inpainting","outpainting"]:
+ if model == "stable_diffusion_inpainting" and req_type not in [
+ "inpainting",
+ "outpainting",
+ ]:
# Try to find any other model to do text2img or img2img
for m in available_models:
if m != "stable_diffusion_inpainting":
@@ -202,41 +356,60 @@ def bridge(interval, model_manager, bd):
# if the model persists as inpainting for text2img or img2img, we abort.
if model == "stable_diffusion_inpainting":
# We remove the base64 from the prompt to avoid flooding the output on the error
- if len(pop.get("source_image",'')) > 10:
- pop["source_image"] = len(pop.get("source_image",''))
- if len(pop.get("source_mask",'')) > 10:
- pop["source_mask"] = len(pop.get("source_mask",''))
- logger.error(f"Received an non-inpainting request for inpainting model. This shouldn't happen. Inform the developer. Current payload {pop}")
+ if len(pop.get("source_image", "")) > 10:
+ pop["source_image"] = len(pop.get("source_image", ""))
+ if len(pop.get("source_mask", "")) > 10:
+ pop["source_mask"] = len(pop.get("source_mask", ""))
+ logger.error(
+ "Received an non-inpainting request for inpainting model. This shouldn't happen. "
+ f"Inform the developer. Current payload {pop}"
+ )
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
continue
- ## TODO: Send faulted
+ # TODO: Send faulted
logger.debug(f"{req_type} ({model}) request with id {current_id} picked up. Initiating work...")
try:
- safety_checker = model_manager.loaded_models['safety_checker']['model'] if 'safety_checker' in model_manager.loaded_models else None
+ safety_checker = (
+ model_manager.loaded_models["safety_checker"]["model"]
+ if "safety_checker" in model_manager.loaded_models
+ else None
+ )
if source_image:
- base64_bytes = source_image.encode('utf-8')
+ base64_bytes = source_image.encode("utf-8")
img_bytes = base64.b64decode(base64_bytes)
img_source = Image.open(BytesIO(img_bytes))
if source_mask:
- base64_bytes = source_mask.encode('utf-8')
+ base64_bytes = source_mask.encode("utf-8")
img_bytes = base64.b64decode(base64_bytes)
img_mask = Image.open(BytesIO(img_bytes))
if img_mask.size != img_source.size:
- logger.warning(f"Source image/mask mismatch. Resizing mask from {img_mask.size} to {img_source.size}")
+ logger.warning(
+ f"Source image/mask mismatch. Resizing mask from {img_mask.size} to {img_source.size}"
+ )
img_mask = img_mask.resize(img_source.size)
if req_type == "img2img":
- gen_payload['init_img'] = img_source
- generator = img2img(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'bridge_generations',
- load_concepts=True, concepts_dir='models/custom/sd-concepts-library', safety_checker=safety_checker, filter_nsfw=use_nsfw_censor,
- disable_voodoo=disable_voodoo.active)
+ gen_payload["init_img"] = img_source
+ generator = img2img(
+ model_manager.loaded_models[model]["model"],
+ model_manager.loaded_models[model]["device"],
+ "bridge_generations",
+ load_concepts=True,
+ concepts_dir="models/custom/sd-concepts-library",
+ safety_checker=safety_checker,
+ filter_nsfw=use_nsfw_censor,
+ disable_voodoo=disable_voodoo.active,
+ )
elif req_type == "inpainting" or req_type == "outpainting":
# These variables do not exist in the outpainting implementation
- if "save_grid" in gen_payload: del gen_payload["save_grid"]
- if "sampler_name" in gen_payload: del gen_payload["sampler_name"]
- if "denoising_strength" in gen_payload: del gen_payload["denoising_strength"]
+ if "save_grid" in gen_payload:
+ del gen_payload["save_grid"]
+ if "sampler_name" in gen_payload:
+ del gen_payload["sampler_name"]
+ if "denoising_strength" in gen_payload:
+ del gen_payload["denoising_strength"]
# We prevent sending an inpainting without mask or transparency, as it will crash us.
if img_mask is None:
try:
@@ -248,33 +421,60 @@ def bridge(interval, model_manager, bd):
current_generation = None
loop_retry = 0
continue
- ## TODO: Send faulted
+ # TODO: Send faulted
- gen_payload['inpaint_img'] = img_source
+ gen_payload["inpaint_img"] = img_source
if img_mask:
- gen_payload['inpaint_mask'] = img_mask
- generator = inpainting(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'bridge_generations',filter_nsfw=use_nsfw_censor)
+ gen_payload["inpaint_mask"] = img_mask
+ generator = inpainting(
+ model_manager.loaded_models[model]["model"],
+ model_manager.loaded_models[model]["device"],
+ "bridge_generations",
+ filter_nsfw=use_nsfw_censor,
+ )
else:
- generator = txt2img(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'bridge_generations',
- load_concepts=True, concepts_dir='models/custom/sd-concepts-library', safety_checker=safety_checker, filter_nsfw=use_nsfw_censor,
- disable_voodoo=disable_voodoo.active)
+ generator = txt2img(
+ model_manager.loaded_models[model]["model"],
+ model_manager.loaded_models[model]["device"],
+ "bridge_generations",
+ load_concepts=True,
+ concepts_dir="models/custom/sd-concepts-library",
+ safety_checker=safety_checker,
+ filter_nsfw=use_nsfw_censor,
+ disable_voodoo=disable_voodoo.active,
+ )
except KeyError:
continue
# If the received image is unreadable, we continue
except UnidentifiedImageError:
- logger.error(f"Source image received for img2img is unreadable. Falling back to text2img!")
- if 'denoising_strength' in gen_payload:
- del gen_payload['denoising_strength']
- generator = txt2img(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'bridge_generations', load_concepts=True, concepts_dir='models/custom/sd-concepts-library')
+ logger.error("Source image received for img2img is unreadable. Falling back to text2img!")
+ if "denoising_strength" in gen_payload:
+ del gen_payload["denoising_strength"]
+ generator = txt2img(
+ model_manager.loaded_models[model]["model"],
+ model_manager.loaded_models[model]["device"],
+ "bridge_generations",
+ load_concepts=True,
+ concepts_dir="models/custom/sd-concepts-library",
+ )
except binascii.Error:
- logger.error(f"Source image received for img2img is cannot be base64 decoded (binascii.Error). Falling back to text2img!")
- if 'denoising_strength' in gen_payload:
- del gen_payload['denoising_strength']
- generator = txt2img(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'bridge_generations', load_concepts=True, concepts_dir='models/custom/sd-concepts-library')
+ logger.error(
+ "Source image received for img2img is cannot be base64 decoded (binascii.Error). "
+ "Falling back to text2img!"
+ )
+ if "denoising_strength" in gen_payload:
+ del gen_payload["denoising_strength"]
+ generator = txt2img(
+ model_manager.loaded_models[model]["model"],
+ model_manager.loaded_models[model]["device"],
+ "bridge_generations",
+ load_concepts=True,
+ concepts_dir="models/custom/sd-concepts-library",
+ )
generator.generate(**gen_payload)
torch_gc()
- ### Submit back to horde
+ # Submit back to horde
# images, seed, info, stats = txt2img(**current_payload)
buffer = BytesIO()
# We send as WebP to avoid using all the horde bandwidth
@@ -290,45 +490,63 @@ def bridge(interval, model_manager, bd):
"max_pixels": bd.max_pixels,
}
current_generation = seed
- while current_id and current_generation != None:
+ while current_id and current_generation is not None:
try:
- submit_req = requests.post(horde_url + '/api/v2/generate/submit', json = submit_dict, headers = headers, timeout=20)
+ submit_req = requests.post(
+ horde_url + "/api/v2/generate/submit",
+ json=submit_dict,
+ headers=headers,
+ timeout=20,
+ )
try:
submit = submit_req.json()
except json.decoder.JSONDecodeError:
- logger.error(f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)")
+ logger.error(
+ f"Something has gone wrong with {horde_url} during submit. "
+ f"Please inform its administrator! (Retry {loop_retry}/10)"
+ )
time.sleep(interval)
continue
if submit_req.status_code == 404:
- logger.warning(f"The generation we were working on got stale. Aborting!")
+ logger.warning("The generation we were working on got stale. Aborting!")
elif not submit_req.ok:
- logger.warning(f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)")
- if 'errors' in submit:
+ logger.warning(
+ f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: "
+ f"{submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)"
+ )
+ if "errors" in submit:
logger.warning(f"Detailed Request Errors: {submit['errors']}")
time.sleep(10)
continue
else:
- logger.info(f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}')
+ logger.info(
+ f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}'
+ )
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
except requests.exceptions.ConnectionError:
- logger.warning(f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)")
+ logger.warning(
+ f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)"
+ )
time.sleep(10)
continue
except requests.exceptions.ReadTimeout:
- logger.warning(f"Server {horde_url} timed out during submit. Waiting 10 seconds... (Retry {loop_retry}/10)")
+ logger.warning(
+ f"Server {horde_url} timed out during submit. Waiting 10 seconds... (Retry {loop_retry}/10)"
+ )
time.sleep(10)
continue
time.sleep(interval)
+
def check_mm_auth(model_manager):
if model_manager.has_authentication():
return
try:
- from creds import hf_username,hf_password
- except:
+ from creds import hf_password, hf_username
+ except ImportError:
hf_username = input("Please type your huggingface.co username: ")
hf_password = getpass.getpass("Please type your huggingface.co Access Token or password: ")
hf_auth = {"username": hf_username, "password": hf_password}
@@ -338,64 +556,79 @@ def check_mm_auth(model_manager):
@logger.catch(reraise=True)
def check_models(models, mm):
logger.init("Models", status="Checking")
- from os.path import exists
- import sys
+
models_exist = True
not_found_models = []
for model in models:
model_info = mm.get_model(model)
if not model_info:
- logger.error(f"Model name requested {model} in bridgeData is unknown to us. Please check your configuration. Aborting!")
+ logger.error(
+ f"Model name requested {model} in bridgeData is unknown to us. "
+ "Please check your configuration. Aborting!"
+ )
sys.exit(1)
if not args.skip_md5 and not mm.validate_model(model):
models_exist = False
not_found_models.append(model)
# Diffusers library uses its own internal download mechanism
- if model_info['type'] == 'diffusers' and model_info['hf_auth']:
+ if model_info["type"] == "diffusers" and model_info["hf_auth"]:
check_mm_auth(mm)
if not models_exist:
- choice = input(f"You do not appear to have downloaded the models needed yet.\nYou need at least a main model to proceed. Would you like to download your prespecified models?\n\
+ choice = input(
+ "You do not appear to have downloaded the models needed yet.\nYou need at least a main model to proceed. "
+ f"Would you like to download your prespecified models?\n\
y: Download {not_found_models} (default).\n\
n: Abort and exit\n\
all: Download all models (This can take a significant amount of time and bandwidth)?\n\
- Please select an option: ")
- if choice not in ['y', 'Y', '', 'yes', 'all', 'a']:
+ Please select an option: "
+ )
+ if choice not in ["y", "Y", "", "yes", "all", "a"]:
sys.exit(1)
needs_hf = False
for model in not_found_models:
dl = mm.get_model_download(model)
for m in dl:
- if m.get('hf_auth', False):
+ if m.get("hf_auth", False):
needs_hf = True
- if choice in ['all', 'a']:
+ if choice in ["all", "a"]:
needs_hf = True
if needs_hf:
check_mm_auth(mm)
mm.init()
mm.taint_models(not_found_models)
- if choice in ['all', 'a']:
- mm.download_all()
- elif choice in ['y', 'Y', '', 'yes']:
+ if choice in ["all", "a"]:
+ mm.download_all()
+ elif choice in ["y", "Y", "", "yes"]:
for model in not_found_models:
logger.init(f"Model: {model}", status="Downloading")
if not mm.download_model(model):
- logger.message("Something went wrong when downloading the model and it does not fit the expected checksum. Please check that your HuggingFace authentication is correct and that you've accepted the model license from the browser.")
+ logger.message(
+ "Something went wrong when downloading the model and it does not fit the expected checksum. "
+ "Please check that your HuggingFace authentication is correct and that you've accepted the "
+ "model license from the browser."
+ )
sys.exit(1)
logger.init_ok("Models", status="OK")
- if exists('./bridgeData.py'):
+ if os.path.exists("./bridgeData.py"):
logger.init_ok("Bridge Config", status="OK")
- elif input("You do not appear to have a bridgeData.py. Would you like to create it from the template now? (y/n)") in ['y', 'Y', '', 'yes']:
- with open('bridgeData_template.py','r') as firstfile, open('bridgeData.py','a') as secondfile:
+ elif input(
+ "You do not appear to have a bridgeData.py. Would you like to create it from the template now? (y/n)"
+ ) in ["y", "Y", "", "yes"]:
+ with open("bridgeData_template.py", "r") as firstfile, open("bridgeData.py", "a") as secondfile:
for line in firstfile:
secondfile.write(line)
- logger.message("bridgeData.py created. Bridge will exit. Please edit bridgeData.py with your setup and restart the bridge")
+ logger.message(
+ "bridgeData.py created. Bridge will exit. Please edit bridgeData.py with your setup and restart the bridge"
+ )
sys.exit(2)
-
+
+
@logger.catch(reraise=True)
def load_bridge_data():
bridge_data = BridgeData()
try:
import bridgeData as bd
+
bridge_data.api_key = bd.api_key
bridge_data.worker_name = bd.worker_name
bridge_data.horde_url = bd.horde_url
@@ -403,7 +636,7 @@ def load_bridge_data():
bridge_data.max_power = bd.max_power
bridge_data.model_names = bd.models_to_load
try:
- bridge_data.nsfw = bd.nsfw
+ bridge_data.nsfw = bd.nsfw
except AttributeError:
pass
try:
@@ -430,33 +663,47 @@ def load_bridge_data():
bridge_data.allow_unsafe_ip = bd.allow_unsafe_ip
except AttributeError:
pass
- except:
+ except (ImportError, AttributeError):
logger.warning("bridgeData.py could not be loaded. Using defaults with anonymous account")
- if args.api_key: bridge_data.api_key = args.api_key
- if args.worker_name: bridge_data.worker_name = args.worker_name
- if args.horde_url: bridge_data.horde_url = args.horde_url
- if args.priority_usernames: bridge_data.priority_usernames = args.priority_usernames
- if args.max_power: bridge_data.max_power = args.max_power
- if args.model: bridge_data.model = [args.model]
- if args.sfw: bridge_data.nsfw = False
- if args.censor_nsfw: bridge_data.censor_nsfw = args.censor_nsfw
- if args.blacklist: bridge_data.blacklist = args.blacklist
- if args.censorlist: bridge_data.censorlist = args.censorlist
- if args.allow_img2img: bridge_data.allow_img2img = args.allow_img2img
- if args.allow_painting: bridge_data.allow_painting = args.allow_painting
- if args.allow_unsafe_ip: bridge_data.allow_unsafe_ip = args.allow_unsafe_ip
+ if args.api_key:
+ bridge_data.api_key = args.api_key
+ if args.worker_name:
+ bridge_data.worker_name = args.worker_name
+ if args.horde_url:
+ bridge_data.horde_url = args.horde_url
+ if args.priority_usernames:
+ bridge_data.priority_usernames = args.priority_usernames
+ if args.max_power:
+ bridge_data.max_power = args.max_power
+ if args.model:
+ bridge_data.model = [args.model]
+ if args.sfw:
+ bridge_data.nsfw = False
+ if args.censor_nsfw:
+ bridge_data.censor_nsfw = args.censor_nsfw
+ if args.blacklist:
+ bridge_data.blacklist = args.blacklist
+ if args.censorlist:
+ bridge_data.censorlist = args.censorlist
+ if args.allow_img2img:
+ bridge_data.allow_img2img = args.allow_img2img
+ if args.allow_painting:
+ bridge_data.allow_painting = args.allow_painting
+ if args.allow_unsafe_ip:
+ bridge_data.allow_unsafe_ip = args.allow_unsafe_ip
if bridge_data.max_power < 2:
bridge_data.max_power = 2
- bridge_data.max_pixels = 64*64*8*bridge_data.max_power
+ bridge_data.max_pixels = 64 * 64 * 8 * bridge_data.max_power
if bridge_data.censor_nsfw or len(bridge_data.censorlist):
- bridge_data.model_names.append('safety_checker')
- return(bridge_data)
+ bridge_data.model_names.append("safety_checker")
+ return bridge_data
+
if __name__ == "__main__":
-
+
set_logger_verbosity(args.verbosity)
if args.log_file:
- logger.add("koboldai_bridge_log.log", retention="7 days", level="warning") # Automatically rotate too big file
+ logger.add("koboldai_bridge_log.log", retention="7 days", level="warning") # Automatically rotate too big file
quiesce_logger(args.quiet)
bd = load_bridge_data()
# test_logger()
@@ -464,15 +711,21 @@ def load_bridge_data():
check_models(bd.model_names, model_manager)
model_manager.init()
for model in bd.model_names:
- logger.init(f'{model}', status="Loading")
+ logger.init(f"{model}", status="Loading")
success = model_manager.load_model(model)
if success:
- logger.init_ok(f'{model}', status="Loaded")
+ logger.init_ok(f"{model}", status="Loaded")
else:
- logger.init_err(f'{model}', status="Error")
- logger.init(f"API Key '{bd.api_key}'. Server Name '{bd.worker_name}'. Horde URL '{bd.horde_url}'. Max Pixels {bd.max_pixels}", status="Joining Horde")
+ logger.init_err(f"{model}", status="Error")
+ logger.init(
+ (
+ f"API Key '{bd.api_key}'. Server Name '{bd.worker_name}'. "
+ f"Horde URL '{bd.horde_url}'. Max Pixels {bd.max_pixels}"
+ ),
+ status="Joining Horde",
+ )
try:
bridge(args.interval, model_manager, bd)
except KeyboardInterrupt:
- logger.info(f"Keyboard Interrupt Received. Ending Process")
+ logger.info("Keyboard Interrupt Received. Ending Process")
logger.init(f"{bd.worker_name} Instance", status="Stopped")
diff --git a/bridgeData_template.py b/bridgeData_template.py
index 95b28b96..7d0454b3 100644
--- a/bridgeData_template.py
+++ b/bridgeData_template.py
@@ -6,7 +6,8 @@
# Visit https://stablehorde.net/register to create one before you can join
api_key = "0000000000"
# Put other users whose prompts you want to prioritize.
-# The owner's username is always included so you don't need to add it here, unless you want it to have lower priority than another user
+# The owner's username is always included so you don't need to add it here,
+# unless you want it to have lower priority than another user
priority_usernames = []
# The amount of power your system can handle
# 8 means 512*512. Each increase increases the possible resoluion by 64 pixes
@@ -28,23 +29,25 @@
# If set to False, this worker will no longer pick img2img jobs from unsafe IPs
allow_unsafe_ip = True
# The models to use. You can select a different main model, or select more than one.
-# With you can easily load 5 of these models with 32Gb RAM and 6G VRAM. Adjust how many models you load based on how much RAM (not VRAM) you have available
+# With you can easily load 5 of these models with 32Gb RAM and 6G VRAM.
+# Adjust how many models you load based on how much RAM (not VRAM) you have available.
# The last model in this list takes priority when the client accepts more than 1
# if you do not know which models you can add here, use the below command
# python show_available_models.py
models_to_load = [
- "stable_diffusion", # This is the standard compvis model. It is not using Diffusers (yet)
- ## Specialized Style models
+ "stable_diffusion", # This is the standard compvis model. It is not using Diffusers (yet)
+ # Specialized Style models:
# "trinart",
# "Furry Epoch",
# "Yiffy",
# "waifu_diffusion",
- ## Dreambooth Models
+ # Dreambooth Models:
# "Arcane Diffusion",
# "Spier-Verse Diffusion",
# "Elden Ring Diffusion",
# "Robo-Diffusion",
# "mo-di-diffusion",
-
- # "stable_diffusion_inpainting", # Enable this to allow inpainting/outpainting. Careful of trying to enable this in tandem with other models if you have 8G or less VRAM!
-]
\ No newline at end of file
+ # "stable_diffusion_inpainting",
+ # Enable this to allow inpainting/outpainting.
+ # Careful of trying to enable this in tandem with other models if you have 8G or less VRAM!
+]
diff --git a/creds_template.py b/creds_template.py
index 15b05531..ab5bc204 100644
--- a/creds_template.py
+++ b/creds_template.py
@@ -1,2 +1,2 @@
hf_username = "username"
-hf_password = "**********"
\ No newline at end of file
+hf_password = "**********"
diff --git a/nataili/__init__.py b/nataili/__init__.py
index 751567c9..742005bf 100644
--- a/nataili/__init__.py
+++ b/nataili/__init__.py
@@ -1,3 +1,4 @@
from .util.switch import Switch
+
disable_xformers = Switch()
disable_voodoo = Switch()
diff --git a/nataili/inference/compvis/img2img.py b/nataili/inference/compvis/img2img.py
index 29700185..9412f9f4 100644
--- a/nataili/inference/compvis/img2img.py
+++ b/nataili/inference/compvis/img2img.py
@@ -1,41 +1,54 @@
import os
import re
import sys
+from contextlib import nullcontext
+
import k_diffusion as K
-import tqdm
-from contextlib import contextmanager, nullcontext
-import skimage
import numpy as np
import PIL
+import skimage
import torch
+import tqdm
from einops import rearrange
+from slugify import slugify
from transformers import CLIPFeatureExtractor
+
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler
from ldm.models.diffusion.plms import PLMSSampler
+from nataili.util import logger
from nataili.util.cache import torch_gc
from nataili.util.check_prompt_length import check_prompt_length
from nataili.util.get_next_sequence_number import get_next_sequence_number
-from nataili.util.image_grid import image_grid
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
from nataili.util.save_sample import save_sample
from nataili.util.seed_to_int import seed_to_int
-from slugify import slugify
-import PIL
-from nataili.util import logger
+
try:
from nataili.util.voodoo import load_from_plasma, performance
except ModuleNotFoundError as e:
from nataili import disable_voodoo
+
if not disable_voodoo.active:
raise e
class img2img:
- def __init__(self, model, device, output_dir, save_extension='jpg',
- output_file_path=False, load_concepts=False, concepts_dir=None,
- verify_input=True, auto_cast=True, filter_nsfw=False, safety_checker=None,
- disable_voodoo=False):
+ def __init__(
+ self,
+ model,
+ device,
+ output_dir,
+ save_extension="jpg",
+ output_file_path=False,
+ load_concepts=False,
+ concepts_dir=None,
+ verify_input=True,
+ auto_cast=True,
+ filter_nsfw=False,
+ safety_checker=None,
+ disable_voodoo=False,
+ ):
self.model = model
self.output_dir = output_dir
self.output_file_path = output_file_path
@@ -47,8 +60,8 @@ def __init__(self, model, device, output_dir, save_extension='jpg',
self.device = device
self.comments = []
self.output_images = []
- self.info = ''
- self.stats = ''
+ self.info = ""
+ self.stats = ""
self.images = []
self.filter_nsfw = filter_nsfw
self.safety_checker = safety_checker
@@ -73,22 +86,27 @@ def process_prompt_tokens(self, prompt_tokens, model):
# tokenizer = model.cond_stage_model.tokenizer
# text_encoder = model.cond_stage_model.transformer
# diffusers codebase
- #tokenizer = pipe.tokenizer
- #text_encoder = pipe.text_encoder
+ # tokenizer = pipe.tokenizer
+ # text_encoder = pipe.text_encoder
- ext = ('.pt', '.bin')
+ ext = (".pt", ".bin")
for token_name in prompt_tokens:
- embedding_path = os.path.join(self.concepts_dir, token_name)
+ embedding_path = os.path.join(self.concepts_dir, token_name)
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
- load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", model.cond_stage_model.transformer, model.cond_stage_model.tokenizer, f"<{token_name}>")
+ load_learned_embed_in_clip(
+ f"{os.path.join(embedding_path, files)}",
+ model.cond_stage_model.transformer,
+ model.cond_stage_model.tokenizer,
+ f"<{token_name}>",
+ )
else:
print(f"Concept {token_name} not found in {self.concepts_dir}")
return
def resize_image(self, resize_mode, im, width, height):
- LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
+ LANCZOS = PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, "Resampling") else PIL.Image.LANCZOS
if resize_mode == "resize":
res = im.resize((width, height), resample=LANCZOS)
elif resize_mode == "crop":
@@ -114,42 +132,61 @@ def resize_image(self, resize_mode, im, width, height):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ res.paste(
+ resized.resize((width, fill_height), box=(0, 0, width, 0)),
+ box=(0, 0),
+ )
+ res.paste(
+ resized.resize(
+ (width, fill_height),
+ box=(0, resized.height, width, resized.height),
+ ),
+ box=(0, fill_height + src_h),
+ )
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ res.paste(
+ resized.resize((fill_width, height), box=(0, 0, 0, height)),
+ box=(0, 0),
+ )
+ res.paste(
+ resized.resize(
+ (fill_width, height),
+ box=(resized.width, 0, resized.width, height),
+ ),
+ box=(fill_width + src_w, 0),
+ )
return res
-
+
#
+
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(self, data):
- if data.ndim > 2: # has channels
+ if data.ndim > 2: # has channels
out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
- c_data = data[:,:,c]
- out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
- out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
- else: # one channel
+ c_data = data[:, :, c]
+ out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
+ else: # one channel
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
- out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
- out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(self, data):
- if data.ndim > 2: # has channels
+ if data.ndim > 2: # has channels
out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
- c_data = data[:,:,c]
- out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
- out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
- else: # one channel
+ c_data = data[:, :, c]
+ out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
+ else: # one channel
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
- out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
- out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
@@ -159,51 +196,64 @@ def _get_gaussian_window(self, width, height, std=3.14, mode=0):
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
- x = (np.arange(width) / width * 2. - 1.) * window_scale_x
+ x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
- fy = (y / height * 2. - 1.) * window_scale_y
+ fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
- window[:, y] = np.exp(-(x**2+fy**2) * std)
+ window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
- window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
+ window[:, y] = (1 / ((x**2 + 1.0) * (fy**2 + 1.0))) ** (
+ std / 3.14
+ ) # hey wait a minute that's not gaussian
return window
- def _get_masked_window_rgb(self, np_mask_grey, hardness=1.):
+ def _get_masked_window_rgb(self, np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
- if hardness != 1.:
+ if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
- np_mask_rgb[:,:,c] = hardened[:]
+ np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation):
"""
Explanation:
Getting good results in/out-painting with stable diffusion can be challenging.
- Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
- in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
+ Although there are simpler effective solutions for in-painting, out-painting can be especially challenging
+ because there is no color data in the masked area to help prompt the generator.
+
+ Ideally, even for in-painting we'd like work effectively without that data as well.
Provided here is my take on a potential solution to this problem.
- By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
- Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
- by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
- be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
- or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
- This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
- The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
+ By taking a fourier transform of the masked src img we get a function that tells us the presence and
+ orientation of each feature scale in the unmasked src.
+ Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations,
+ and positions increases output coherence by helping keep features aligned.
+ This technique is applicable to any continuous generation task such as audio or video, each of which can
+ be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased.
+ For multi-channel data such as color or stereo sound the "color tone" or histogram of the seed noise
+ can be matched to improve quality (using scikit-image currently)
+ This method is quite robust and has the added benefit of being fast independently of the size of the
+ out-painted area.
+ The effects of this method include things like helping the generator integrate the pre-existing
+ view distance and camera angle.
Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
- noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
- color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
+ noise_q controls the exponent in the fall-off of the distribution can be any positive number,
+ lower values means higher detail (range > 0, default 1.)
+ color_variation controls how much freedom is allowed for the colors/palette of the out-painted area
+ (range 0..1, default 0.01)
This code is provided as is under the Unlicense (https://unlicense.org/)
- Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
+ Although you have no obligation to do so, if you found this code helpful please find it in your heart
+ to credit me [parlance-zz].
Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
- This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
+ This code is part of a new branch of a discord bot I am working on integrating with diffusers
+ (https://github.com/parlance-zz/g-diffuser-bot)
"""
@@ -214,72 +264,96 @@ def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
- np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
- np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
- np_src_grey = (np.sum(np_src_image, axis=2)/3.)
- all_mask = np.ones((width, height), dtype=bool)
+ # FIXME: the commented lines are never used. remove?
+ # np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
+ np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
+ # np_src_grey = np.sum(np_src_image, axis=2) / 3.0
+ # all_mask = np.ones((width, height), dtype=bool)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
- windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey))
+ windowed_image = _np_src_image * (1.0 - self._get_masked_window_rgb(np_mask_grey))
windowed_image /= np.max(windowed_image)
- windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
- #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
- #_save_debug_img(windowed_image, "windowed_src_img")
-
- src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
+ windowed_image += np.average(_np_src_image) * np_mask_rgb
+ # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black,
+ # we get better results from fft by filling the average unmasked color
+ # windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) /
+ # (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
+ # _save_debug_img(windowed_image, "windowed_src_img")
+
+ src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
- #_save_debug_img(src_dist, "windowed_src_dist")
+ # _save_debug_img(src_dist, "windowed_src_dist")
noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
noise_rgb = np.random.random_sample((width, height, num_channels))
- noise_grey = (np.sum(noise_rgb, axis=2)/3.)
- noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
+ noise_grey = np.sum(noise_rgb, axis=2) / 3.0
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels):
- noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
+ noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = self._fft2(noise_rgb)
for c in range(num_channels):
- noise_fft[:,:,c] *= noise_window
+ noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(self._ifft2(noise_fft))
shaped_noise_fft = self._fft2(noise_rgb)
- shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
+ shaped_noise_fft[:, :, :] = (
+ np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist**noise_q) * src_phase
+ ) # perform the actual shaping
- brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
- contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
+ brightness_variation = 0.0 # color_variation
+ # todo: temporarily tieing brightness variation to color variation for now
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.0) - brightness_variation * 2.0
# scikit-image is used for histogram matching, very convenient!
shaped_noise = np.real(self._ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
- shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
- shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
- #_save_debug_img(shaped_noise, "shaped_noise")
+ shaped_noise[img_mask, :] = skimage.exposure.match_histograms(
+ shaped_noise[img_mask, :] ** 1.0,
+ contrast_adjusted_np_src[ref_mask, :],
+ channel_axis=1,
+ )
+ shaped_noise = _np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
+ # _save_debug_img(shaped_noise, "shaped_noise")
matched_noise = np.zeros((width, height, num_channels))
matched_noise = shaped_noise[:]
- #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
- #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
+ # matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:],
+ # _np_src_image[ref_mask,:], channel_axis=1)
+ # matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
- #_save_debug_img(matched_noise, "matched_noise")
+ # _save_debug_img(matched_noise, "matched_noise")
"""
todo:
- color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
+ color_variation doesnt have to be a single number,
+ the overall color tone of the out-painted area could be param controlled
"""
- return np.clip(matched_noise, 0., 1.)
-
- def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
+ return np.clip(matched_noise, 0.0, 1.0)
+
+ def find_noise_for_image(
+ self,
+ model,
+ device,
+ init_image,
+ prompt,
+ steps=200,
+ cond_scale=2.0,
+ verbose=False,
+ normalize=False,
+ generation_callback=None,
+ ):
image = np.array(init_image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
- image = 2. * image - 1.
+ image = 2.0 * image - 1.0
image = image.to(device)
x = model.get_first_stage_encoding(model.encode_first_stage(image))
- uncond = model.get_learned_conditioning([''])
+ uncond = model.get_learned_conditioning([""])
cond = model.get_learned_conditioning([prompt])
s_in = x.new_ones([x.shape[0]])
@@ -315,17 +389,37 @@ def find_noise_for_image(self, model, device, init_image, prompt, steps=200, con
x = x + d * dt
return x / sigmas[-1]
+
@performance
- def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed',
- denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
- height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
+ def generate(
+ self,
+ prompt: str,
+ init_img=None,
+ init_mask=None,
+ mask_mode="mask",
+ resize_mode="resize",
+ noise_mode="seed",
+ denoising_strength: float = 0.8,
+ ddim_steps=50,
+ sampler_name="k_lms",
+ n_iter=1,
+ batch_size=1,
+ cfg_scale=7.5,
+ seed=None,
+ height=512,
+ width=512,
+ save_individual_images: bool = True,
+ save_grid: bool = True,
+ ddim_eta: float = 0.0,
+ ):
torch_gc()
+
def process_init_mask(init_mask: PIL.Image):
if init_mask.mode == "RGBA":
- init_mask = init_mask.convert('RGBA')
- background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0))
+ init_mask = init_mask.convert("RGBA")
+ background = PIL.Image.new("RGBA", init_mask.size, (0, 0, 0))
init_mask = PIL.Image.alpha_composite(background, init_mask)
- init_mask = init_mask.convert('RGB')
+ init_mask = init_mask.convert("RGB")
return init_mask
if mask_mode == "mask":
@@ -336,77 +430,83 @@ def process_init_mask(init_mask: PIL.Image):
init_mask = process_init_mask(init_mask)
init_mask = PIL.ImageOps.invert(init_mask)
elif mask_mode == "alpha":
- init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
+ init_img_transparency = init_img.split()[-1].convert(
+ "L"
+ ) # .point(lambda x: 255 if x > 0 else 0, mode='1')
init_mask = init_img_transparency
init_mask = init_mask.convert("RGB")
init_mask = self.resize_image(resize_mode, init_mask, width, height)
init_mask = init_mask.convert("RGB")
- assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ assert 0.0 <= denoising_strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(denoising_strength * ddim_steps)
- if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None:
+ if (
+ init_mask is not None
+ and (noise_mode == "matched" or noise_mode == "find_and_matched")
+ and init_img is not None
+ ):
noise_q = 0.99
color_variation = 0.0
mask_blend_factor = 1.0
- np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
- np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
+ np_init = (np.asarray(init_img.convert("RGB")) / 255.0).astype(
+ np.float64
+ ) # annoyingly complex mask fixing
+ np_mask_rgb = 1.0 - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB")) / 255.0).astype(np.float64)
np_mask_rgb -= np.min(np_mask_rgb)
np_mask_rgb /= np.max(np_mask_rgb)
- np_mask_rgb = 1. - np_mask_rgb
- np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
- blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
- blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
- #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
- #np_mask_rgb = np_mask_rgb + blurred
- np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
- np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
+ np_mask_rgb = 1.0 - np_mask_rgb
+ np_mask_rgb_hardened = 1.0 - (np_mask_rgb < 0.99).astype(np.float64)
+ blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16.0, channel_axis=2, truncate=32.0)
+ blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16.0, channel_axis=2, truncate=32.0)
+ # np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
+ # np_mask_rgb = np_mask_rgb + blurred
+ np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0.0, 1.0)
+ np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0.0, 1.0)
noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
- blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
+ blend_mask_rgb = np.clip(np_mask_rgb_dilated, 0.0, 1.0) ** (mask_blend_factor)
noised = noise_rgb[:]
- blend_mask_rgb **= (2.)
- noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
+ blend_mask_rgb **= 2.0
+ noised = np_init[:] * (1.0 - blend_mask_rgb) + noised * blend_mask_rgb
- np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
+ np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
ref_mask = np_mask_grey < 1e-3
all_mask = np.ones((height, width), dtype=bool)
- noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
+ noised[all_mask, :] = skimage.exposure.match_histograms(
+ noised[all_mask, :] ** 1.0, noised[ref_mask, :], channel_axis=1
+ )
- init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+ init_img = PIL.Image.fromarray(np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8), mode="RGB")
if not self.disable_voodoo:
with load_from_plasma(self.model, self.device) as model:
seed = seed_to_int(seed)
- image_dict = {
- "seed": seed
- }
+ image_dict = {"seed": seed}
# Init image is assumed to be a PIL image
- init_img = self.resize_image('resize', init_img, width, height)
- if sampler_name == 'PLMS':
+ init_img = self.resize_image("resize", init_img, width, height)
+ if sampler_name == "PLMS":
sampler = PLMSSampler(model)
- elif sampler_name == 'DDIM':
+ elif sampler_name == "DDIM":
sampler = DDIMSampler(model)
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(model,'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(model,'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(model,'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(model,'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(model,'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(model,'lms')
+ elif sampler_name == "k_dpm_2_a":
+ sampler = KDiffusionSampler(model, "dpm_2_ancestral")
+ elif sampler_name == "k_dpm_2":
+ sampler = KDiffusionSampler(model, "dpm_2")
+ elif sampler_name == "k_euler_a":
+ sampler = KDiffusionSampler(model, "euler_ancestral")
+ elif sampler_name == "k_euler":
+ sampler = KDiffusionSampler(model, "euler")
+ elif sampler_name == "k_heun":
+ sampler = KDiffusionSampler(model, "heun")
+ elif sampler_name == "k_lms":
+ sampler = KDiffusionSampler(model, "lms")
else:
raise Exception("Unknown sampler: " + sampler_name)
-
-
def init():
- image = init_img.convert('RGB')
+ image = init_img.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -419,25 +519,31 @@ def init():
mask = None
if mask_channel is not None:
mask = np.array(mask_channel).astype(np.float32) / 255.0
- mask = (1 - mask)
+ mask = 1 - mask
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(model.device)
- init_image = 2. * image - 1.
+ init_image = 2.0 * image - 1.0
init_image = init_image.to(model.device)
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
+ init_latent = model.get_first_stage_encoding(
+ model.encode_first_stage(init_image)
+ ) # move to latent space
- return init_latent, mask,
+ return (
+ init_latent,
+ mask,
+ )
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ nonlocal sampler
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
t_enc_steps = t_enc_steps - 1
obliterate = True
- if sampler_name != 'DDIM':
+ if sampler_name != "DDIM":
x0, z_mask = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
@@ -448,36 +554,55 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=xi.device)
- xi = (z_mask * noise) + ((1-z_mask) * xi)
+ xi = (z_mask * noise) + ((1 - z_mask) * xi)
- sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1 :]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
- extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
- 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
+ samples_ddim = K.sampling.__dict__[f"sample_{sampler.get_sampler_name()}"](
+ model_wrap_cfg,
+ xi,
+ sigma_sched,
+ extra_args={
+ "cond": conditioning,
+ "uncond": unconditional_conditioning,
+ "cond_scale": cfg_scale,
+ "mask": z_mask,
+ "x0": x0,
+ "xi": xi,
+ },
+ disable=False,
+ )
else:
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
- z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(model.device))
+ z_enc = sampler.stochastic_encode(
+ x0,
+ torch.tensor([t_enc_steps] * batch_size).to(model.device),
+ )
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=z_enc.device)
- z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
-
- # decode it
- samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
- unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning,
- z_mask=z_mask, x0=x0)
+ z_enc = (z_mask * random) + ((1 - z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(
+ z_enc,
+ conditioning,
+ t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask,
+ x0=x0,
+ )
return samples_ddim
torch_gc()
-
+
if self.load_concepts and self.concepts_dir is not None:
- prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ prompt_tokens = re.findall("<([a-zA-Z0-9-]+)>", prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
@@ -489,8 +614,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
if self.verify_input:
try:
check_prompt_length(model, prompt, self.comments)
- except:
+ except Exception:
import traceback
+
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -502,10 +628,10 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
with torch.no_grad(), precision_scope("cuda"):
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+ prompts = all_prompts[n * batch_size : (n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size : (n + 1) * batch_size]
- uc = model.get_learned_conditioning(len(prompts) * [''])
+ uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@@ -518,41 +644,43 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
x = self.create_random_tensors(shape, seeds=seeds)
init_data = init()
- samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+ samples_ddim = sample(
+ init_data=init_data,
+ x=x,
+ conditioning=c,
+ unconditional_conditioning=uc,
+ sampler_name=sampler_name,
+ )
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
else:
seed = seed_to_int(seed)
- image_dict = {
- "seed": seed
- }
+ image_dict = {"seed": seed}
# Init image is assumed to be a PIL image
- init_img = self.resize_image('resize', init_img, width, height)
- if sampler_name == 'PLMS':
+ init_img = self.resize_image("resize", init_img, width, height)
+ if sampler_name == "PLMS":
sampler = PLMSSampler(self.model)
- elif sampler_name == 'DDIM':
+ elif sampler_name == "DDIM":
sampler = DDIMSampler(self.model)
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(self.model,'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(self.model,'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(self.model,'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(self.model,'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(self.model,'lms')
+ elif sampler_name == "k_dpm_2_a":
+ sampler = KDiffusionSampler(self.model, "dpm_2_ancestral")
+ elif sampler_name == "k_dpm_2":
+ sampler = KDiffusionSampler(self.model, "dpm_2")
+ elif sampler_name == "k_euler_a":
+ sampler = KDiffusionSampler(self.model, "euler_ancestral")
+ elif sampler_name == "k_euler":
+ sampler = KDiffusionSampler(self.model, "euler")
+ elif sampler_name == "k_heun":
+ sampler = KDiffusionSampler(self.model, "heun")
+ elif sampler_name == "k_lms":
+ sampler = KDiffusionSampler(self.model, "lms")
else:
raise Exception("Unknown sampler: " + sampler_name)
-
-
def init():
- image = init_img.convert('RGB')
+ image = init_img.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -565,16 +693,21 @@ def init():
mask = None
if mask_channel is not None:
mask = np.array(mask_channel).astype(np.float32) / 255.0
- mask = (1 - mask)
+ mask = 1 - mask
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(self.model.device)
- init_image = 2. * image - 1.
+ init_image = 2.0 * image - 1.0
init_image = init_image.to(self.model.device)
- init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
+ init_latent = self.model.get_first_stage_encoding(
+ self.model.encode_first_stage(init_image)
+ ) # move to latent space
- return init_latent, mask,
+ return (
+ init_latent,
+ mask,
+ )
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
t_enc_steps = t_enc
@@ -583,7 +716,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
t_enc_steps = t_enc_steps - 1
obliterate = True
- if sampler_name != 'DDIM':
+ if sampler_name != "DDIM":
x0, z_mask = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
@@ -594,36 +727,55 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=xi.device)
- xi = (z_mask * noise) + ((1-z_mask) * xi)
+ xi = (z_mask * noise) + ((1 - z_mask) * xi)
- sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1 :]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
- extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
- 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
+ samples_ddim = K.sampling.__dict__[f"sample_{sampler.get_sampler_name()}"](
+ model_wrap_cfg,
+ xi,
+ sigma_sched,
+ extra_args={
+ "cond": conditioning,
+ "uncond": unconditional_conditioning,
+ "cond_scale": cfg_scale,
+ "mask": z_mask,
+ "x0": x0,
+ "xi": xi,
+ },
+ disable=False,
+ )
else:
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
- z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device))
+ z_enc = sampler.stochastic_encode(
+ x0,
+ torch.tensor([t_enc_steps] * batch_size).to(self.model.device),
+ )
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=z_enc.device)
- z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
-
- # decode it
- samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
- unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning,
- z_mask=z_mask, x0=x0)
+ z_enc = (z_mask * random) + ((1 - z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(
+ z_enc,
+ conditioning,
+ t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask,
+ x0=x0,
+ )
return samples_ddim
torch_gc()
-
+
if self.load_concepts and self.concepts_dir is not None:
- prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ prompt_tokens = re.findall("<([a-zA-Z0-9-]+)>", prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
@@ -635,8 +787,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
if self.verify_input:
try:
check_prompt_length(self.model, prompt, self.comments)
- except:
+ except Exception:
import traceback
+
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -648,10 +801,10 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
with torch.no_grad(), precision_scope("cuda"):
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+ prompts = all_prompts[n * batch_size : (n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size : (n + 1) * batch_size]
- uc = self.model.get_learned_conditioning(len(prompts) * [''])
+ uc = self.model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@@ -664,7 +817,13 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
x = self.create_random_tensors(shape, seeds=seeds)
init_data = init()
- samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+ samples_ddim = sample(
+ init_data=init_data,
+ x=x,
+ conditioning=c,
+ unconditional_conditioning=uc,
+ sampler_name=sampler_name,
+ )
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -674,9 +833,11 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
- filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
+ filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[
+ : 200 - len(full_path)
+ ]
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
image = PIL.Image.fromarray(x_sample)
if self.safety_checker is not None and self.filter_nsfw:
@@ -688,11 +849,11 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
logger.info(f"Image {filename} has NSFW concept")
image = output_images[0]
image = PIL.Image.fromarray(image)
- image_dict['image'] = image
+ image_dict["image"] = image
self.images.append(image_dict)
if save_individual_images:
- path = os.path.join(sample_path, filename + '.' + self.save_extension)
+ path = os.path.join(sample_path, filename + "." + self.save_extension)
success = save_sample(image, filename, sample_path_i, self.save_extension)
if success:
if self.output_file_path:
@@ -705,8 +866,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
{prompt}
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
""".strip()
- self.stats = f'''
- '''
+ self.stats = """
+ """
for comment in self.comments:
self.info += "\n\n" + comment
diff --git a/nataili/inference/compvis/txt2img.py b/nataili/inference/compvis/txt2img.py
index a55bfa00..1970a138 100644
--- a/nataili/inference/compvis/txt2img.py
+++ b/nataili/inference/compvis/txt2img.py
@@ -1,38 +1,51 @@
import os
import re
import sys
-from contextlib import contextmanager, nullcontext
+from contextlib import nullcontext
import numpy as np
import PIL
import torch
from einops import rearrange
+from slugify import slugify
+from transformers import CLIPFeatureExtractor
+
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.kdiffusion import KDiffusionSampler
from ldm.models.diffusion.plms import PLMSSampler
-from transformers import CLIPFeatureExtractor
+from nataili.util import logger
from nataili.util.cache import torch_gc
from nataili.util.check_prompt_length import check_prompt_length
from nataili.util.get_next_sequence_number import get_next_sequence_number
-from nataili.util.image_grid import image_grid
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
from nataili.util.save_sample import save_sample
from nataili.util.seed_to_int import seed_to_int
-from slugify import slugify
-from nataili.util import logger
+
try:
from nataili.util.voodoo import load_from_plasma, performance
except ModuleNotFoundError as e:
from nataili import disable_voodoo
+
if not disable_voodoo.active:
raise e
class txt2img:
- def __init__(self, model, device, output_dir, save_extension='jpg',
- output_file_path=False, load_concepts=False, concepts_dir=None,
- verify_input=True, auto_cast=True, filter_nsfw=False, safety_checker=None,
- disable_voodoo=False):
+ def __init__(
+ self,
+ model,
+ device,
+ output_dir,
+ save_extension="jpg",
+ output_file_path=False,
+ load_concepts=False,
+ concepts_dir=None,
+ verify_input=True,
+ auto_cast=True,
+ filter_nsfw=False,
+ safety_checker=None,
+ disable_voodoo=False,
+ ):
self.model = model
self.output_dir = output_dir
self.output_file_path = output_file_path
@@ -44,8 +57,8 @@ def __init__(self, model, device, output_dir, save_extension='jpg',
self.device = device
self.comments = []
self.output_images = []
- self.info = ''
- self.stats = ''
+ self.info = ""
+ self.stats = ""
self.images = []
self.filter_nsfw = filter_nsfw
self.safety_checker = safety_checker
@@ -70,23 +83,41 @@ def process_prompt_tokens(self, prompt_tokens, model):
# tokenizer = model.cond_stage_model.tokenizer
# text_encoder = model.cond_stage_model.transformer
# diffusers codebase
- #tokenizer = pipe.tokenizer
- #text_encoder = pipe.text_encoder
+ # tokenizer = pipe.tokenizer
+ # text_encoder = pipe.text_encoder
- ext = ('.pt', '.bin')
+ ext = (".pt", ".bin")
for token_name in prompt_tokens:
- embedding_path = os.path.join(self.concepts_dir, token_name)
+ embedding_path = os.path.join(self.concepts_dir, token_name)
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
- load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", model.cond_stage_model.transformer, model.cond_stage_model.tokenizer, f"<{token_name}>")
+ load_learned_embed_in_clip(
+ f"{os.path.join(embedding_path, files)}",
+ model.cond_stage_model.transformer,
+ model.cond_stage_model.tokenizer,
+ f"<{token_name}>",
+ )
else:
print(f"Concept {token_name} not found in {self.concepts_dir}")
return
@performance
- def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
- height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
+ def generate(
+ self,
+ prompt: str,
+ ddim_steps=50,
+ sampler_name="k_lms",
+ n_iter=1,
+ batch_size=1,
+ cfg_scale=7.5,
+ seed=None,
+ height=512,
+ width=512,
+ save_individual_images: bool = True,
+ save_grid: bool = True,
+ ddim_eta: float = 0.0,
+ ):
if not self.disable_voodoo:
with load_from_plasma(self.model, self.device) as model:
# not needed?
@@ -94,43 +125,46 @@ def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, b
model.eval()
seed = seed_to_int(seed)
- image_dict = {
- "seed": seed
- }
- negprompt = ''
- if '###' in prompt:
- prompt, negprompt = prompt.split('###', 1)
+ image_dict = {"seed": seed}
+ negprompt = ""
+ if "###" in prompt:
+ prompt, negprompt = prompt.split("###", 1)
prompt = prompt.strip()
negprompt = negprompt.strip()
- if sampler_name == 'PLMS':
+ if sampler_name == "PLMS":
sampler = PLMSSampler(model)
- elif sampler_name == 'DDIM':
+ elif sampler_name == "DDIM":
sampler = DDIMSampler(model)
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(model,'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(model,'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(model,'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(model,'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(model,'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(model,'lms')
+ elif sampler_name == "k_dpm_2_a":
+ sampler = KDiffusionSampler(model, "dpm_2_ancestral")
+ elif sampler_name == "k_dpm_2":
+ sampler = KDiffusionSampler(model, "dpm_2")
+ elif sampler_name == "k_euler_a":
+ sampler = KDiffusionSampler(model, "euler_ancestral")
+ elif sampler_name == "k_euler":
+ sampler = KDiffusionSampler(model, "euler")
+ elif sampler_name == "k_heun":
+ sampler = KDiffusionSampler(model, "heun")
+ elif sampler_name == "k_lms":
+ sampler = KDiffusionSampler(model, "lms")
else:
raise Exception("Unknown sampler: " + sampler_name)
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = sampler.sample(
+ S=ddim_steps,
+ conditioning=conditioning,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ x_T=x,
+ )
return samples_ddim
torch_gc()
-
+
if self.load_concepts and self.concepts_dir is not None:
- prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ prompt_tokens = re.findall("<([a-zA-Z0-9-]+)>", prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens, model)
@@ -142,8 +176,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
if self.verify_input:
try:
check_prompt_length(model, prompt, self.comments)
- except:
+ except Exception:
import traceback
+
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -153,8 +188,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
with torch.no_grad():
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+ prompts = all_prompts[n * batch_size : (n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size : (n + 1) * batch_size]
uc = model.get_learned_conditioning(len(prompts) * [negprompt])
@@ -169,50 +204,60 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
x = self.create_random_tensors(shape, seeds=seeds)
- samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+ samples_ddim = sample(
+ init_data=None,
+ x=x,
+ conditioning=c,
+ unconditional_conditioning=uc,
+ sampler_name=sampler_name,
+ )
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
else:
seed = seed_to_int(seed)
- image_dict = {
- "seed": seed
- }
- negprompt = ''
- if '###' in prompt:
- prompt, negprompt = prompt.split('###', 1)
+ image_dict = {"seed": seed}
+ negprompt = ""
+ if "###" in prompt:
+ prompt, negprompt = prompt.split("###", 1)
prompt = prompt.strip()
negprompt = negprompt.strip()
- if sampler_name == 'PLMS':
+ if sampler_name == "PLMS":
sampler = PLMSSampler(self.model)
- elif sampler_name == 'DDIM':
+ elif sampler_name == "DDIM":
sampler = DDIMSampler(self.model)
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(self.model,'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(self.model,'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(self.model,'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(self.model,'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(self.model,'lms')
+ elif sampler_name == "k_dpm_2_a":
+ sampler = KDiffusionSampler(self.model, "dpm_2_ancestral")
+ elif sampler_name == "k_dpm_2":
+ sampler = KDiffusionSampler(self.model, "dpm_2")
+ elif sampler_name == "k_euler_a":
+ sampler = KDiffusionSampler(self.model, "euler_ancestral")
+ elif sampler_name == "k_euler":
+ sampler = KDiffusionSampler(self.model, "euler")
+ elif sampler_name == "k_heun":
+ sampler = KDiffusionSampler(self.model, "heun")
+ elif sampler_name == "k_lms":
+ sampler = KDiffusionSampler(self.model, "lms")
else:
raise Exception("Unknown sampler: " + sampler_name)
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning, x_T=x)
+ nonlocal sampler
+ samples_ddim, _ = sampler.sample(
+ S=ddim_steps,
+ conditioning=conditioning,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ x_T=x,
+ )
return samples_ddim
torch_gc()
-
+
if self.load_concepts and self.concepts_dir is not None:
- prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ prompt_tokens = re.findall("<([a-zA-Z0-9-]+)>", prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
@@ -224,8 +269,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
if self.verify_input:
try:
check_prompt_length(self.model, prompt, self.comments)
- except:
+ except Exception:
import traceback
+
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -237,8 +283,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
with torch.no_grad(), precision_scope("cuda"):
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+ prompts = all_prompts[n * batch_size : (n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size : (n + 1) * batch_size]
uc = self.model.get_learned_conditioning(len(prompts) * [negprompt])
@@ -253,7 +299,13 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
x = self.create_random_tensors(shape, seeds=seeds)
- samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+ samples_ddim = sample(
+ init_data=None,
+ x=x,
+ conditioning=c,
+ unconditional_conditioning=uc,
+ sampler_name=sampler_name,
+ )
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -263,9 +315,11 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
- filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
+ filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[
+ : 200 - len(full_path)
+ ]
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
image = PIL.Image.fromarray(x_sample)
if self.safety_checker is not None and self.filter_nsfw:
@@ -277,11 +331,11 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
logger.info(f"Image {filename} has NSFW concept")
image = output_images[0]
image = PIL.Image.fromarray(image)
- image_dict['image'] = image
+ image_dict["image"] = image
self.images.append(image_dict)
if save_individual_images:
- path = os.path.join(sample_path, filename + '.' + self.save_extension)
+ path = os.path.join(sample_path, filename + "." + self.save_extension)
success = save_sample(image, filename, sample_path_i, self.save_extension)
if success:
if self.output_file_path:
@@ -295,8 +349,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
{prompt}
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
""".strip()
- self.stats = f'''
- '''
+ self.stats = """
+ """
for comment in self.comments:
self.info += "\n\n" + comment
@@ -305,4 +359,4 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
del sampler
- return
\ No newline at end of file
+ return
diff --git a/nataili/inference/diffusers/inpainting.py b/nataili/inference/diffusers/inpainting.py
index b210eacf..94d1afbf 100644
--- a/nataili/inference/diffusers/inpainting.py
+++ b/nataili/inference/diffusers/inpainting.py
@@ -1,23 +1,33 @@
import os
import re
-import sys
+from contextlib import nullcontext
+
import PIL
import PIL.ImageOps
import torch
-from contextlib import nullcontext
from slugify import slugify
-from diffusers import StableDiffusionInpaintPipeline
+from nataili.util import logger
from nataili.util.cache import torch_gc
-from nataili.util.check_prompt_length import check_prompt_length
from nataili.util.get_next_sequence_number import get_next_sequence_number
from nataili.util.save_sample import save_sample
from nataili.util.seed_to_int import seed_to_int
-from nataili.util import logger
+
class inpainting:
- def __init__(self, pipe, device, output_dir, save_extension='jpg', output_file_path=False, load_concepts=False,
- concepts_dir=None, verify_input=True, auto_cast=True, filter_nsfw = False):
+ def __init__(
+ self,
+ pipe,
+ device,
+ output_dir,
+ save_extension="jpg",
+ output_file_path=False,
+ load_concepts=False,
+ concepts_dir=None,
+ verify_input=True,
+ auto_cast=True,
+ filter_nsfw=False,
+ ):
self.output_dir = output_dir
self.output_file_path = output_file_path
self.save_extension = save_extension
@@ -29,13 +39,13 @@ def __init__(self, pipe, device, output_dir, save_extension='jpg', output_file_p
self.device = device
self.comments = []
self.output_images = []
- self.info = ''
- self.stats = ''
+ self.info = ""
+ self.stats = ""
self.images = []
self.filter_nsfw = filter_nsfw
def resize_image(self, resize_mode, im, width, height):
- LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
+ LANCZOS = PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, "Resampling") else PIL.Image.LANCZOS
if resize_mode == "resize":
res = im.resize((width, height), resample=LANCZOS)
elif resize_mode == "crop":
@@ -61,42 +71,72 @@ def resize_image(self, resize_mode, im, width, height):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ res.paste(
+ resized.resize((width, fill_height), box=(0, 0, width, 0)),
+ box=(0, 0),
+ )
+ res.paste(
+ resized.resize(
+ (width, fill_height),
+ box=(0, resized.height, width, resized.height),
+ ),
+ box=(0, fill_height + src_h),
+ )
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ res.paste(
+ resized.resize((fill_width, height), box=(0, 0, 0, height)),
+ box=(0, 0),
+ )
+ res.paste(
+ resized.resize(
+ (fill_width, height),
+ box=(resized.width, 0, resized.width, height),
+ ),
+ box=(fill_width + src_w, 0),
+ )
return res
- def generate(self, prompt: str, inpaint_img=None, inpaint_mask=None, ddim_steps=50, n_iter=1, batch_size=1,
- cfg_scale=7.5, seed=None, height=512, width=512, save_individual_images: bool = True):
+ def generate(
+ self,
+ prompt: str,
+ inpaint_img=None,
+ inpaint_mask=None,
+ ddim_steps=50,
+ n_iter=1,
+ batch_size=1,
+ cfg_scale=7.5,
+ seed=None,
+ height=512,
+ width=512,
+ save_individual_images: bool = True,
+ ):
safety_checker = None
if not self.filter_nsfw:
safety_checker = self.pipe.safety_checker
self.pipe.safety_checker = None
seed = seed_to_int(seed)
- inpaint_img = self.resize_image('resize', inpaint_img, width, height)
+ inpaint_img = self.resize_image("resize", inpaint_img, width, height)
# mask information has been transferred in the Alpha channel of the inpaint image
logger.debug(inpaint_mask)
if inpaint_mask is None:
- try:
- red, green, blue, alpha = inpaint_img.split()
- except ValueError:
- raise Exception("inpainting image doesn't have an alpha channel.")
-
- inpaint_mask = alpha
- inpaint_mask = PIL.ImageOps.invert(inpaint_mask)
+ try:
+ red, green, blue, alpha = inpaint_img.split()
+ except ValueError:
+ raise Exception("inpainting image doesn't have an alpha channel.")
+
+ inpaint_mask = alpha
+ inpaint_mask = PIL.ImageOps.invert(inpaint_mask)
else:
- inpaint_mask = self.resize_image('resize', inpaint_mask, width, height)
+ inpaint_mask = self.resize_image("resize", inpaint_mask, width, height)
torch_gc()
if self.load_concepts and self.concepts_dir is not None:
- prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ prompt_tokens = re.findall("<([a-zA-Z0-9-]+)>", prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
@@ -111,61 +151,58 @@ def generate(self, prompt: str, inpaint_img=None, inpaint_mask=None, ddim_steps=
precision_scope = torch.autocast if self.auto_cast else nullcontext
with torch.no_grad(), precision_scope("cuda"):
- for n in range(batch_size):
- print(f"Iteration: {n+1}/{batch_size}")
-
- prompt = all_prompts[n]
- seed = all_seeds[n]
- print("prompt: " + prompt + ", seed: " + str(seed))
-
- generator = torch.Generator(device=self.device).manual_seed(seed)
-
- x_samples = self.pipe(
- prompt=prompt,
- image=inpaint_img,
- mask_image=inpaint_mask,
- guidance_scale=cfg_scale,
- num_inference_steps=ddim_steps,
- generator=generator,
- num_images_per_prompt=n_iter,
- width=width,
- height=height
- ).images
-
- for i, x_sample in enumerate(x_samples):
- image_dict = {
- "seed": seed,
- "image": x_sample
- }
-
- self.images.append(image_dict)
- if safety_checker:
- self.pipe.safety_checker = safety_checker
-
- if save_individual_images:
- sanitized_prompt = slugify(prompt)
- sample_path_i = sample_path
- base_count = get_next_sequence_number(sample_path_i)
- full_path = os.path.join(os.getcwd(), sample_path)
- filename = f"{base_count:05}-{ddim_steps}_{seed}_{sanitized_prompt}"[:200-len(full_path)]
-
- path = os.path.join(sample_path, filename + '.' + self.save_extension)
- success = save_sample(x_sample, filename, sample_path_i, self.save_extension)
-
- if success:
- if self.output_file_path:
- self.output_images.append(path)
- else:
- self.output_images.append(x_sample)
- else:
- return
+ for n in range(batch_size):
+ print(f"Iteration: {n+1}/{batch_size}")
+
+ prompt = all_prompts[n]
+ seed = all_seeds[n]
+ print("prompt: " + prompt + ", seed: " + str(seed))
+
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ x_samples = self.pipe(
+ prompt=prompt,
+ image=inpaint_img,
+ mask_image=inpaint_mask,
+ guidance_scale=cfg_scale,
+ num_inference_steps=ddim_steps,
+ generator=generator,
+ num_images_per_prompt=n_iter,
+ width=width,
+ height=height,
+ ).images
+
+ for i, x_sample in enumerate(x_samples):
+ image_dict = {"seed": seed, "image": x_sample}
+
+ self.images.append(image_dict)
+ if safety_checker:
+ self.pipe.safety_checker = safety_checker
+
+ if save_individual_images:
+ sanitized_prompt = slugify(prompt)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ full_path = os.path.join(os.getcwd(), sample_path)
+ filename = f"{base_count:05}-{ddim_steps}_{seed}_{sanitized_prompt}"[: 200 - len(full_path)]
+
+ path = os.path.join(sample_path, filename + "." + self.save_extension)
+ success = save_sample(x_sample, filename, sample_path_i, self.save_extension)
+
+ if success:
+ if self.output_file_path:
+ self.output_images.append(path)
+ else:
+ self.output_images.append(x_sample)
+ else:
+ return
self.info = f"""
{prompt}
Steps: {ddim_steps}, CFG scale: {cfg_scale}, Seed: {seed}
""".strip()
- self.stats = f'''
- '''
+ self.stats = """
+ """
for comment in self.comments:
self.info += "\n\n" + comment
diff --git a/nataili/model_manager.py b/nataili/model_manager.py
index 8fef0bcf..31149de2 100644
--- a/nataili/model_manager.py
+++ b/nataili/model_manager.py
@@ -1,61 +1,64 @@
-import os
+import hashlib
import json
+import os
import shutil
import zipfile
-import requests
+
+import clip
import git
+import open_clip
+import requests
import torch
-import hashlib
-from ldm.util import instantiate_from_config
-from omegaconf import OmegaConf
-from transformers import logging
-
from basicsr.archs.rrdbnet_arch import RRDBNet
+from diffusers import StableDiffusionInpaintPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from gfpgan import GFPGANer
+from omegaconf import OmegaConf
from realesrgan import RealESRGANer
-from ldm.models.blip import blip_decoder
from tqdm import tqdm
-import open_clip
-import clip
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from diffusers import StableDiffusionInpaintPipeline
+from transformers import logging
+
+from ldm.models.blip import blip_decoder
+from ldm.util import instantiate_from_config
+
try:
- from nataili.util.voodoo import push_model_to_plasma, load_from_plasma
+ from nataili.util.voodoo import push_model_to_plasma
except ModuleNotFoundError as e:
from nataili import disable_voodoo
+
if not disable_voodoo.active:
raise e
-
-from nataili.util.cache import torch_gc
from nataili.util import logger
+from nataili.util.cache import torch_gc
logging.set_verbosity_error()
-models = json.load(open('./db.json'))
-dependencies = json.load(open('./db_dep.json'))
+models = json.load(open("./db.json"))
+dependencies = json.load(open("./db_dep.json"))
remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json"
-class ModelManager():
- def __init__(self, hf_auth=None, download=True,disable_voodoo=True):
+
+class ModelManager:
+ def __init__(self, hf_auth=None, download=True, disable_voodoo=True):
if download:
try:
logger.init("Model Reference", status="Downloading")
r = requests.get(remote_models)
self.models = r.json()
r = requests.get(remote_dependencies)
- self.dependencies = json.load(open('./db_dep.json'))
+ self.dependencies = json.load(open("./db_dep.json"))
logger.init_ok("Model Reference", status="OK")
- except:
+ except Exception:
logger.init_err("Model Reference", status="Download Error")
- self.models = json.load(open('./db.json'))
- self.dependencies = json.load(open('./db_dep.json'))
+ self.models = json.load(open("./db.json"))
+ self.dependencies = json.load(open("./db_dep.json"))
logger.init_warn("Model Reference", status="Local")
else:
- self.models = json.load(open('./db.json'))
- self.dependencies = json.load(open('./db_dep.json'))
+ self.models = json.load(open("./db.json"))
+ self.dependencies = json.load(open("./db_dep.json"))
self.available_models = []
self.tainted_models = []
self.available_dependencies = []
@@ -78,11 +81,11 @@ def init(self):
self.available_models = models_available
if self.hf_auth is not None:
- if 'username' not in self.hf_auth and 'password' not in self.hf_auth:
- raise ValueError('hf_auth must contain username and password')
+ if "username" not in self.hf_auth and "password" not in self.hf_auth:
+ raise ValueError("hf_auth must contain username and password")
else:
- if self.hf_auth['username'] == '' or self.hf_auth['password'] == '':
- raise ValueError('hf_auth must contain username and password')
+ if self.hf_auth["username"] == "" or self.hf_auth["password"] == "":
+ raise ValueError("hf_auth must contain username and password")
return True
def set_authentication(self, hf_auth=None):
@@ -91,7 +94,7 @@ def set_authentication(self, hf_auth=None):
return
self.hf_auth = hf_auth
if hf_auth:
- os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_auth.get('password')
+ os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_auth.get("password")
def has_authentication(self):
if self.hf_auth:
@@ -100,11 +103,11 @@ def has_authentication(self):
def get_model(self, model_name):
return self.models.get(model_name)
-
+
def get_filtered_models(self, **kwargs):
- '''Get all model names.
+ """Get all model names.
Can filter based on metadata of the model reference db
- '''
+ """
filtered_models = self.models
for keyword in kwargs:
iterating_models = filtered_models.copy()
@@ -118,52 +121,52 @@ def get_filtered_models(self, **kwargs):
def get_filtered_model_names(self, **kwargs):
filtered_models = self.get_filtered_models(**kwargs)
return list(filtered_models.keys())
-
+
def get_dependency(self, dependency_name):
return self.dependencies[dependency_name]
def get_model_files(self, model_name):
- if self.models[model_name]['type'] == 'diffusers':
+ if self.models[model_name]["type"] == "diffusers":
return []
- return self.models[model_name]['config']['files']
-
+ return self.models[model_name]["config"]["files"]
+
def get_dependency_files(self, dependency_name):
- return self.dependencies[dependency_name]['config']['files']
-
+ return self.dependencies[dependency_name]["config"]["files"]
+
def get_model_download(self, model_name):
- return self.models[model_name]['config']['download']
-
+ return self.models[model_name]["config"]["download"]
+
def get_dependency_download(self, dependency_name):
- return self.dependencies[dependency_name]['config']['download']
-
+ return self.dependencies[dependency_name]["config"]["download"]
+
def get_available_models(self):
return self.available_models
-
+
def get_available_dependencies(self):
return self.available_dependencies
-
+
def get_loaded_models(self):
return self.loaded_models
-
+
def get_loaded_models_names(self):
return list(self.loaded_models.keys())
-
+
def get_loaded_model(self, model_name):
return self.loaded_models[model_name]
-
+
def unload_model(self, model_name):
if model_name in self.loaded_models:
del self.loaded_models[model_name]
return True
return False
-
+
def unload_all_models(self):
for model in self.loaded_models:
del self.loaded_models[model]
return True
-
- def taint_model(self,model_name):
- '''Marks a model as not valid by remiving it from available_models'''
+
+ def taint_model(self, model_name):
+ """Marks a model as not valid by remiving it from available_models"""
if model_name in self.available_models:
self.available_models.remove(model_name)
self.tainted_models.append(model_name)
@@ -172,7 +175,7 @@ def taint_models(self, models):
for model in models:
self.taint_model(model)
- def load_model_from_config(self, model_path='', config_path='', map_location="cpu"):
+ def load_model_from_config(self, model_path="", config_path="", map_location="cpu"):
config = OmegaConf.load(config_path)
pl_sd = torch.load(model_path, map_location=map_location)
if "global_step" in pl_sd:
@@ -184,133 +187,170 @@ def load_model_from_config(self, model_path='', config_path='', map_location="cp
del pl_sd, sd, m, u
return model
- def load_ckpt(self, model_name='', precision='half', gpu_id=0):
- ckpt_path = self.get_model_files(model_name)[0]['path']
- config_path = self.get_model_files(model_name)[1]['path']
+ def load_ckpt(self, model_name="", precision="half", gpu_id=0):
+ ckpt_path = self.get_model_files(model_name)[0]["path"]
+ config_path = self.get_model_files(model_name)[1]["path"]
model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path)
device = torch.device(f"cuda:{gpu_id}")
- model = model if precision == 'full' else model.half()
+ model = model if precision == "full" else model.half()
if not self.disable_voodoo:
logger.debug(f"Doing voodoo on {model_name}")
model = push_model_to_plasma(model) if isinstance(model, torch.nn.Module) else model
else:
- model = (model if precision=='full' else model.half()).to(device)
+ model = (model if precision == "full" else model.half()).to(device)
torch_gc()
- return {'model': model, 'device': device}
-
- def load_realesrgan(self, model_name='', precision='half', gpu_id=0):
-
- RealESRGAN_models = {
- 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
- 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- }
+ return {"model": model, "device": device}
- model_path = self.get_model_files(model_name)[0]['path']
- device = torch.device(f"cuda:{gpu_id}")
- model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']],
- pre_pad=0, half=True if precision == 'half' else False, device=device)
- return {'model': model, 'device': device}
+ def load_realesrgan(self, model_name="", precision="half", gpu_id=0):
- def load_gfpgan(self, model_name='', gpu_id=0):
-
- model_path = self.get_model_files(model_name)[0]['path']
+ RealESRGAN_models = {
+ "RealESRGAN_x4plus": RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=4,
+ ),
+ "RealESRGAN_x4plus_anime_6B": RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=6,
+ num_grow_ch=32,
+ scale=4,
+ ),
+ }
+
+ model_path = self.get_model_files(model_name)[0]["path"]
device = torch.device(f"cuda:{gpu_id}")
- model = GFPGANer(model_path=model_path, upscale=1, arch='clean',
- channel_multiplier=2, bg_upsampler=None, device=device)
- return {'model': model, 'device': device}
-
- def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'):
+ model = RealESRGANer(
+ scale=2,
+ model_path=model_path,
+ model=RealESRGAN_models[models[model_name]["name"]],
+ pre_pad=0,
+ half=True if precision == "half" else False,
+ device=device,
+ )
+ return {"model": model, "device": device}
+
+ def load_gfpgan(self, model_name="", gpu_id=0):
+
+ model_path = self.get_model_files(model_name)[0]["path"]
+ device = torch.device(f"cuda:{gpu_id}")
+ model = GFPGANer(
+ model_path=model_path,
+ upscale=1,
+ arch="clean",
+ channel_multiplier=2,
+ bg_upsampler=None,
+ device=device,
+ )
+ return {"model": model, "device": device}
+
+ def load_blip(
+ self,
+ model_name="",
+ precision="half",
+ gpu_id=0,
+ blip_image_eval_size=512,
+ vit="base",
+ ):
# vit = 'base' or 'large'
- model_path = self.get_model_files(model_name)[0]['path']
+ model_path = self.get_model_files(model_name)[0]["path"]
device = torch.device(f"cuda:{gpu_id}")
- model = blip_decoder(pretrained=model_path,
- med_config="configs/blip/med_config.json",
- image_size=blip_image_eval_size, vit=vit)
+ model = blip_decoder(
+ pretrained=model_path,
+ med_config="configs/blip/med_config.json",
+ image_size=blip_image_eval_size,
+ vit=vit,
+ )
model = model.eval()
- model = (model if precision=='full' else model.half()).to(device)
- return {'model': model, 'device': device}
+ model = (model if precision == "full" else model.half()).to(device)
+ return {"model": model, "device": device}
- def load_open_clip(self, model_name='', precision='half', gpu_id=0):
- pretrained = self.get_model(model_name)['pretrained_name']
+ def load_open_clip(self, model_name="", precision="half", gpu_id=0):
+ pretrained = self.get_model(model_name)["pretrained_name"]
device = torch.device(f"cuda:{gpu_id}")
- model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip')
+ model, _, preprocesses = open_clip.create_model_and_transforms(
+ model_name, pretrained=pretrained, cache_dir="models/clip"
+ )
model = model.eval()
- model = (model if precision=='full' else model.half()).to(device)
- return {'model': model, 'device': device, 'preprocesses': preprocesses}
+ model = (model if precision == "full" else model.half()).to(device)
+ return {"model": model, "device": device, "preprocesses": preprocesses}
- def load_clip(self, model_name='', precision='half', gpu_id=0):
+ def load_clip(self, model_name="", precision="half", gpu_id=0):
device = torch.device(f"cuda:{gpu_id}")
- model, preprocesses = clip.load(model_name, device=device, download_root='models/clip')
+ model, preprocesses = clip.load(model_name, device=device, download_root="models/clip")
model = model.eval()
- model = (model if precision=='full' else model.half()).to(device)
- return {'model': model, 'device': device, 'preprocesses': preprocesses}
+ model = (model if precision == "full" else model.half()).to(device)
+ return {"model": model, "device": device, "preprocesses": preprocesses}
- def load_diffuser(self, model_name=''):
- model_path = self.models[model_name]['hf_path']
+ def load_diffuser(self, model_name=""):
+ model_path = self.models[model_name]["hf_path"]
pipe = StableDiffusionInpaintPipeline.from_pretrained(
- model_path,
- revision="fp16",
- torch_dtype=torch.float16,
- use_auth_token=self.models[model_name]['hf_auth'],
+ model_path,
+ revision="fp16",
+ torch_dtype=torch.float16,
+ use_auth_token=self.models[model_name]["hf_auth"],
).to("cuda")
- return {'model': pipe, 'device': "cuda"}
+ return {"model": pipe, "device": "cuda"}
- def load_model(self, model_name='', precision='half', gpu_id=0):
+ def load_model(self, model_name="", precision="half", gpu_id=0):
if model_name not in self.available_models:
return False
- if self.models[model_name]['type'] == 'ckpt':
+ if self.models[model_name]["type"] == "ckpt":
self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id)
return True
- elif self.models[model_name]['type'] == 'realesrgan':
+ elif self.models[model_name]["type"] == "realesrgan":
self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id)
return True
- elif self.models[model_name]['type'] == 'gfpgan':
+ elif self.models[model_name]["type"] == "gfpgan":
self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id)
return True
- elif self.models[model_name]['type'] == 'blip':
- self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base')
+ elif self.models[model_name]["type"] == "blip":
+ self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, "base")
return True
- elif self.models[model_name]['type'] == 'open_clip':
+ elif self.models[model_name]["type"] == "open_clip":
self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id)
return True
- elif self.models[model_name]['type'] == 'clip':
+ elif self.models[model_name]["type"] == "clip":
self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id)
return True
- elif self.models[model_name]['type'] == 'diffusers':
+ elif self.models[model_name]["type"] == "diffusers":
self.loaded_models[model_name] = self.load_diffuser(model_name)
return True
- elif self.models[model_name]['type'] == 'safety_checker':
+ elif self.models[model_name]["type"] == "safety_checker":
self.loaded_models[model_name] = self.load_safety_checker(model_name, gpu_id)
return True
else:
return False
- def load_safety_checker(self, model_name='', gpu_id=0):
- model_path = os.path.dirname(self.get_model_files(model_name)[0]['path'])
+ def load_safety_checker(self, model_name="", gpu_id=0):
+ model_path = os.path.dirname(self.get_model_files(model_name)[0]["path"])
device = torch.device(f"cuda:{gpu_id}")
model = StableDiffusionSafetyChecker.from_pretrained(model_path)
model = model.eval().to(device)
- return {'model': model, 'device': device}
+ return {"model": model, "device": device}
def validate_model(self, model_name):
files = self.get_model_files(model_name)
- all_ok = True
for file_details in files:
- if not self.check_file_available(file_details['path']):
+ if not self.check_file_available(file_details["path"]):
return False
if not self.validate_file(file_details):
return False
return True
def validate_file(self, file_details):
- if 'md5sum' in file_details:
- file_name = file_details['path']
+ if "md5sum" in file_details:
+ file_name = file_details["path"]
logger.debug(f"Getting md5sum of {file_name}")
- with open(file_name, 'rb') as file_to_check:
+ with open(file_name, "rb") as file_to_check:
file_hash = hashlib.md5()
while chunk := file_to_check.read(8192):
file_hash.update(chunk)
- if file_details['md5sum'] != file_hash.hexdigest():
+ if file_details["md5sum"] != file_hash.hexdigest():
return False
return True
@@ -320,25 +360,29 @@ def check_file_available(self, file_path):
def check_available(self, files):
available = True
for file in files:
- if not self.check_file_available(file['path']):
+ if not self.check_file_available(file["path"]):
available = False
return available
def download_file(self, url, file_path):
# make directory
os.makedirs(os.path.dirname(file_path), exist_ok=True)
- pbar_desc = file_path.split('/')[-1]
+ pbar_desc = file_path.split("/")[-1]
r = requests.get(url, stream=True, allow_redirects=True)
- with open(file_path, 'wb') as f:
+ with open(file_path, "wb") as f:
with tqdm(
# all optional kwargs
- unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
- desc=pbar_desc, total=int(r.headers.get('content-length', 0))
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ desc=pbar_desc,
+ total=int(r.headers.get("content-length", 0)),
) as pbar:
- for chunk in r.iter_content(chunk_size=16*1024):
+ for chunk in r.iter_content(chunk_size=16 * 1024):
if chunk:
f.write(chunk)
- pbar.update(len(chunk))
+ pbar.update(len(chunk))
def download_model(self, model_name):
if model_name in self.available_models:
@@ -346,58 +390,68 @@ def download_model(self, model_name):
return True
download = self.get_model_download(model_name)
files = self.get_model_files(model_name)
- for i in range(len(download)):
- file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path']
-
- if 'file_url' in download[i]:
- download_url = download[i]['file_url']
- if 'hf_auth' in download[i]:
- username = self.hf_auth['username']
- password = self.hf_auth['password']
+ for i in range(len(download)):
+ file_path = (
+ f"{download[i]['file_path']}/{download[i]['file_name']}"
+ if "file_path" in download[i]
+ else files[i]["path"]
+ )
+
+ if "file_url" in download[i]:
+ download_url = download[i]["file_url"]
+ if "hf_auth" in download[i]:
+ username = self.hf_auth["username"]
+ password = self.hf_auth["password"]
download_url = download_url.format(username=username, password=password)
- if 'file_name' in download[i]:
- download_name = download[i]['file_name']
- if 'file_path' in download[i]:
- download_path = download[i]['file_path']
-
- if 'manual' in download[i]:
- logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...")
- input('')
+ if "file_name" in download[i]:
+ download_name = download[i]["file_name"]
+ if "file_path" in download[i]:
+ download_path = download[i]["file_path"]
+
+ if "manual" in download[i]:
+ logger.warning(
+ f"The model {model_name} requires manual download from {download_url}. "
+ f"Please place it in {download_path}/{download_name} then press ENTER to continue..."
+ )
+ input("")
continue
# TODO: simplify
if "file_content" in download[i]:
- file_content = download[i]['file_content']
+ file_content = download[i]["file_content"]
logger.info(f"writing {file_content} to {file_path}")
# make directory download_path
os.makedirs(download_path, exist_ok=True)
# write file_content to download_path/download_name
- with open(os.path.join(download_path, download_name), 'w') as f:
+ with open(os.path.join(download_path, download_name), "w") as f:
f.write(file_content)
- elif 'symlink' in download[i]:
+ elif "symlink" in download[i]:
logger.info(f"symlink {file_path} to {download[i]['symlink']}")
- symlink = download[i]['symlink']
+ symlink = download[i]["symlink"]
# make directory symlink
os.makedirs(download_path, exist_ok=True)
# make symlink from download_path/download_name to symlink
os.symlink(symlink, os.path.join(download_path, download_name))
- elif 'git' in download[i]:
+ elif "git" in download[i]:
logger.info(f"git clone {download_url} to {file_path}")
# make directory download_path
os.makedirs(file_path, exist_ok=True)
git.Git(file_path).clone(download_url)
- if 'post_process' in download[i]:
- for post_process in download[i]['post_process']:
- if 'delete' in post_process:
+ if "post_process" in download[i]:
+ for post_process in download[i]["post_process"]:
+ if "delete" in post_process:
# delete folder post_process['delete']
logger.info(f"delete {post_process['delete']}")
try:
- shutil.rmtree(post_process['delete'])
+ shutil.rmtree(post_process["delete"])
except PermissionError as e:
- logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.")
+ logger.error(
+ f"[!] Something went wrong while deleting the `{post_process['delete']}`. "
+ "Please delete it manually."
+ )
logger.error("PermissionError: ", e)
else:
if not self.check_file_available(file_path) or model_name in self.tainted_models:
- logger.debug(f'Downloading {download_url} to {file_path}')
+ logger.debug(f"Downloading {download_url} to {file_path}")
self.download_file(download_url, file_path)
if not self.validate_model(model_name):
return False
@@ -405,7 +459,7 @@ def download_model(self, model_name):
self.tainted_models.remove(model_name)
self.init()
return True
-
+
def download_dependency(self, dependency_name):
if dependency_name in self.available_dependencies:
logger.info(f"{dependency_name} is already installed.")
@@ -416,25 +470,25 @@ def download_dependency(self, dependency_name):
if "git" in download[i]:
logger.warning("git download not implemented yet")
break
-
- file_path = files[i]['path']
- if 'file_url' in download[i]:
- download_url = download[i]['file_url']
- if 'file_name' in download[i]:
- download_name = download[i]['file_name']
- if 'file_path' in download[i]:
- download_path = download[i]['file_path']
+
+ file_path = files[i]["path"]
+ if "file_url" in download[i]:
+ download_url = download[i]["file_url"]
+ if "file_name" in download[i]:
+ download_name = download[i]["file_name"]
+ if "file_path" in download[i]:
+ download_path = download[i]["file_path"]
logger.debug(download_name)
if "unzip" in download[i]:
- zip_path = f'temp/{download_name}.zip'
+ zip_path = f"temp/{download_name}.zip"
# os dirname zip_path
# mkdir temp
os.makedirs("temp", exist_ok=True)
self.download_file(download_url, zip_path)
logger.info(f"unzip {zip_path}")
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
- zip_ref.extractall('temp/')
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ zip_ref.extractall("temp/")
# move temp/sd-concepts-library-main/sd-concepts-library to download_path
logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}")
shutil.move(f"temp/{download_name}-main/{download_name}", download_path)
@@ -444,34 +498,37 @@ def download_dependency(self, dependency_name):
shutil.rmtree(f"temp/{download_name}-main")
else:
if not self.check_file_available(file_path):
- logger.init(f'{file_path}', status="Downloading")
+ logger.init(f"{file_path}", status="Downloading")
self.download_file(download_url, file_path)
self.init()
return True
-
+
def download_all_models(self):
- for model in self.get_filtered_model_names(download_all = True):
+ for model in self.get_filtered_model_names(download_all=True):
if not self.check_model_available(model):
logger.init(f"{model}", status="Downloading")
self.download_model(model)
else:
logger.info(f"{model} is already downloaded.")
return True
-
+
def download_all_dependencies(self):
for dependency in self.dependencies:
if not self.check_dependency_available(dependency):
- logger.init(f"{dependency}",status="Downloading")
+ logger.init(f"{dependency}", status="Downloading")
self.download_dependency(dependency)
else:
logger.info(f"{dependency} is already installed.")
return True
-
+
def download_all(self):
self.download_all_dependencies()
self.download_all_models()
return True
-
+
+ """
+ FIXME: this method is present twice, commenting first one...
+
def check_all_available(self):
for model in self.models:
if not self.check_available(self.get_model_files(model)):
@@ -480,17 +537,18 @@ def check_all_available(self):
if not self.check_available(self.get_dependency_files(dependency)):
return False
return True
-
+ """
+
def check_model_available(self, model_name):
if model_name not in self.models:
return False
return self.check_available(self.get_model_files(model_name))
-
+
def check_dependency_available(self, dependency_name):
if dependency_name not in self.dependencies:
return False
return self.check_available(self.get_dependency_files(dependency_name))
-
+
def check_all_available(self):
for model in self.models:
if not self.check_model_available(model):
@@ -499,11 +557,3 @@ def check_all_available(self):
if not self.check_dependency_available(dependency):
return False
return True
-
-
-
-
-
-
-
-
diff --git a/nataili/postprocess/upscaler.py b/nataili/postprocess/upscaler.py
index 529b4380..0475c5b4 100644
--- a/nataili/postprocess/upscaler.py
+++ b/nataili/postprocess/upscaler.py
@@ -6,43 +6,42 @@
# - output_ext
# outupts:
# - output_images
-import PIL
-from torchvision import transforms
-import numpy as np
import os
+
import cv2
+import PIL
from nataili.util.save_sample import save_sample
+
class realesrgan:
- def __init__(self, model, device, output_dir, output_ext='jpg'):
+ def __init__(self, model, device, output_dir, output_ext="jpg"):
self.model = model
self.device = device
self.output_dir = output_dir
self.output_ext = output_ext
self.output_images = []
-
+
def generate(self, input_image):
# load image
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
- img_mode = 'RGBA'
+ img_mode = "RGBA"
else:
img_mode = None
# upscale
output, _ = self.model.enhance(img)
- if img_mode == 'RGBA': # RGBA images should be saved in png format
- self.output_ext = 'png'
-
- esrgan_sample = output[:,:,::-1]
+ if img_mode == "RGBA": # RGBA images should be saved in png format
+ self.output_ext = "png"
+
+ esrgan_sample = output[:, :, ::-1]
esrgan_image = PIL.Image.fromarray(esrgan_sample)
# append model name to output image name
filename = os.path.basename(input_image)
filename = os.path.splitext(filename)[0]
- filename = f'{filename}_esrgan'
- filename_with_ext = f'{filename}.{self.output_ext}'
+ filename = f"{filename}_esrgan"
+ filename_with_ext = f"{filename}.{self.output_ext}"
output_image = os.path.join(self.output_dir, filename_with_ext)
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
self.output_images.append(output_image)
return
-
diff --git a/nataili/upscalers/realesrgan.py b/nataili/upscalers/realesrgan.py
index 529b4380..0475c5b4 100644
--- a/nataili/upscalers/realesrgan.py
+++ b/nataili/upscalers/realesrgan.py
@@ -6,43 +6,42 @@
# - output_ext
# outupts:
# - output_images
-import PIL
-from torchvision import transforms
-import numpy as np
import os
+
import cv2
+import PIL
from nataili.util.save_sample import save_sample
+
class realesrgan:
- def __init__(self, model, device, output_dir, output_ext='jpg'):
+ def __init__(self, model, device, output_dir, output_ext="jpg"):
self.model = model
self.device = device
self.output_dir = output_dir
self.output_ext = output_ext
self.output_images = []
-
+
def generate(self, input_image):
# load image
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
- img_mode = 'RGBA'
+ img_mode = "RGBA"
else:
img_mode = None
# upscale
output, _ = self.model.enhance(img)
- if img_mode == 'RGBA': # RGBA images should be saved in png format
- self.output_ext = 'png'
-
- esrgan_sample = output[:,:,::-1]
+ if img_mode == "RGBA": # RGBA images should be saved in png format
+ self.output_ext = "png"
+
+ esrgan_sample = output[:, :, ::-1]
esrgan_image = PIL.Image.fromarray(esrgan_sample)
# append model name to output image name
filename = os.path.basename(input_image)
filename = os.path.splitext(filename)[0]
- filename = f'{filename}_esrgan'
- filename_with_ext = f'{filename}.{self.output_ext}'
+ filename = f"{filename}_esrgan"
+ filename_with_ext = f"{filename}.{self.output_ext}"
output_image = os.path.join(self.output_dir, filename_with_ext)
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
self.output_images.append(output_image)
return
-
diff --git a/nataili/util/__init__.py b/nataili/util/__init__.py
index 766cb275..351adee3 100644
--- a/nataili/util/__init__.py
+++ b/nataili/util/__init__.py
@@ -1 +1,8 @@
-from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger
+from nataili.util.logger import logger, quiesce_logger, set_logger_verbosity, test_logger
+
+__all__ = [
+ logger,
+ quiesce_logger,
+ set_logger_verbosity,
+ test_logger,
+]
diff --git a/nataili/util/cache.py b/nataili/util/cache.py
index ef7711c6..65462891 100644
--- a/nataili/util/cache.py
+++ b/nataili/util/cache.py
@@ -1,11 +1,9 @@
import gc
import torch
-import threading
-import pynvml
-import time
with torch.no_grad():
+
def torch_gc():
for _ in range(2):
gc.collect()
diff --git a/nataili/util/check_prompt_length.py b/nataili/util/check_prompt_length.py
index 4953eacc..c94ace98 100644
--- a/nataili/util/check_prompt_length.py
+++ b/nataili/util/check_prompt_length.py
@@ -4,15 +4,23 @@ def check_prompt_length(model, prompt, comments):
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
- info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
- return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
- ovf = info['overflowing_tokens'][0]
+ info = model.cond_stage_model.tokenizer(
+ [prompt],
+ truncation=True,
+ max_length=max_length,
+ return_overflowing_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ ovf = info["overflowing_tokens"][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
return
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
- del tokenizer
\ No newline at end of file
+ overflowing_text = tokenizer.convert_tokens_to_string("".join(overflowing_words))
+ comments.append(
+ f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n"
+ )
+ del tokenizer
diff --git a/nataili/util/get_next_sequence_number.py b/nataili/util/get_next_sequence_number.py
index ede028f6..f7a24d56 100644
--- a/nataili/util/get_next_sequence_number.py
+++ b/nataili/util/get_next_sequence_number.py
@@ -1,6 +1,7 @@
from pathlib import Path
-def get_next_sequence_number(path, prefix=''):
+
+def get_next_sequence_number(path, prefix=""):
"""
Determines and returns the next sequence number to use when saving an
image in the specified directory.
@@ -13,10 +14,10 @@ def get_next_sequence_number(path, prefix=''):
"""
result = -1
for p in Path(path).iterdir():
- if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
- tmp = p.name[len(prefix):]
+ if p.name.endswith((".png", ".jpg")) and p.name.startswith(prefix):
+ tmp = p.name[len(prefix) :]
try:
- result = max(int(tmp.split('-')[0]), result)
+ result = max(int(tmp.split("-")[0]), result)
except ValueError:
pass
- return result + 1
\ No newline at end of file
+ return result + 1
diff --git a/nataili/util/image_grid.py b/nataili/util/image_grid.py
index 7ea85eb7..cbc2257d 100644
--- a/nataili/util/image_grid.py
+++ b/nataili/util/image_grid.py
@@ -13,7 +13,7 @@ def image_grid(imgs, n_rows=None):
cols = math.ceil(len(imgs) / rows)
w, h = imgs[0].size
- grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black')
+ grid = PIL.Image.new("RGB", size=(cols * w, rows * h), color="black")
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
diff --git a/nataili/util/load_learned_embed_in_clip.py b/nataili/util/load_learned_embed_in_clip.py
index 9507e58b..e0a2d1fe 100644
--- a/nataili/util/load_learned_embed_in_clip.py
+++ b/nataili/util/load_learned_embed_in_clip.py
@@ -6,19 +6,17 @@
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
- if learned_embeds_path.endswith('.pt'):
+ if learned_embeds_path.endswith(".pt"):
# old format
# token = * so replace with file directory name when converting
trained_token = os.path.basename(learned_embeds_path)
- params_dict = {
- trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1])
- }
- learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin'
+ params_dict = {trained_token: torch.tensor(list(loaded_learned_embeds["string_to_param"].items())[0][1])}
+ learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + ".bin"
torch.save(params_dict, learned_embeds_path)
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
- elif learned_embeds_path.endswith('.bin'):
+ elif learned_embeds_path.endswith(".bin"):
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
@@ -29,7 +27,9 @@ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, tok
# add the token in tokenizer
token = token if token is not None else trained_token
- num_added_tokens = tokenizer.add_tokens(token)
+
+ # FIXME: following line is not used, remove?
+ # num_added_tokens = tokenizer.add_tokens(token)
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
diff --git a/nataili/util/logger.py b/nataili/util/logger.py
index d012c830..bb5181b2 100644
--- a/nataili/util/logger.py
+++ b/nataili/util/logger.py
@@ -1,5 +1,6 @@
import sys
from functools import partialmethod
+
from loguru import logger
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
@@ -9,6 +10,7 @@
verbosity = 20
quiet = 0
+
def set_logger_verbosity(count):
global verbosity
# The count comes reversed. So count = 0 means minimum verbosity
@@ -16,41 +18,49 @@ def set_logger_verbosity(count):
# So the more count we have, the lowe we drop the versbosity maximum
verbosity = 20 - (count * 10)
+
def quiesce_logger(count):
global quiet
# The bigger the count, the more silent we want our logger
quiet = count * 10
+
def is_stdout_log(record):
if record["level"].name not in STDOUT_LEVELS:
- return(False)
+ return False
if record["level"].no < verbosity + quiet:
- return(False)
- return(True)
+ return False
+ return True
+
def is_init_log(record):
if record["level"].name not in INIT_LEVELS:
- return(False)
+ return False
if record["level"].no < verbosity + quiet:
- return(False)
- return(True)
+ return False
+ return True
+
def is_msg_log(record):
if record["level"].name not in MESSAGE_LEVELS:
- return(False)
+ return False
if record["level"].no < verbosity + quiet:
- return(False)
- return(True)
+ return False
+ return True
+
def is_stderr_log(record):
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
- return(False)
+ return False
if record["level"].no < verbosity + quiet:
- return(False)
- return(True)
+ return False
+ return True
+
def test_logger():
- logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
+ logger.generation(
+ "This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8")
+ )
logger.prompt("This is a prompt message")
logger.debug("Debug Message")
logger.info("Info Message")
@@ -65,7 +75,10 @@ def test_logger():
sys.exit()
-logfmt = "{level: <10} | {time:YYYY-MM-DD HH:mm:ss} | {name}:{function}:{line} - {message}"
+logfmt = (
+ "{level: <10} | {time:YYYY-MM-DD HH:mm:ss} | "
+ "{name}:{function}:{line} - {message}"
+)
genfmt = "{level: <10} @ {time:YYYY-MM-DD HH:mm:ss} | {message}"
initfmt = "INIT | {extra[status]: <11} | {message}"
msgfmt = "{level: <10} | {message}"
@@ -93,10 +106,33 @@ def test_logger():
config = {
"handlers": [
- {"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
- {"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
- {"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
- {"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
+ {
+ "sink": sys.stderr,
+ "format": logfmt,
+ "colorize": True,
+ "filter": is_stderr_log,
+ },
+ {
+ "sink": sys.stdout,
+ "format": genfmt,
+ "level": "PROMPT",
+ "colorize": True,
+ "filter": is_stdout_log,
+ },
+ {
+ "sink": sys.stdout,
+ "format": initfmt,
+ "level": "INIT",
+ "colorize": True,
+ "filter": is_init_log,
+ },
+ {
+ "sink": sys.stdout,
+ "format": msgfmt,
+ "level": "MESSAGE",
+ "colorize": True,
+ "filter": is_msg_log,
+ },
],
}
logger.configure(**config)
diff --git a/nataili/util/save_sample.py b/nataili/util/save_sample.py
index 5c791d30..01606f7b 100644
--- a/nataili/util/save_sample.py
+++ b/nataili/util/save_sample.py
@@ -1,16 +1,26 @@
import os
-def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9):
- path = os.path.join(sample_path, filename + '.' + extension)
+
+def save_sample(
+ image,
+ filename,
+ sample_path,
+ extension="png",
+ jpg_quality=95,
+ webp_quality=95,
+ webp_lossless=True,
+ png_compression=9,
+):
+ path = os.path.join(sample_path, filename + "." + extension)
if os.path.exists(path):
return False
if not os.path.exists(sample_path):
os.makedirs(sample_path)
- if extension == 'png':
- image.save(path, format='PNG', compress_level=png_compression)
- elif extension == 'jpg':
+ if extension == "png":
+ image.save(path, format="PNG", compress_level=png_compression)
+ elif extension == "jpg":
image.save(path, quality=jpg_quality, optimize=True)
- elif extension == 'webp':
+ elif extension == "webp":
image.save(path, quality=webp_quality, lossless=webp_lossless)
else:
return False
diff --git a/nataili/util/seed_to_int.py b/nataili/util/seed_to_int.py
index 61cc9fdd..5fc4819c 100644
--- a/nataili/util/seed_to_int.py
+++ b/nataili/util/seed_to_int.py
@@ -1,15 +1,16 @@
import random
+
def seed_to_int(s):
if type(s) is int:
return s
- if s is None or s == '':
+ if s is None or s == "":
return random.randint(0, 2**32 - 1)
if type(s) is list:
seed_list = []
for seed in s:
- if seed is None or seed == '':
+ if seed is None or seed == "":
seed_list.append(random.randint(0, 2**32 - 1))
else:
seed_list = s
@@ -19,4 +20,4 @@ def seed_to_int(s):
n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
while n >= 2**32:
n = n >> 32
- return n
\ No newline at end of file
+ return n
diff --git a/nataili/util/switch.py b/nataili/util/switch.py
index c1627651..007a99ce 100644
--- a/nataili/util/switch.py
+++ b/nataili/util/switch.py
@@ -3,9 +3,9 @@ class Switch:
def activate(self):
self.active = True
-
+
def disable(self):
self.active = False
- def toggle(self,value):
- self.active = value
\ No newline at end of file
+ def toggle(self, value):
+ self.active = value
diff --git a/nataili/util/voodoo.py b/nataili/util/voodoo.py
index 30d1c522..13b1177c 100644
--- a/nataili/util/voodoo.py
+++ b/nataili/util/voodoo.py
@@ -1,13 +1,14 @@
import contextlib
import copy
from functools import wraps
-from typing import Dict, List, Tuple, TypeVar, Union
+from typing import Dict, List, Tuple, TypeVar
import ray
import torch
T = TypeVar("T")
+
def performance(f: T) -> T:
@wraps(f)
def wrapper(*args, **kwargs):
@@ -15,17 +16,14 @@ def wrapper(*args, **kwargs):
return wrapper
+
def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]:
tensors = []
for _, module in m.named_modules():
params = {
- name: torch.clone(param).cpu().detach().numpy()
- for name, param in module.named_parameters(recurse=False)
- }
- buffers = {
- name: torch.clone(buf).cpu().detach().numpy()
- for name, buf in module.named_buffers(recurse=False)
+ name: torch.clone(param).cpu().detach().numpy() for name, param in module.named_parameters(recurse=False)
}
+ buffers = {name: torch.clone(buf).cpu().detach().numpy() for name, buf in module.named_buffers(recurse=False)}
tensors.append({"params": params, "buffers": buffers})
m_copy = copy.deepcopy(m)
@@ -46,9 +44,7 @@ def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device="cuda"):
for name, array in tensor_dict["params"].items():
module.register_parameter(
name,
- torch.nn.Parameter(
- torch.as_tensor(array, device=device), requires_grad=False
- ),
+ torch.nn.Parameter(torch.as_tensor(array, device=device), requires_grad=False),
)
for name, array in tensor_dict["buffers"].items():
module.register_buffer(name, torch.as_tensor(array, device=device))
@@ -62,6 +58,7 @@ def load_from_plasma(ref, device="cuda"):
yield skeleton
torch.cuda.empty_cache()
+
def push_model_to_plasma(model: torch.nn.Module) -> ray.ObjectRef:
ref = ray.put(extract_tensors(model))
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..bdb94fe5
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,8 @@
+[tool.black]
+line-length = 119
+
+[tool.isort]
+py_version = 38
+line_length = 119
+profile = "black"
+known_third_party = ["creds"]
diff --git a/requirements.dev.txt b/requirements.dev.txt
new file mode 100644
index 00000000..69ed784e
--- /dev/null
+++ b/requirements.dev.txt
@@ -0,0 +1,3 @@
+black==22.3.0
+flake8==3.8.1
+isort==5.10.1
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..ed2a7c7d
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,8 @@
+[flake8]
+max-line-length = 119
+extend-ignore = E203
+
+[isort]
+multi_line_output = 3
+include_trailing_comma = True
+line_length = 119
diff --git a/setup.py b/setup.py
index 0b5e1d96..1cef98e0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,13 +1,13 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
setup(
- name='nataili',
- version='0.0.1',
- description='',
+ name="nataili",
+ version="0.0.1",
+ description="",
packages=find_packages(),
install_requires=[
- 'torch',
- 'numpy',
- 'tqdm',
+ "torch",
+ "numpy",
+ "tqdm",
],
)
diff --git a/show_available_models.py b/show_available_models.py
index 1bf5deaf..90329f7e 100644
--- a/show_available_models.py
+++ b/show_available_models.py
@@ -4,13 +4,13 @@
# TODO: huggingface_hub or some way to use token instead of username/password
mm = ModelManager()
-filtered_models = mm.get_filtered_models(type='ckpt')
-ppmodels = ''
+filtered_models = mm.get_filtered_models(type="ckpt")
+ppmodels = ""
for model_name in filtered_models:
if model_name == 'LDSR': continue
ppmodels += model_name
- if filtered_models[model_name].get('description'):
+ if filtered_models[model_name].get("description"):
ppmodels += f" : {filtered_models[model_name].get('description')}"
- ppmodels += '\n'
+ ppmodels += "\n"
print(f"## Known ckpt Models ##\n{ppmodels}")
-input("Press ENTER to continue")
\ No newline at end of file
+input("Press ENTER to continue")
diff --git a/style.sh b/style.sh
new file mode 100755
index 00000000..a4c8ec4b
--- /dev/null
+++ b/style.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Run this script with "--fix" to automatically fix the issues which can be fixed
+
+# Set the working directory to where this script is located
+cd "$(dirname ${BASH_SOURCE[0]})"
+
+# exit script directly if any command fails
+set -e
+
+if [ "$1" == "--fix" ]
+then
+ echo "fix requested"
+ BLACK_OPTS=""
+ ISORT_OPTS=""
+else
+ echo "fix not requested"
+ BLACK_OPTS="--check --diff"
+ ISORT_OPTS="--check-only --diff"
+fi
+
+SRC="*.py nataili"
+
+black --line-length=119 $BLACK_OPTS $SRC
+flake8 $SRC
+isort $ISORT_OPTS $SRC
diff --git a/test.py b/test.py
index 6e5a26a2..3c696b7b 100644
--- a/test.py
+++ b/test.py
@@ -1,89 +1,104 @@
+import time
+
from nataili.inference.compvis.img2img import img2img
-from nataili.model_manager import ModelManager
from nataili.inference.compvis.txt2img import txt2img
+from nataili.model_manager import ModelManager
from nataili.util.cache import torch_gc
from nataili.util.logger import logger
-import time
-import PIL
-
-init_image = './01.png'
+init_image = "./01.png"
mm = ModelManager()
mm.init()
-logger.debug(f'Available dependencies:')
+logger.debug("Available dependencies:")
for dependency in mm.available_dependencies:
logger.debug(dependency)
-logger.debug(f'Available models:')
+logger.debug("Available models:")
for model in mm.available_models:
logger.debug(model)
-models_to_load = [#'stable_diffusion',
- #'waifu_diffusion',
- 'trinart',
- #'GFPGAN', 'RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B',
- #'BLIP', 'ViT-L/14', 'ViT-g-14', 'ViT-H-14'
- ]
-logger.init(f'{models_to_load}', status="Loading")
+models_to_load = [
+ # 'stable_diffusion',
+ # 'waifu_diffusion',
+ "trinart",
+ # 'GFPGAN', 'RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B',
+ # 'BLIP', 'ViT-L/14', 'ViT-g-14', 'ViT-H-14'
+]
+logger.init(f"{models_to_load}", status="Loading")
+
@logger.catch
def test():
tic = time.time()
- model = 'safety_checker'
- logger.init(f'Model: {model}', status="Loading")
+ model = "safety_checker"
+ logger.init(f"Model: {model}", status="Loading")
success = mm.load_model(model)
toc = time.time()
- logger.init_ok(f'Loading {model}: Took {toc-tic} seconds', status=success)
+ logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
for model in models_to_load:
torch_gc()
tic = time.time()
- logger.init(f'Model: {model}', status="Loading")
-
+ logger.init(f"Model: {model}", status="Loading")
+
success = mm.load_model(model)
toc = time.time()
- logger.init_ok(f'Loading {model}: Took {toc-tic} seconds', status=success)
+ logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
torch_gc()
- if model in ['stable_diffusion', 'waifu_diffusion', 'trinart']:
- logger.debug(f'Running inference on {model}')
- logger.info(f'Testing txt2img with prompt "collosal corgi"')
+ if model in ["stable_diffusion", "waifu_diffusion", "trinart"]:
+ logger.debug(f"Running inference on {model}")
+ logger.info('Testing txt2img with prompt "collosal corgi"')
- t2i = txt2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output')
- t2i.generate('collosal corgi')
+ t2i = txt2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ )
+ t2i.generate("collosal corgi")
torch_gc()
- logger.info(f'Testing nsfw filter with prompt "boobs"')
+ logger.info('Testing nsfw filter with prompt "boobs"')
- t2i = txt2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output', filter_nsfw=True, safety_checker=mm.loaded_models['safety_checker']['model'])
- t2i.generate('boobs')
+ t2i = txt2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ filter_nsfw=True,
+ safety_checker=mm.loaded_models["safety_checker"]["model"],
+ )
+ t2i.generate("boobs")
torch_gc()
- logger.info(f'Testing img2img with prompt "cute anime girl"')
+ logger.info('Testing img2img with prompt "cute anime girl"')
- i2i = img2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output')
- init_img = PIL.Image.open(init_img)
- i2i.generate('cute anime girl', init_image)
+ i2i = img2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ )
+ # init_img = PIL.Image.open(init_img)
+ i2i.generate("cute anime girl", init_image)
torch_gc()
- logger.init_ok(f'Model {model}', status="Unloading")
+ logger.init_ok(f"Model {model}", status="Unloading")
mm.unload_model(model)
torch_gc()
while True:
- print('Enter model name to load:')
+ print("Enter model name to load:")
print(mm.available_models)
model = input()
- if model == 'exit':
+ if model == "exit":
break
- print(f'Loading {model}')
+ print(f"Loading {model}")
success = mm.load_model(model)
- print(f'Loading {model} successful: {success}')
- print('')
+ print(f"Loading {model} successful: {success}")
+ print("")
if __name__ == "__main__":
diff --git a/test_download_all_models.py b/test_download_all_models.py
index 3b73b15e..1d24ba01 100644
--- a/test_download_all_models.py
+++ b/test_download_all_models.py
@@ -1,11 +1,12 @@
# test_download_models
-from nataili.model_manager import ModelManager
import creds
+from nataili.model_manager import ModelManager
+
# TODO: huggingface_hub or some way to use token instead of username/password
hf_auth = {"username": creds.hf_username, "password": creds.hf_password}
mm = ModelManager(hf_auth=hf_auth)
mm.init()
-mm.download_all()
\ No newline at end of file
+mm.download_all()
diff --git a/test_inpainting.py b/test_inpainting.py
index 5bc4e8fc..afae4193 100644
--- a/test_inpainting.py
+++ b/test_inpainting.py
@@ -1,6 +1,7 @@
-from nataili.inference.diffusers.inpainting import inpainting
from PIL import Image
+from nataili.inference.diffusers.inpainting import inpainting
+
original = Image.open("./inpaint_original.png")
mask = Image.open("./inpaint_mask.png")
diff --git a/test_voodoo.py b/test_voodoo.py
index ed76834c..dcacddea 100644
--- a/test_voodoo.py
+++ b/test_voodoo.py
@@ -1,89 +1,109 @@
+import time
+
+import PIL
+
from nataili.inference.compvis.img2img import img2img
-from nataili.model_manager import ModelManager
from nataili.inference.compvis.txt2img import txt2img
+from nataili.model_manager import ModelManager
from nataili.util.cache import torch_gc
from nataili.util.logger import logger
-import time
-import PIL
-
-init_image = './01.png'
+init_image = "./01.png"
mm = ModelManager()
mm.init()
-logger.debug(f'Available dependencies:')
+logger.debug("Available dependencies:")
for dependency in mm.available_dependencies:
logger.debug(dependency)
-logger.debug(f'Available models:')
+logger.debug("Available models:")
for model in mm.available_models:
logger.debug(model)
-models_to_load = [#'stable_diffusion',
- #'waifu_diffusion',
- 'trinart',
- #'GFPGAN', 'RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B',
- #'BLIP', 'ViT-L/14', 'ViT-g-14', 'ViT-H-14'
- ]
-logger.init(f'{models_to_load}', status="Loading")
+models_to_load = [
+ # 'stable_diffusion',
+ # 'waifu_diffusion',
+ "trinart",
+ # 'GFPGAN', 'RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B',
+ # 'BLIP', 'ViT-L/14', 'ViT-g-14', 'ViT-H-14'
+]
+logger.init(f"{models_to_load}", status="Loading")
+
@logger.catch
def test():
tic = time.time()
- model = 'safety_checker'
- logger.init(f'Model: {model}', status="Loading")
+ model = "safety_checker"
+ logger.init(f"Model: {model}", status="Loading")
success = mm.load_model(model)
toc = time.time()
- logger.init_ok(f'Loading {model}: Took {toc-tic} seconds', status=success)
+ logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
for model in models_to_load:
torch_gc()
tic = time.time()
- logger.init(f'Model: {model}', status="Loading")
-
+ logger.init(f"Model: {model}", status="Loading")
+
success = mm.load_model(model, use_voodoo=True)
toc = time.time()
- logger.init_ok(f'Loading {model}: Took {toc-tic} seconds', status=success)
+ logger.init_ok(f"Loading {model}: Took {toc-tic} seconds", status=success)
torch_gc()
- if model in ['stable_diffusion', 'waifu_diffusion', 'trinart']:
- logger.debug(f'Running inference on {model}')
- logger.info(f'Testing txt2img with prompt "collosal corgi"')
+ if model in ["stable_diffusion", "waifu_diffusion", "trinart"]:
+ logger.debug(f"Running inference on {model}")
+ logger.info('Testing txt2img with prompt "collosal corgi"')
- t2i = txt2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output', use_voodoo=True)
- t2i.generate('collosal corgi')
+ t2i = txt2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ use_voodoo=True,
+ )
+ t2i.generate("collosal corgi")
torch_gc()
- logger.info(f'Testing nsfw filter with prompt "boobs"')
+ logger.info('Testing nsfw filter with prompt "boobs"')
- t2i = txt2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output', filter_nsfw=True, safety_checker=mm.loaded_models['safety_checker']['model'], use_voodoo=True)
- t2i.generate('boobs')
+ t2i = txt2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ filter_nsfw=True,
+ safety_checker=mm.loaded_models["safety_checker"]["model"],
+ use_voodoo=True,
+ )
+ t2i.generate("boobs")
torch_gc()
- logger.info(f'Testing img2img with prompt "cute anime girl"')
+ logger.info('Testing img2img with prompt "cute anime girl"')
- i2i = img2img(mm.loaded_models[model]["model"], mm.loaded_models[model]["device"], 'test_output', use_voodoo=True)
+ i2i = img2img(
+ mm.loaded_models[model]["model"],
+ mm.loaded_models[model]["device"],
+ "test_output",
+ use_voodoo=True,
+ )
init_img = PIL.Image.open(init_image)
- i2i.generate('cute anime girl', init_img)
+ i2i.generate("cute anime girl", init_img)
torch_gc()
- logger.init_ok(f'Model {model}', status="Unloading")
+ logger.init_ok(f"Model {model}", status="Unloading")
mm.unload_model(model)
torch_gc()
while True:
- print('Enter model name to load:')
+ print("Enter model name to load:")
print(mm.available_models)
model = input()
- if model == 'exit':
+ if model == "exit":
break
- print(f'Loading {model}')
+ print(f"Loading {model}")
success = mm.load_model(model)
- print(f'Loading {model} successful: {success}')
- print('')
+ print(f"Loading {model} successful: {success}")
+ print("")
if __name__ == "__main__":