stable-diffusion-for-fun/frontend.py

246 lines
7.0 KiB
Python

import argparse
import uuid
from flask import jsonify
from flask import Flask
from flask import render_template
from flask import request
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from utilities.constants import LOGGER_NAME_FRONTEND
from utilities.logger import Logger
from utilities.constants import APIKEY
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import BASE64IMAGE
from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import KEY_LANGUAGE
from utilities.constants import SUPPORTED_LANGS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_RESTORATION
from utilities.constants import IMAGE_NOT_FOUND_BASE64
from utilities.database import Database
from utilities.images import load_image
logger = Logger(name=LOGGER_NAME_FRONTEND)
database = Database(logger)
app = Flask(__name__)
limiter = Limiter(
get_remote_address,
app=app,
)
@app.route("/add_job", methods=["POST"])
@limiter.limit("4/second")
def add_job():
req = request.get_json()
if APIKEY not in req:
logger.error(f"{APIKEY} not present in {req}")
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
logger.error(f"user not found with {req[APIKEY]}")
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
# only checks required key for non-restoration jobs
if req.get(KEY_JOB_TYPE, None) != VALUE_JOB_RESTORATION:
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if (
req[KEY_JOB_TYPE]
in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION]
and REFERENCE_IMG not in req
):
return jsonify({"msg": "missing reference image"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING and MASK_IMG not in req:
return jsonify({"msg": "missing mask image"}), 404
if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS:
return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
return (
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
500,
)
job_uuid = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(job_uuid))
database.insert_new_job(req, job_uuid=job_uuid)
return jsonify({"msg": "", UUID: job_uuid})
@app.route("/cancel_job", methods=["POST"])
@limiter.limit("4/second")
def cancel_job():
req = request.get_json()
if APIKEY not in req:
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("cancelling job with uuid {}..".format(req[UUID]))
result = database.cancel_job(job_uuid=req[UUID], apikey=req[APIKEY])
if result:
msg = "your job with uuid {} removed".format(req[UUID])
return jsonify({"msg": msg})
jobs = database.get_jobs(job_uuid=req[UUID])
if jobs:
return (
jsonify(
{
"msg": "your job {} is not in pending state, unable to cancel".format(
req[UUID]
)
}
),
405,
)
return (
jsonify({"msg": "unable to find your job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
@limiter.limit("4/second")
def get_jobs():
req = request.get_json()
if APIKEY not in req:
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
# define max number of jobs to fetch from db
job_count_limit = 20
if UUID in req:
jobs = database.get_jobs(
job_uuid=req[UUID],
apikey=req[APIKEY],
job_types=req[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [],
limit_count=job_count_limit,
)
else:
jobs = database.get_jobs(
apikey=req[APIKEY],
job_types=req[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [],
limit_count=job_count_limit,
)
for job in jobs:
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
return jsonify({"jobs": jobs})
@app.route("/random_jobs", methods=["GET"])
@limiter.limit("4/second")
def random_jobs():
# define max number of jobs to fetch from db
job_count_limit = 20
jobs = database.get_random_jobs(limit_count=job_count_limit)
for job in jobs:
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
return jsonify({"jobs": jobs})
@app.route("/")
@limiter.limit("1/second")
def index():
return render_template("index.html")
@app.route("/restoration")
@limiter.limit("1/second")
def restoration():
return render_template("restoration.html")
def main(args):
database.set_image_output_folder(args.image_output_folder)
database.connect(args.db)
app.config["TITLE"] = args.title
app.run(host="0.0.0.0", port=args.port)
database.safe_disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'debug' flag
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
# Add an argument to set the title of service
parser.add_argument(
"--title", type=str, default="Happy Diffusion", help="Title of the webpage"
)
# Add an argument to set the path of the database file
parser.add_argument(
"--image-output-folder",
"-o",
type=str,
default="",
help="Path to output images",
)
# Add an argument to set the port
parser.add_argument(
"--port",
type=str,
default="8888",
help="Port to expose the service",
)
args = parser.parse_args()
main(args)