Skip to content

Commit

Permalink
R2 (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 authored Dec 8, 2022
1 parent 3ff8d32 commit 99784ea
Show file tree
Hide file tree
Showing 21 changed files with 132 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ SQL_statements.txt
horde.log
horde.db
/.idea
/boto3oeo.py
1 change: 1 addition & 0 deletions cli_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_logger():
],
}
logger.configure(**config)
logger.add("cliRequests.log", retention="7 days", level=19)
logger.disable("__main__")
logger.warning("disabled")
logger.enable("")
Expand Down
21 changes: 15 additions & 6 deletions cli_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self):
"nsfw": False,
"censor_nsfw": False,
"trusted_workers": False,
"models": ["stable_diffusion"]
"models": ["stable_diffusion"],
"r2": True
}
self.source_image = None
self.source_processing = "img2img"
Expand Down Expand Up @@ -160,14 +161,22 @@ def generate():
return
results = results_json['generations']
for iter in range(len(results)):
b64img = results[iter]["img"]
base64_bytes = b64img.encode('utf-8')
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
final_filename = request_data.filename
if len(results) > 1:
final_filename = f"{iter}_{request_data.filename}"
img.save(final_filename)
if request_data.get_submit_dict()["r2"]:
try:
img_data = requests.get(results[iter]["img"]).content
except:
logger.error("Received b64 again")
with open(final_filename, 'wb') as handler:
handler.write(img_data)
else:
b64img = results[iter]["img"]
base64_bytes = b64img.encode('utf-8')
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
img.save(final_filename)
logger.info(f"Saved {final_filename}")
else:
logger.error(submit_req.text)
Expand Down
3 changes: 3 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self):
self.generate_parser.add_argument("source_processing", type=str, default="img2img", required=False, help="If source_image is provided, specifies how to process it.", location="json")
self.generate_parser.add_argument("source_mask", type=str, required=False, help="If img_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel", location="json")
self.generate_parser.add_argument("models", type=list, required=False, default=['stable_diffusion'], help="The acceptable models with which to generate", location="json")
self.generate_parser.add_argument("r2", type=bool, default=False, required=False, help="If True, the image will be sent via cloudflare r2 download link", location="json")
self.job_pop_parser.add_argument("max_pixels", type=int, required=False, default=512*512, help="The maximum amount of pixels this worker can generate", location="json")
self.job_pop_parser.add_argument("allow_img2img", type=bool, required=False, default=True, help="If True, this worker will pick up img2img requests", location="json")
self.job_pop_parser.add_argument("allow_painting", type=bool, required=False, default=True, help="If True, this worker will pick up inpainting/outpaining requests", location="json")
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(self,api):
'source_image': fields.String(description="The Base64-encoded webp to use for img2img"),
'source_processing': fields.String(required=False, default='img2img',enum=["img2img", "inpainting", "outpainting"], description="If source_image is provided, specifies how to process it."),
'source_mask': fields.String(description="If img_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel"),
'r2_upload': fields.String(description="The r2 upload link to use to upload this image"),
})
self.input_model_job_pop = api.inherit('PopInputStable', self.input_model_job_pop, {
'max_pixels': fields.Integer(default=512*512,description="The maximum amount of pixels this worker can generate"),
Expand All @@ -83,6 +85,7 @@ def __init__(self,api):
'source_image': fields.String(required=False, description="The Base64-encoded webp to use for img2img"),
'source_processing': fields.String(required=False, default='img2img',enum=["img2img", "inpainting", "outpainting"], description="If source_image is provided, specifies how to process it."),
'source_mask': fields.String(description="If source_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the Base64-encoded webp mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel"),
'r2': fields.Boolean(default=False, description="If True, the image will be sent via cloudflare r2 download link"),
})
self.response_model_worker_details = api.inherit('WorkerDetailsStable', self.response_model_worker_details, {
"max_pixels": fields.Integer(example=262144,description="The maximum pixels in resolution this worker can generate"),
Expand Down
6 changes: 4 additions & 2 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def validate(self):
if self.args.source_image:
if self.args.source_processing == "img2img" and self.params.get("sampler_name") in ["k_dpm_fast", "k_dpm_adaptive", "k_dpmpp_2s_a", "k_dpmpp_2m"]:
raise e.UnsupportedSampler
if "stable_diffusion_2.0" in self.args.models:
if any(model_name.startswith("stable_diffusion_2") for model_name in self.args.models):
raise e.UnsupportedModel
if self.args.models != ["stable_diffusion_2.0"] and self.params.get("sampler_name") in ["dpmsolver"]:
if not any(model_name.startswith("stable_diffusion_2") for model_name in self.args.models) and self.params.get("sampler_name") in ["dpmsolver"]:
raise e.UnsupportedSampler
# if self.args.models == ["stable_diffusion_2.0"] and self.params.get("sampler_name") not in ["dpmsolver"]:
# raise e.UnsupportedSampler
Expand Down Expand Up @@ -105,6 +105,7 @@ def initiate_waiting_prompt(self):
source_mask = convert_source_image_to_webp(self.args.source_mask),
ipaddr = self.user_ip,
safe_ip=self.safe_ip,
r2=self.args.r2,
)
needs_kudos,resolution = self.wp.requires_upfront_kudos(database.retrieve_totals())
if needs_kudos:
Expand Down Expand Up @@ -149,6 +150,7 @@ def initiate_waiting_prompt(self):
source_mask = convert_source_image_to_webp(self.args.source_mask),
ipaddr = self.user_ip,
safe_ip=self.safe_ip,
r2=self.args.r2,
)
needs_kudos,resolution = self.wp.requires_upfront_kudos(database.retrieve_totals())
if needs_kudos:
Expand Down
3 changes: 2 additions & 1 deletion horde/apis/v2/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def post(self):
# self.prioritized_wp.append(wp)
# logger.warning(datetime.utcnow())
## End prioritize by bridge request ##
for wp in self.get_sorted_wp(): # TODO this should also filter on .n>0
for wp in self.get_sorted_wp():
if wp not in self.prioritized_wp:
self.prioritized_wp.append(wp)
# logger.warning(datetime.utcnow())
Expand All @@ -412,6 +412,7 @@ def post(self):
# logger.debug(worker_ret)
if worker_ret is None:
continue
# logger.debug(worker_ret)
return(worker_ret, 200)
# We report maintenance exception only if we couldn't find any jobs
if self.worker.maintenance:
Expand Down
1 change: 1 addition & 0 deletions horde/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
arg_parser.add_argument('--worker_invite', action="store_true", help="If set, Will start the horde in worker invite-only mode")
arg_parser.add_argument('--raid', action="store_true", help="If set, Will start the horde in raid prevention mode")
arg_parser.add_argument('--allow_all_ips', action="store_true", help="If set, will consider all IPs safe")
arg_parser.add_argument('--quorum', action="store_true", help="If set, will forcefully grab the quorum")
args = arg_parser.parse_args()

maintenance = Switch()
Expand Down
6 changes: 3 additions & 3 deletions horde/classes/base/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ def reset_suspicion(self):
'''Clears the user's suspicion and resets their reasons'''
if self.is_anon():
return
#TODO Select from UserSuspicions DB and delete all matching user ID
db.session.query(UserSuspicions).filter_by(user_id=self.id).delete()
db.session.commit()
for worker in self.workers:
worker.reset_suspicion()

def get_suspicion(self):
return(db.session.query(UserSuspicions).filter(user_id=self.id).count())
return(db.session.query(UserSuspicions).filter_by(user_id=self.id).count())

def count_workers(self):
return(len(self.workers))
Expand Down Expand Up @@ -374,7 +374,7 @@ def get_details(self, details_privilege = 0):
}
ret_dict["evaluating_kudos"] = self.evaluating_kudos
ret_dict["monthly_kudos"] = mk_dict
ret_dict["suspicious"] = self.suspicious
ret_dict["suspicious"] = len(self.suspicions)
return(ret_dict)


Expand Down
1 change: 0 additions & 1 deletion horde/classes/base/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def set_workers(self, worker_ids = None):
def set_models(self, model_names = None):
if not model_names: model_names = []
# We don't allow more workers to claim they can server more than 50 models atm (to prevent abuse)
logger.debug(model_names)
for model in model_names:
model_entry = WPModels(model=model,wp_id=self.id)
db.session.add(model_entry)
Expand Down
7 changes: 1 addition & 6 deletions horde/classes/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def report_suspicion(self, amount = 1, reason = Suspicions.WORKER_PROFANITY, for

def reset_suspicion(self):
'''Clears the worker's suspicion and resets their reasons'''
#TODO Select from WorkerSuspicions DB and delete all matching user ID
db.session.query(WorkerSuspicions).filter_by(worker_id=self.id).delete()
db.session.commit()

def get_suspicion(self):
Expand Down Expand Up @@ -269,11 +269,6 @@ def can_generate(self, waiting_prompt):
# We don't consider stale workers in the request, so we don't need to report a reason
is_matching = False
return([is_matching,skipped_reason])
# If the request specified only specific workers to fulfill it, and we're not one of them, we skip
#logger.warning(datetime.utcnow())
if len(waiting_prompt.workers) >= 1 and self not in waiting_prompt.workers:
is_matching = False
skipped_reason = 'worker_id'
#logger.warning(datetime.utcnow())
if waiting_prompt.nsfw and not self.nsfw:
is_matching = False
Expand Down
6 changes: 6 additions & 0 deletions horde/classes/stable/news.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

class NewsExtended(News):


STABLE_HORDE_NEWS = [
{
"date_published": "2022-12-08",
"newspiece": "The Stable Horde workers now support dynamically swapping models. This means that models will always switch to support the most in demand models every minute, allowing us to support demand much better!",
"importance": "Information"
},
{
"date_published": "2022-11-28",
"newspiece": "The Horde has undertaken a massive code refactoring to allow me to move to a proper SQL DB. This will finally allow me to scale the frontend systems horizontally and allow for way more capacity!",
Expand Down
6 changes: 5 additions & 1 deletion horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from horde.logger import logger
from horde.classes.base.processing_generation import ProcessingGeneration
from horde.r2 import generate_download_url


class ProcessingGenerationExtended(ProcessingGeneration):

def get_details(self):
'''Returns a dictionary with details about this processing generation'''
generation = self.generation
if generation == "R2":
generation = generate_download_url(str(self.id))
ret_dict = {
"img": self.generation,
"img": generation,
"seed": self.seed,
"worker_id": self.worker.id,
"worker_name": self.worker.name,
Expand Down
10 changes: 7 additions & 3 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from horde.flask import db
from horde.utils import get_random_seed
from horde.classes.base.waiting_prompt import WaitingPrompt
from horde.r2 import generate_upload_url


class WaitingPromptExtended(WaitingPrompt):
Expand All @@ -15,6 +16,7 @@ class WaitingPromptExtended(WaitingPrompt):
seed = db.Column(db.BigInteger, default=None, nullable=True)
seed_variation = db.Column(db.Integer, default=None)
kudos = db.Column(db.Float, default=0, nullable=False)
r2 = db.Column(db.Boolean, default=False, nullable=False)

@logger.catch(reraise=True)
def extract_params(self):
Expand All @@ -40,7 +42,7 @@ def extract_params(self):
self.width = self.params["width"]
self.height = self.params["height"]
# Silent change
if self.get_model_names() == ["stable_diffusion_2.0"]:
if any(model_name.startswith("stable_diffusion_2") for model_name in self.get_model_names()):
self.params['sampler_name'] = "dpmsolver"
# The total amount of to pixelsteps requested.
if self.params.get('seed') == '':
Expand All @@ -55,7 +57,7 @@ def extract_params(self):
# It then crashes in self.gen_payload["seed"] += self.seed_variation trying to None + Int
if self.seed is None:
self.seed = self.seed_to_int(self.seed)
logger.debug(self.params)
# logger.debug(self.params)
# logger.debug([self.prompt,self.params['width'],self.params['sampler_name']])
self.things = self.width * self.height * self.get_accurate_steps()
self.total_usage = round(self.things * self.n / thing_divisor,2)
Expand Down Expand Up @@ -119,6 +121,8 @@ def get_pop_payload(self, procgen):
prompt_payload["source_processing"] = self.source_processing
if self.source_mask:
prompt_payload["source_mask"] = self.source_mask
if procgen.worker.bridge_version >= 8 and self.r2:
prompt_payload["r2_upload"] = generate_upload_url(str(procgen.id))
else:
prompt_payload = {}
self.faulted = True
Expand Down Expand Up @@ -182,7 +186,7 @@ def requires_upfront_kudos(self, counted_totals):
if max_res < 576:
max_res = 576
# SD 2.0 requires at least 768 to do its thing
if max_res < 768 and len(self.models) > 1 and "stable_diffusion_2.0" in self.models:
if max_res < 768 and len(self.models) >= 1 and "stable_diffusion_2." in self.models:
max_res = 768
if max_res > 1024:
max_res = 1024
Expand Down
2 changes: 1 addition & 1 deletion horde/countermeasures.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def is_ip_safe(ipaddr):
if probability == int(os.getenv("IP_CHECKER_LC")):
is_safe = CounterMeasures.set_safe(ipaddr,True)
else:
is_safe = CounterMeasures.set_safe(ipaddr,False)
is_safe = CounterMeasures.set_safe(ipaddr,True) # True until I can improve my load
logger.error(f"An error occured while validating IP. Return Code: {result.text}")
else:
probability = float(result.content)
Expand Down
2 changes: 1 addition & 1 deletion horde/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
quorum = Quorum(1, get_quorum)
wp_list_cacher = PrimaryTimedFunction(1, store_prioritized_wp_queue, quorum=quorum)
worker_cacher = PrimaryTimedFunction(25, store_worker_list, quorum=quorum)
model_cacher = PrimaryTimedFunction(2, store_available_models, quorum=quorum)
model_cacher = PrimaryTimedFunction(5, store_available_models, quorum=quorum)
wp_cleaner = PrimaryTimedFunction(60, check_waiting_prompts, quorum=quorum)
monthly_kudos = PrimaryTimedFunction(86400, assign_monthly_kudos, quorum=quorum)
store_totals = PrimaryTimedFunction(60, store_totals, quorum=quorum)
Expand Down
2 changes: 0 additions & 2 deletions horde/database/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def run(self):
# This allows me to change the primary node on-the-fly
if self.cancel:
break
if self.quorum_thread:
logger.debug(self.quorum_thread.quorum)
if self.quorum_thread and self.quorum_thread.quorum != horde_instance_id:
time.sleep(self.interval)
continue
Expand Down
9 changes: 8 additions & 1 deletion horde/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_available_models():

def retrieve_available_models():
'''Retrieves model details from Redis cache, or from DB if cache is unavailable'''
models_ret = horde_r.get('models_cache')
models_ret = json.loads(horde_r.get('models_cache'))
if models_ret is None:
models_ret = get_available_models()
return(models_ret)
Expand Down Expand Up @@ -378,6 +378,13 @@ def get_sorted_wp_filtered_to_worker(worker, models_list = None, blacklist = Non
WaitingPrompt.user_id == worker.user_id,
),
),
or_(
worker.bridge_version >= 8,
and_(
worker.bridge_version < 8,
WaitingPrompt.r2 == False,
),
),
).order_by(
WaitingPrompt.extra_priority.desc(),
WaitingPrompt.created.asc()
Expand Down
Loading

0 comments on commit 99784ea

Please sign in to comment.