Skip to content

Commit

Permalink
Migrate to an SQL DB ORM (#87)
Browse files Browse the repository at this point in the history
Refactored to an ORM
  • Loading branch information
db0 authored Dec 5, 2022
1 parent d2c38cc commit 3ff8d32
Show file tree
Hide file tree
Showing 49 changed files with 3,895 additions and 3,016 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,7 @@ db/*
cliRequestsData.py
*horde_generation*
test_commands.txt
horde.log
SQL_statements.txt
horde.log
horde.db
/.idea
26 changes: 18 additions & 8 deletions cli_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from cli_logger import logger, set_logger_verbosity, quiesce_logger, test_logger
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from io import BytesIO
from requests.exceptions import ConnectionError

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-n', '--amount', action="store", required=False, type=int, help="The amount of images to generate with this prompt")
Expand Down Expand Up @@ -127,15 +128,24 @@ def generate():
logger.debug(submit_results)
req_id = submit_results['id']
is_done = False
retry = 0
while not is_done:
chk_req = requests.get(f'{args.horde}/api/v2/generate/check/{req_id}')
if not chk_req.ok:
logger.error(chk_req.text)
return
chk_results = chk_req.json()
logger.info(chk_results)
is_done = chk_results['done']
time.sleep(0.8)
try:
chk_req = requests.get(f'{args.horde}/api/v2/generate/check/{req_id}')
if not chk_req.ok:
logger.error(chk_req.text)
return
chk_results = chk_req.json()
logger.info(chk_results)
is_done = chk_results['done']
time.sleep(0.8)
except ConnectionError as e:
retry += 1
logger.error(f"Error {e} when retrieving status. Retry {retry}/10")
if retry < 10:
time.sleep(1)
continue
raise e
retrieve_req = requests.get(f'{args.horde}/api/v2/generate/status/{req_id}')
if not retrieve_req.ok:
logger.error(retrieve_req.text)
Expand Down
62 changes: 27 additions & 35 deletions horde/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,33 @@
from .logger import logger, set_logger_verbosity, quiesce_logger
from .argparser import args
import os
from uuid import uuid4

set_logger_verbosity(args.verbosity)
quiesce_logger(args.quiet)
horde_instance_id = str(uuid4())

from . import countermeasures as cm
from flask_dance.contrib.discord import make_discord_blueprint
from flask_dance.contrib.github import make_github_blueprint
from flask_dance.contrib.google import make_google_blueprint

from .switch import Switch
maintenance = Switch()
invite_only = Switch()
if args.worker_invite:
invite_only.activate()
raid = Switch()
if args.raid:
raid.activate()
from horde.routes import * # I don't like this, we should be refactoring what things are being loaded
from horde.apis import apiv1, apiv2
from horde.argparser import args, invite_only, raid, maintenance
from horde.flask import HORDE, cache
from horde.logger import logger

from .limiter import limiter
from flask import Flask, render_template, redirect, url_for, request, Blueprint
from .flask import HORDE, cache
from . import routes
from .apis import apiv1, apiv2
from flask_dance.contrib.google import make_google_blueprint, google
from flask_dance.contrib.discord import make_discord_blueprint, discord
from flask_dance.contrib.github import make_github_blueprint, github
import os
from horde.limiter import limiter

HORDE.register_blueprint(apiv2)
if args.horde == 'kobold':
HORDE.register_blueprint(apiv1)


@HORDE.after_request
def after_request(response):
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS, PUT, DELETE"
response.headers["Access-Control-Allow-Headers"] = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, apikey"
return response


google_client_id = os.getenv("GOOGLE_CLIENT_ID")
google_client_secret = os.getenv("GLOOGLE_CLIENT_SECRET")
discord_client_id = os.getenv("DISCORD_CLIENT_ID")
Expand All @@ -45,24 +37,24 @@ def after_request(response):
HORDE.secret_key = os.getenv("secret_key")
os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1'
google_blueprint = make_google_blueprint(
client_id = google_client_id,
client_secret = google_client_secret,
reprompt_consent = True,
client_id=google_client_id,
client_secret=google_client_secret,
reprompt_consent=True,
redirect_url='/register',
scope = ["email"],
scope=["email"],
)
HORDE.register_blueprint(google_blueprint,url_prefix="/google")
HORDE.register_blueprint(google_blueprint, url_prefix="/google")
discord_blueprint = make_discord_blueprint(
client_id = discord_client_id,
client_secret = discord_client_secret,
scope = ["identify"],
client_id=discord_client_id,
client_secret=discord_client_secret,
scope=["identify"],
redirect_url='/finish_dance',
)
HORDE.register_blueprint(discord_blueprint,url_prefix="/discord")
HORDE.register_blueprint(discord_blueprint, url_prefix="/discord")
github_blueprint = make_github_blueprint(
client_id = github_client_id,
client_secret = github_client_secret,
scope = ["identify"],
client_id=github_client_id,
client_secret=github_client_secret,
scope=["identify"],
redirect_url='/finish_dance',
)
HORDE.register_blueprint(github_blueprint,url_prefix="/github")
HORDE.register_blueprint(github_blueprint, url_prefix="/github")
3 changes: 2 additions & 1 deletion horde/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .. import args
from horde.argparser import args
from importlib import import_module
from horde.logger import logger

ModelsV2 = import_module(name=f'horde.apis.models.{args.horde}_v2').Models
ParsersV2 = import_module(name=f'horde.apis.models.{args.horde}_v2').Parsers
Expand Down
4 changes: 2 additions & 2 deletions horde/apis/apiv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ..vars import horde_title


v1 = import_module(name=f'.{args.horde}_v1', package=f'horde.apis.v1').api
v2 = import_module(name=f'.{args.horde}', package=f'horde.apis.v2').api
v1 = import_module(name=f'.{args.horde}_v1', package='horde.apis.v1').api
v2 = import_module(name=f'.{args.horde}', package='horde.apis.v2').api

blueprint = Blueprint('apiv1', __name__, url_prefix='/api')
api = Api(blueprint,
Expand Down
4 changes: 2 additions & 2 deletions horde/apis/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ..vars import horde_title

if args.horde == 'kobold':
v1 = import_module(name=f'.{args.horde}_v1', package=f'horde.apis.v1').api
v2 = import_module(name=f'.{args.horde}', package=f'horde.apis.v2').api
v1 = import_module(name=f'.{args.horde}_v1', package='horde.apis.v1').api
v2 = import_module(name=f'.{args.horde}', package='horde.apis.v2').api

blueprint = Blueprint('apiv2', __name__, url_prefix='/api')
api = Api(blueprint,
Expand Down
6 changes: 3 additions & 3 deletions horde/apis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def __init__(self, username, old_name, new_name, object_type = 'worker'):
class ImageValidationFailed(wze.BadRequest):
def __init__(self, message = "Please ensure the source image payload for img2img is a valid base64 encoded image."):
self.specific = f"Image validation failed. {message}"
self.log = f"Source image validation failed for img2img"
self.log = "Source image validation failed for img2img"

class SourceMaskUnnecessary(wze.BadRequest):
def __init__(self):
self.specific = f"Please do not pass a source_mask unless you are sending a source_image as well"
self.log = f"Tried to pass source_mask with txt2img"
self.log = "Tried to pass source_mask with txt2img"

class UnsupportedSampler(wze.BadRequest):
def __init__(self):
Expand All @@ -68,7 +68,7 @@ def __init__(self):

class UnsupportedModel(wze.BadRequest):
def __init__(self):
self.specific = f"This model is not supported in this mode the moment"
self.specific = "This model is not supported in this mode the moment"
self.log = None

class InvalidAPIKey(wze.Unauthorized):
Expand Down
2 changes: 1 addition & 1 deletion horde/apis/models/kobold_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flask_restx import fields, reqparse
from flask_restx import fields
from . import v2


Expand Down
13 changes: 7 additions & 6 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask_restx import fields, reqparse
from flask_restx import fields
from . import v2
from horde.logger import logger


class Parsers(v2.Parsers):
Expand Down Expand Up @@ -31,10 +32,10 @@ def __init__(self,api):
'sampler_name': fields.String(required=False, default='k_euler_a',enum=["k_lms", "k_heun", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpm_fast", "k_dpm_adaptive", "k_dpmpp_2s_a", "k_dpmpp_2m", "dpmsolver"]),
'toggles': fields.List(fields.Integer,required=False, example=[1,4], description="Obsolete Toggles used in the SD Webui. To be removed. Do not modify unless you know what you're doing."),
'cfg_scale': fields.Float(required=False,default=5.0, min=-40, max=30, multiple=0.5),
'denoising_strength': fields.Float(required=False,example=0.75, min=0, max=1.0, multiple=0.01),
'denoising_strength': fields.Float(required=False,example=0.75, min=0, max=1.0),
'seed': fields.String(required=False,description="The seed to use to generete this request"),
'height': fields.Integer(required=False,default=512,description="The height of the image to generate", min=64, max=3072, multiple=64),
'width': fields.Integer(required=False,default=512,description="The width of the image to generate", min=64, max=3072, multiple=64),
'height': fields.Integer(required=False, default=512, description="The height of the image to generate", min=64, max=3072, multiple=64),
'width': fields.Integer(required=False, default=512, description="The width of the image to generate", min=64, max=3072, multiple=64),
'seed_variation': fields.Integer(required=False, example=1, min = 1, max=1000, description="If passed with multiple n, the provided seed will be incremented every time by this value"),
'post_processing': fields.List(fields.String(description="The list of post-processors to apply to the image, in the order to be applied",enum=["GFPGAN", "RealESRGAN_x4plus"]),unique=True),
'karras': fields.Boolean(default=False,description="Set to True to enable karras noise scheduling tweaks"),
Expand All @@ -46,8 +47,8 @@ def __init__(self,api):
'use_nsfw_censor': fields.Boolean(description="When true will apply NSFW censoring model on the generation"),
})
self.input_model_generation_payload = api.inherit('ModelGenerationInputStable', self.root_model_generation_payload_stable, {
'steps': fields.Integer(example=50, min = 1, max=500),
'n': fields.Integer(example=1, description="The amount of images to generate", min = 1, max=20),
'steps': fields.Integer(default=30, required=False, min = 1, max=500),
'n': fields.Integer(default=1, required=False, description="The amount of images to generate", min = 1, max=20),
})
self.response_model_generations_skipped = api.inherit('NoValidRequestFoundStable', self.response_model_generations_skipped, {
'max_pixels': fields.Integer(description="How many waiting requests were skipped because they demanded a higher size than this worker provides"),
Expand Down
2 changes: 1 addition & 1 deletion horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(self,api):
"worker_invited": fields.Integer(description="Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode."),
"moderator": fields.Boolean(example=False,description="Set to true to Make this user a horde moderator"),
"public_workers": fields.Boolean(example=False,description="Set to true to Make this user a display their worker IDs"),
"monthly_kudos": fields.Integer(description="When specified, will start assigning the user monthly kudos, starting now!",min=0),
"monthly_kudos": fields.Integer(description="When specified, will start assigning the user monthly kudos, starting now!"),
"username": fields.String(description="When specified, will change the username. No profanity allowed!",min_length=3,max_length=100),
"trusted": fields.Boolean(example=False,description="When set to true,the user and their servers will not be affected by suspicion"),
"reset_suspicion": fields.Boolean(description="Set the user's suspicion back to 0"),
Expand Down
Loading

0 comments on commit 3ff8d32

Please sign in to comment.