diff --git a/.gitignore b/.gitignore index 8b8e9634..b646ca26 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,4 @@ SQL_statements.txt horde.log horde.db /.idea +/boto3oeo.py \ No newline at end of file diff --git a/cli_logger.py b/cli_logger.py index 21d66e40..3fb7d292 100644 --- a/cli_logger.py +++ b/cli_logger.py @@ -97,6 +97,7 @@ def test_logger(): ], } logger.configure(**config) +logger.add("cliRequests.log", retention="7 days", level=19) logger.disable("__main__") logger.warning("disabled") logger.enable("") diff --git a/cli_request.py b/cli_request.py index 4382f3a1..fbef2234 100644 --- a/cli_request.py +++ b/cli_request.py @@ -43,7 +43,8 @@ def __init__(self): "nsfw": False, "censor_nsfw": False, "trusted_workers": False, - "models": ["stable_diffusion"] + "models": ["stable_diffusion"], + "r2": True } self.source_image = None self.source_processing = "img2img" @@ -160,14 +161,22 @@ def generate(): return results = results_json['generations'] for iter in range(len(results)): - b64img = results[iter]["img"] - base64_bytes = b64img.encode('utf-8') - img_bytes = base64.b64decode(base64_bytes) - img = Image.open(BytesIO(img_bytes)) final_filename = request_data.filename if len(results) > 1: final_filename = f"{iter}_{request_data.filename}" - img.save(final_filename) + if request_data.get_submit_dict()["r2"]: + try: + img_data = requests.get(results[iter]["img"]).content + except: + logger.error("Received b64 again") + with open(final_filename, 'wb') as handler: + handler.write(img_data) + else: + b64img = results[iter]["img"] + base64_bytes = b64img.encode('utf-8') + img_bytes = base64.b64decode(base64_bytes) + img = Image.open(BytesIO(img_bytes)) + img.save(final_filename) logger.info(f"Saved {final_filename}") else: logger.error(submit_req.text) diff --git a/horde/apis/models/stable_v2.py b/horde/apis/models/stable_v2.py index cde5b04e..a611aaef 100644 --- a/horde/apis/models/stable_v2.py +++ b/horde/apis/models/stable_v2.py @@ -10,6 +10,7 @@ def __init__(self): self.generate_parser.add_argument("source_processing", type=str, default="img2img", required=False, help="If source_image is provided, specifies how to process it.", location="json") self.generate_parser.add_argument("source_mask", type=str, required=False, help="If img_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel", location="json") self.generate_parser.add_argument("models", type=list, required=False, default=['stable_diffusion'], help="The acceptable models with which to generate", location="json") + self.generate_parser.add_argument("r2", type=bool, default=False, required=False, help="If True, the image will be sent via cloudflare r2 download link", location="json") self.job_pop_parser.add_argument("max_pixels", type=int, required=False, default=512*512, help="The maximum amount of pixels this worker can generate", location="json") self.job_pop_parser.add_argument("allow_img2img", type=bool, required=False, default=True, help="If True, this worker will pick up img2img requests", location="json") self.job_pop_parser.add_argument("allow_painting", type=bool, required=False, default=True, help="If True, this worker will pick up inpainting/outpaining requests", location="json") @@ -64,6 +65,7 @@ def __init__(self,api): 'source_image': fields.String(description="The Base64-encoded webp to use for img2img"), 'source_processing': fields.String(required=False, default='img2img',enum=["img2img", "inpainting", "outpainting"], description="If source_image is provided, specifies how to process it."), 'source_mask': fields.String(description="If img_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel"), + 'r2_upload': fields.String(description="The r2 upload link to use to upload this image"), }) self.input_model_job_pop = api.inherit('PopInputStable', self.input_model_job_pop, { 'max_pixels': fields.Integer(default=512*512,description="The maximum amount of pixels this worker can generate"), @@ -83,6 +85,7 @@ def __init__(self,api): 'source_image': fields.String(required=False, description="The Base64-encoded webp to use for img2img"), 'source_processing': fields.String(required=False, default='img2img',enum=["img2img", "inpainting", "outpainting"], description="If source_image is provided, specifies how to process it."), 'source_mask': fields.String(description="If source_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the Base64-encoded webp mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel"), + 'r2': fields.Boolean(default=False, description="If True, the image will be sent via cloudflare r2 download link"), }) self.response_model_worker_details = api.inherit('WorkerDetailsStable', self.response_model_worker_details, { "max_pixels": fields.Integer(example=262144,description="The maximum pixels in resolution this worker can generate"), diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 09bde79b..ffc0cce2 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -74,9 +74,9 @@ def validate(self): if self.args.source_image: if self.args.source_processing == "img2img" and self.params.get("sampler_name") in ["k_dpm_fast", "k_dpm_adaptive", "k_dpmpp_2s_a", "k_dpmpp_2m"]: raise e.UnsupportedSampler - if "stable_diffusion_2.0" in self.args.models: + if any(model_name.startswith("stable_diffusion_2") for model_name in self.args.models): raise e.UnsupportedModel - if self.args.models != ["stable_diffusion_2.0"] and self.params.get("sampler_name") in ["dpmsolver"]: + if not any(model_name.startswith("stable_diffusion_2") for model_name in self.args.models) and self.params.get("sampler_name") in ["dpmsolver"]: raise e.UnsupportedSampler # if self.args.models == ["stable_diffusion_2.0"] and self.params.get("sampler_name") not in ["dpmsolver"]: # raise e.UnsupportedSampler @@ -105,6 +105,7 @@ def initiate_waiting_prompt(self): source_mask = convert_source_image_to_webp(self.args.source_mask), ipaddr = self.user_ip, safe_ip=self.safe_ip, + r2=self.args.r2, ) needs_kudos,resolution = self.wp.requires_upfront_kudos(database.retrieve_totals()) if needs_kudos: @@ -149,6 +150,7 @@ def initiate_waiting_prompt(self): source_mask = convert_source_image_to_webp(self.args.source_mask), ipaddr = self.user_ip, safe_ip=self.safe_ip, + r2=self.args.r2, ) needs_kudos,resolution = self.wp.requires_upfront_kudos(database.retrieve_totals()) if needs_kudos: diff --git a/horde/apis/v2/v2.py b/horde/apis/v2/v2.py index 1aac793b..09d862ba 100644 --- a/horde/apis/v2/v2.py +++ b/horde/apis/v2/v2.py @@ -389,7 +389,7 @@ def post(self): # self.prioritized_wp.append(wp) # logger.warning(datetime.utcnow()) ## End prioritize by bridge request ## - for wp in self.get_sorted_wp(): # TODO this should also filter on .n>0 + for wp in self.get_sorted_wp(): if wp not in self.prioritized_wp: self.prioritized_wp.append(wp) # logger.warning(datetime.utcnow()) @@ -412,6 +412,7 @@ def post(self): # logger.debug(worker_ret) if worker_ret is None: continue + # logger.debug(worker_ret) return(worker_ret, 200) # We report maintenance exception only if we couldn't find any jobs if self.worker.maintenance: diff --git a/horde/argparser.py b/horde/argparser.py index 7e848bb2..eff964ee 100644 --- a/horde/argparser.py +++ b/horde/argparser.py @@ -12,6 +12,7 @@ arg_parser.add_argument('--worker_invite', action="store_true", help="If set, Will start the horde in worker invite-only mode") arg_parser.add_argument('--raid', action="store_true", help="If set, Will start the horde in raid prevention mode") arg_parser.add_argument('--allow_all_ips', action="store_true", help="If set, will consider all IPs safe") +arg_parser.add_argument('--quorum', action="store_true", help="If set, will forcefully grab the quorum") args = arg_parser.parse_args() maintenance = Switch() diff --git a/horde/classes/base/user.py b/horde/classes/base/user.py index 0fcf5716..2ca61677 100644 --- a/horde/classes/base/user.py +++ b/horde/classes/base/user.py @@ -272,13 +272,13 @@ def reset_suspicion(self): '''Clears the user's suspicion and resets their reasons''' if self.is_anon(): return - #TODO Select from UserSuspicions DB and delete all matching user ID + db.session.query(UserSuspicions).filter_by(user_id=self.id).delete() db.session.commit() for worker in self.workers: worker.reset_suspicion() def get_suspicion(self): - return(db.session.query(UserSuspicions).filter(user_id=self.id).count()) + return(db.session.query(UserSuspicions).filter_by(user_id=self.id).count()) def count_workers(self): return(len(self.workers)) @@ -374,7 +374,7 @@ def get_details(self, details_privilege = 0): } ret_dict["evaluating_kudos"] = self.evaluating_kudos ret_dict["monthly_kudos"] = mk_dict - ret_dict["suspicious"] = self.suspicious + ret_dict["suspicious"] = len(self.suspicions) return(ret_dict) diff --git a/horde/classes/base/waiting_prompt.py b/horde/classes/base/waiting_prompt.py index ba3b27d0..45e8535b 100644 --- a/horde/classes/base/waiting_prompt.py +++ b/horde/classes/base/waiting_prompt.py @@ -98,7 +98,6 @@ def set_workers(self, worker_ids = None): def set_models(self, model_names = None): if not model_names: model_names = [] # We don't allow more workers to claim they can server more than 50 models atm (to prevent abuse) - logger.debug(model_names) for model in model_names: model_entry = WPModels(model=model,wp_id=self.id) db.session.add(model_entry) diff --git a/horde/classes/base/worker.py b/horde/classes/base/worker.py index 33c368e6..aedb9043 100644 --- a/horde/classes/base/worker.py +++ b/horde/classes/base/worker.py @@ -130,7 +130,7 @@ def report_suspicion(self, amount = 1, reason = Suspicions.WORKER_PROFANITY, for def reset_suspicion(self): '''Clears the worker's suspicion and resets their reasons''' - #TODO Select from WorkerSuspicions DB and delete all matching user ID + db.session.query(WorkerSuspicions).filter_by(worker_id=self.id).delete() db.session.commit() def get_suspicion(self): @@ -269,11 +269,6 @@ def can_generate(self, waiting_prompt): # We don't consider stale workers in the request, so we don't need to report a reason is_matching = False return([is_matching,skipped_reason]) - # If the request specified only specific workers to fulfill it, and we're not one of them, we skip - #logger.warning(datetime.utcnow()) - if len(waiting_prompt.workers) >= 1 and self not in waiting_prompt.workers: - is_matching = False - skipped_reason = 'worker_id' #logger.warning(datetime.utcnow()) if waiting_prompt.nsfw and not self.nsfw: is_matching = False diff --git a/horde/classes/stable/news.py b/horde/classes/stable/news.py index d6376934..b027eac2 100644 --- a/horde/classes/stable/news.py +++ b/horde/classes/stable/news.py @@ -2,7 +2,13 @@ class NewsExtended(News): + STABLE_HORDE_NEWS = [ + { + "date_published": "2022-12-08", + "newspiece": "The Stable Horde workers now support dynamically swapping models. This means that models will always switch to support the most in demand models every minute, allowing us to support demand much better!", + "importance": "Information" + }, { "date_published": "2022-11-28", "newspiece": "The Horde has undertaken a massive code refactoring to allow me to move to a proper SQL DB. This will finally allow me to scale the frontend systems horizontally and allow for way more capacity!", diff --git a/horde/classes/stable/processing_generation.py b/horde/classes/stable/processing_generation.py index 0cfdedf7..7b398565 100644 --- a/horde/classes/stable/processing_generation.py +++ b/horde/classes/stable/processing_generation.py @@ -1,13 +1,17 @@ from horde.logger import logger from horde.classes.base.processing_generation import ProcessingGeneration +from horde.r2 import generate_download_url class ProcessingGenerationExtended(ProcessingGeneration): def get_details(self): '''Returns a dictionary with details about this processing generation''' + generation = self.generation + if generation == "R2": + generation = generate_download_url(str(self.id)) ret_dict = { - "img": self.generation, + "img": generation, "seed": self.seed, "worker_id": self.worker.id, "worker_name": self.worker.name, diff --git a/horde/classes/stable/waiting_prompt.py b/horde/classes/stable/waiting_prompt.py index 6fe5f01e..29e97d86 100644 --- a/horde/classes/stable/waiting_prompt.py +++ b/horde/classes/stable/waiting_prompt.py @@ -5,6 +5,7 @@ from horde.flask import db from horde.utils import get_random_seed from horde.classes.base.waiting_prompt import WaitingPrompt +from horde.r2 import generate_upload_url class WaitingPromptExtended(WaitingPrompt): @@ -15,6 +16,7 @@ class WaitingPromptExtended(WaitingPrompt): seed = db.Column(db.BigInteger, default=None, nullable=True) seed_variation = db.Column(db.Integer, default=None) kudos = db.Column(db.Float, default=0, nullable=False) + r2 = db.Column(db.Boolean, default=False, nullable=False) @logger.catch(reraise=True) def extract_params(self): @@ -40,7 +42,7 @@ def extract_params(self): self.width = self.params["width"] self.height = self.params["height"] # Silent change - if self.get_model_names() == ["stable_diffusion_2.0"]: + if any(model_name.startswith("stable_diffusion_2") for model_name in self.get_model_names()): self.params['sampler_name'] = "dpmsolver" # The total amount of to pixelsteps requested. if self.params.get('seed') == '': @@ -55,7 +57,7 @@ def extract_params(self): # It then crashes in self.gen_payload["seed"] += self.seed_variation trying to None + Int if self.seed is None: self.seed = self.seed_to_int(self.seed) - logger.debug(self.params) + # logger.debug(self.params) # logger.debug([self.prompt,self.params['width'],self.params['sampler_name']]) self.things = self.width * self.height * self.get_accurate_steps() self.total_usage = round(self.things * self.n / thing_divisor,2) @@ -119,6 +121,8 @@ def get_pop_payload(self, procgen): prompt_payload["source_processing"] = self.source_processing if self.source_mask: prompt_payload["source_mask"] = self.source_mask + if procgen.worker.bridge_version >= 8 and self.r2: + prompt_payload["r2_upload"] = generate_upload_url(str(procgen.id)) else: prompt_payload = {} self.faulted = True @@ -182,7 +186,7 @@ def requires_upfront_kudos(self, counted_totals): if max_res < 576: max_res = 576 # SD 2.0 requires at least 768 to do its thing - if max_res < 768 and len(self.models) > 1 and "stable_diffusion_2.0" in self.models: + if max_res < 768 and len(self.models) >= 1 and "stable_diffusion_2." in self.models: max_res = 768 if max_res > 1024: max_res = 1024 diff --git a/horde/countermeasures.py b/horde/countermeasures.py index 3b8a8bee..a22ae44f 100644 --- a/horde/countermeasures.py +++ b/horde/countermeasures.py @@ -70,7 +70,7 @@ def is_ip_safe(ipaddr): if probability == int(os.getenv("IP_CHECKER_LC")): is_safe = CounterMeasures.set_safe(ipaddr,True) else: - is_safe = CounterMeasures.set_safe(ipaddr,False) + is_safe = CounterMeasures.set_safe(ipaddr,True) # True until I can improve my load logger.error(f"An error occured while validating IP. Return Code: {result.text}") else: probability = float(result.content) diff --git a/horde/database/__init__.py b/horde/database/__init__.py index 869a8995..1cf9cf8f 100644 --- a/horde/database/__init__.py +++ b/horde/database/__init__.py @@ -5,7 +5,7 @@ quorum = Quorum(1, get_quorum) wp_list_cacher = PrimaryTimedFunction(1, store_prioritized_wp_queue, quorum=quorum) worker_cacher = PrimaryTimedFunction(25, store_worker_list, quorum=quorum) -model_cacher = PrimaryTimedFunction(2, store_available_models, quorum=quorum) +model_cacher = PrimaryTimedFunction(5, store_available_models, quorum=quorum) wp_cleaner = PrimaryTimedFunction(60, check_waiting_prompts, quorum=quorum) monthly_kudos = PrimaryTimedFunction(86400, assign_monthly_kudos, quorum=quorum) store_totals = PrimaryTimedFunction(60, store_totals, quorum=quorum) diff --git a/horde/database/classes.py b/horde/database/classes.py index d8f421a8..57bd82a7 100644 --- a/horde/database/classes.py +++ b/horde/database/classes.py @@ -37,8 +37,6 @@ def run(self): # This allows me to change the primary node on-the-fly if self.cancel: break - if self.quorum_thread: - logger.debug(self.quorum_thread.quorum) if self.quorum_thread and self.quorum_thread.quorum != horde_instance_id: time.sleep(self.interval) continue diff --git a/horde/database/functions.py b/horde/database/functions.py index 7ced5a53..27d5282b 100644 --- a/horde/database/functions.py +++ b/horde/database/functions.py @@ -172,7 +172,7 @@ def get_available_models(): def retrieve_available_models(): '''Retrieves model details from Redis cache, or from DB if cache is unavailable''' - models_ret = horde_r.get('models_cache') + models_ret = json.loads(horde_r.get('models_cache')) if models_ret is None: models_ret = get_available_models() return(models_ret) @@ -378,6 +378,13 @@ def get_sorted_wp_filtered_to_worker(worker, models_list = None, blacklist = Non WaitingPrompt.user_id == worker.user_id, ), ), + or_( + worker.bridge_version >= 8, + and_( + worker.bridge_version < 8, + WaitingPrompt.r2 == False, + ), + ), ).order_by( WaitingPrompt.extra_priority.desc(), WaitingPrompt.created.asc() diff --git a/horde/database/threads.py b/horde/database/threads.py index db713ec2..fdf19390 100644 --- a/horde/database/threads.py +++ b/horde/database/threads.py @@ -12,6 +12,8 @@ from horde.database.functions import query_prioritized_wps, get_active_workers, get_available_models, count_totals, prune_expired_stats from horde import horde_instance_id from horde.argparser import args +from horde.r2 import delete_procgen_image +from horde.argparser import args @logger.catch(reraise=True) @@ -27,6 +29,10 @@ def get_quorum(): horde_r.setex('horde_quorum', timedelta(seconds=2), horde_instance_id) logger.debug(f"Quorum retained in port {args.port} with ID {horde_instance_id}") # We return None which will make other threads sleep one iteration to ensure no other node raced us to the quorum + elif args.quorum: + horde_r.setex('horde_quorum', timedelta(seconds=2), horde_instance_id) + logger.debug(f"Forcing Pickingh Quorum n port {args.port} with ID {horde_instance_id}") + # We return None which will make other threads sleep one iteration to ensure no other node raced us to the quorum return(quorum) @@ -102,16 +108,33 @@ def check_waiting_prompts(): with HORDE.app_context(): # Cleans expired WPs expired_wps = db.session.query(WaitingPrompt).filter(WaitingPrompt.expiry < datetime.utcnow()) + expired_r_wps = expired_wps.filter(WaitingPrompt.r2 == True) + all_wp_r_id = [wp.id for wp in expired_r_wps.all()] + expired_r2_procgens = db.session.query( + ProcessingGeneration.id, + ).filter( + ProcessingGeneration.wp_id.in_(all_wp_r_id) + ).all() + # logger.debug([expired_r_wps, expired_r2_procgens]) + for procgen in expired_r2_procgens: + delete_procgen_image(str(procgen.id)) logger.info(f"Pruned {expired_wps.count()} expired Waiting Prompts") expired_wps.delete() db.session.commit() # Faults stale ProcGens - all_proc_gen = db.session.query(ProcessingGeneration).filter(ProcessingGeneration.generation is None).filter().all() + all_proc_gen = db.session.query( + ProcessingGeneration, + ).join( + WaitingPrompt, + ).filter( + ProcessingGeneration.generation is None, + ProcessingGeneration.faulted == False, + # datetime.utcnow() - ProcessingGeneration.start_time > WaitingPrompt.job_ttl, # How do we calculate this in the query? Maybe I need to set an expiry time iun procgen as well better? + ).all() for proc_gen in all_proc_gen: - proc_gen = proc_gen.Join(WaitingPrompt, WaitingPrompt.id == ProcessingGeneration.wp_id).filter(WaitingPrompt.faulted == False).filter(ProcessingGeneration.faulted == False) - if proc_gen.is_stale(wp.job_ttl): + if proc_gen.is_stale(proc_gen.wp.job_ttl): proc_gen.abort() - wp.n += 1 + proc_gen.wp.n += 1 db.session.commit() # Faults WP with 3 or more faulted Procgens @@ -137,7 +160,7 @@ def store_available_models(): with HORDE.app_context(): json_models = json.dumps(get_available_models()) try: - horde_r.setex('model_cache', timedelta(seconds=10), json_models) + horde_r.setex('models_cache', timedelta(seconds=10), json_models) except (TypeError, OverflowError) as e: logger.error(f"Failed serializing workers with error: {e}") diff --git a/horde/r2.py b/horde/r2.py new file mode 100644 index 00000000..9fcb9e41 --- /dev/null +++ b/horde/r2.py @@ -0,0 +1,43 @@ +import uuid +from datetime import datetime +from horde.logger import logger +import boto3 +from botocore.exceptions import ClientError + +s3_client = boto3.client('s3', endpoint_url="https://a223539ccf6caa2d76459c9727d276e6.r2.cloudflarestorage.com") + +@logger.catch(reraise=True) +def generate_presigned_url(client_method, method_parameters, expires_in): + """ + Generate a presigned Amazon S3 URL that can be used to perform an action. + + :param s3_client: A Boto3 Amazon S3 client. + :param client_method: The name of the client method that the URL performs. + :param method_parameters: The parameters of the specified client method. + :param expires_in: The number of seconds the presigned URL is valid for. + :return: The presigned URL. + """ + try: + url = s3_client.generate_presigned_url( + ClientMethod=client_method, + Params=method_parameters, + ExpiresIn=expires_in + ) + except ClientError: + logger.exception( + f"Couldn't get a presigned URL for client method {client_method}", ) + raise + # logger.debug(url) + return url + +def generate_upload_url(procgen_id): + return generate_presigned_url("put_object", {'Bucket': "stable-horde", 'Key': f"{procgen_id}.webp"}, 1800) + +def generate_download_url(procgen_id): + return generate_presigned_url("get_object", {'Bucket': "stable-horde", 'Key': f"{procgen_id}.webp"}, 1800) + +def delete_procgen_image(procgen_id): + response = s3_client.delete_object( + Bucket="stable-horde", + Key=f"{procgen_id}.webp" + ) diff --git a/icon.png b/icon.png new file mode 100644 index 00000000..ea6e62c6 Binary files /dev/null and b/icon.png differ diff --git a/requirements.txt b/requirements.txt index b282955f..d5e3be19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ pillow~=9.3.0 flask_sqlalchemy oauthlib~=3.2.2 SQLAlchemy~=1.4.44 -psycopg2-binary \ No newline at end of file +psycopg2-binary +boto3 \ No newline at end of file