From 7e84144432a0de39035eca35875b8822f3299a81 Mon Sep 17 00:00:00 2001 From: HappyZ Date: Sun, 7 May 2023 21:43:14 -0700 Subject: [PATCH] separate frontend and backend, adds multilingual support (chinese), improve multiaccess to database --- BUILD | 24 +++- main.py | 283 ----------------------------------------- manage_db.py | 7 +- requirements.txt | 2 + templates/index.html | 154 ++++++++++++++-------- utilities/BUILD | 5 + utilities/constants.py | 22 +++- utilities/database.py | 186 +++++++++++++-------------- utilities/text2img.py | 4 +- 9 files changed, 240 insertions(+), 447 deletions(-) delete mode 100644 main.py diff --git a/BUILD b/BUILD index e400d06..a678067 100644 --- a/BUILD +++ b/BUILD @@ -4,19 +4,31 @@ load("@subpar//:subpar.bzl", "par_binary") package(default_visibility=["//visibility:public"]) par_binary( - name="main", - srcs=["main.py"], + name="frontend", + srcs=["frontend.py"], deps=[ "//utilities:constants", "//utilities:database", "//utilities:logger", - "//utilities:model", - "//utilities:text2img", - "//utilities:img2img", - "//utilities:envvar", + "//utilities:times", ], data=[ "templates/index.html", ], ) + +par_binary( + name="backend", + srcs=["backend.py"], + deps=[ + "//utilities:constants", + "//utilities:database", + "//utilities:logger", + "//utilities:model", + "//utilities:text2img", + "//utilities:translator", + "//utilities:img2img", + "//utilities:times", + ], +) diff --git a/main.py b/main.py deleted file mode 100644 index e4b7a4b..0000000 --- a/main.py +++ /dev/null @@ -1,283 +0,0 @@ -import argparse -import copy -import tempfile -import pkgutil -import uuid -from flask import jsonify -from flask import Flask -from flask import render_template -from flask import request -from threading import Event -from threading import Thread -from threading import Lock - -from utilities.constants import APIKEY -from utilities.constants import KEY_APP -from utilities.constants import KEY_JOB_STATUS -from utilities.constants import KEY_JOB_TYPE -from utilities.constants import KEY_PROMPT -from utilities.constants import KEY_NEG_PROMPT -from utilities.constants import LOGGER_NAME -from utilities.constants import LOGGER_NAME_IMG2IMG -from utilities.constants import LOGGER_NAME_TXT2IMG -from utilities.constants import REFERENCE_IMG -from utilities.constants import MAX_JOB_NUMBER -from utilities.constants import OPTIONAL_KEYS -from utilities.constants import REQUIRED_KEYS -from utilities.constants import UUID -from utilities.constants import VALUE_APP -from utilities.constants import VALUE_JOB_TXT2IMG -from utilities.constants import VALUE_JOB_IMG2IMG -from utilities.constants import VALUE_JOB_INPAINTING -from utilities.constants import VALUE_JOB_PENDING -from utilities.constants import VALUE_JOB_RUNNING -from utilities.constants import VALUE_JOB_DONE -from utilities.constants import VALUE_JOB_FAILED -from utilities.database import Database -from utilities.envvar import get_env_var_with_default -from utilities.envvar import get_env_var -from utilities.times import wait_for_seconds -from utilities.logger import Logger -from utilities.model import Model -from utilities.config import Config -from utilities.text2img import Text2Img -from utilities.img2img import Img2Img - - -app = Flask(__name__) -memory_lock = Lock() -event_termination = Event() -logger = Logger(name=LOGGER_NAME) -database = Database(logger) -use_gpu = True - -local_job_stack = [] -local_completed_jobs = [] - - -@app.route("/add_job", methods=["POST"]) -def add_job(): - req = request.get_json() - - if APIKEY not in req: - logger.error(f"{APIKEY} not present in {req}") - return "", 401 - with memory_lock: - user = database.validate_user(req[APIKEY]) - if not user: - logger.error(f"user not found with {req[APIKEY]}") - return "", 401 - - for key in req.keys(): - if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS): - return jsonify({"msg": "provided one or more unrecognized keys"}), 404 - for required_key in REQUIRED_KEYS: - if required_key not in req: - return jsonify({"msg": "missing one or more required keys"}), 404 - - if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req: - return jsonify({"msg": "missing reference image"}), 404 - - if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER: - return ( - jsonify({"msg": "too many jobs in queue, please wait or cancel some"}), - 500, - ) - - job_uuid = str(uuid.uuid4()) - logger.info("adding a new job with uuid {}..".format(job_uuid)) - - with memory_lock: - database.insert_new_job(req, job_uuid=job_uuid) - - return jsonify({"msg": "", UUID: job_uuid}) - - -@app.route("/cancel_job", methods=["POST"]) -def cancel_job(): - req = request.get_json() - if APIKEY not in req: - return "", 401 - with memory_lock: - user = database.validate_user(req[APIKEY]) - if not user: - return "", 401 - - if UUID not in req: - return jsonify({"msg": "missing uuid"}), 404 - - logger.info("cancelling job with uuid {}..".format(req[UUID])) - - with memory_lock: - result = database.cancel_job(job_uuid=req[UUID]) - - if result: - msg = "job with uuid {} removed".format(req[UUID]) - return jsonify({"msg": msg}) - - with memory_lock: - jobs = database.get_jobs(job_uuid=req[UUID]) - - if jobs: - return ( - jsonify( - { - "msg": "job {} is not in pending state, unable to cancel".format( - req[UUID] - ) - } - ), - 405, - ) - - return ( - jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}), - 404, - ) - - -@app.route("/get_jobs", methods=["POST"]) -def get_jobs(): - req = request.get_json() - if APIKEY not in req: - return "", 401 - with memory_lock: - user = database.validate_user(req[APIKEY]) - if not user: - return "", 401 - - with memory_lock: - jobs = database.get_jobs(job_uuid=req[UUID]) - - return jsonify({"jobs": jobs}) - - -@app.route("/") -def index(): - return render_template("index.html") - - -def load_model(logger: Logger) -> Model: - # model candidates: - # "runwayml/stable-diffusion-v1-5" - # "CompVis/stable-diffusion-v1-4" - # "stabilityai/stable-diffusion-2-1" - # "SG161222/Realistic_Vision_V2.0" - # "darkstorm2150/Protogen_x3.4_Official_Release" - # "darkstorm2150/Protogen_x5.8_Official_Release" - # "prompthero/openjourney" - # "naclbit/trinart_stable_diffusion_v2" - # "hakurei/waifu-diffusion" - model_name = "darkstorm2150/Protogen_x5.8_Official_Release" - # inpainting model candidates: - # "runwayml/stable-diffusion-inpainting" - inpainting_model_name = "runwayml/stable-diffusion-inpainting" - - model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) - if use_gpu: - model.set_low_memory_mode() - model.load_all() - - return model - - -def backend(event_termination, db): - model = load_model(logger) - text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG)) - img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) - - text2img.breakfast() - img2img.breakfast() - - while not event_termination.is_set(): - wait_for_seconds(1) - - with memory_lock: - pending_jobs = database.get_all_pending_jobs() - - if len(pending_jobs) == 0: - continue - - next_job = pending_jobs[0] - - with memory_lock: - database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]) - - prompt = next_job[KEY_PROMPT] - negative_prompt = next_job[KEY_NEG_PROMPT] - - config = Config().set_config(next_job) - - try: - if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG: - result_dict = text2img.lunch( - prompt=prompt, negative_prompt=negative_prompt, config=config - ) - elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG: - ref_img = next_job[REFERENCE_IMG] - result_dict = img2img.lunch( - prompt=prompt, - negative_prompt=negative_prompt, - reference_image=ref_img, - config=config, - ) - except BaseException as e: - logger.error("text2img.lunch error: {}".format(e)) - with memory_lock: - database.update_job( - {KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID] - ) - continue - - with memory_lock: - database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID]) - database.update_job(result_dict, job_uuid=next_job[UUID]) - - logger.critical("stopped") - - -def main(db_filepath, is_testing: bool = False): - database.connect(db_filepath) - - if is_testing: - try: - app.run(host="0.0.0.0", port="5000") - except KeyboardInterrupt: - pass - return - thread = Thread( - target=backend, - args=( - event_termination, - database, - ), - ) - thread.start() - # ugly solution for now - # TODO: use a database to track instead of internal memory - try: - app.run(host="0.0.0.0", port="8888") - thread.join() - except KeyboardInterrupt: - event_termination.set() - - database.safe_disconnect() - - thread.join(2) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # Add an argument to set the 'testing' flag - parser.add_argument("--testing", action="store_true", help="Enable testing mode") - - # Add an argument to set the path of the database file - parser.add_argument( - "--db", type=str, default="happysd.db", help="Path to SQLite database file" - ) - - args = parser.parse_args() - logger.info(args) - - main(args.db, args.testing) diff --git a/manage_db.py b/manage_db.py index f3679c2..bd9f8bc 100644 --- a/manage_db.py +++ b/manage_db.py @@ -34,13 +34,14 @@ def create_table_history(c): c.execute( """CREATE TABLE IF NOT EXISTS history (uuid TEXT PRIMARY KEY, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP, updated_at TIMESTAMP, apikey TEXT, priority INT, type TEXT, status TEXT, prompt TEXT, + lang TEXT, neg_prompt TEXT, seed TEXT, ref_img TEXT, @@ -146,7 +147,7 @@ def show_users(c, username="", details=False): print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") if details: c.execute( - "SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", + "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", (user[1],), ) rows = c.fetchall() @@ -163,7 +164,7 @@ def show_users(c, username="", details=False): print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") if details: c.execute( - "SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", + "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", (user[1],), ) rows = c.fetchall() diff --git a/requirements.txt b/requirements.txt index ca9a5df..b44b5f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ Pillow scikit-image torch transformers +sentencepiece +fcntl diff --git a/templates/index.html b/templates/index.html index 423b78a..19a49f9 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,4 +1,4 @@ - + @@ -17,75 +17,84 @@
- - + +
-
+
- + -
Less than 77 words otherwise it'll be truncated. Example: - "photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high - quality, film grain, Fujifilm XT3"
+
Less than 77 words. Example: photo of a cute cat. + Use () to emphasize.
- + -
Less than 77 words otherwise it'll be truncated. - Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, - drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg - artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn - hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad - proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, - missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long - neck"
+
Less than 77 words. Optional.
- + -
Leave it empty or put 0 to use a random - seed +
+ Leave it empty or set 0 to use a random seed
- + -
Each step is about 38s (CPU) or 0.1s - (GPU) + placeholder="50"> +
More steps better image but longer time to generate
- +
- +
- + -
How much guidance to follow from - description. 20 strictly follow prompt, 7 creative/artistic. + aria-describedby="inputGuidanceScaleHelp" placeholder="12.5" min="1" max="30"> +
+ Don't set it to the extremes (1 or 30). 20 means strictly follow prompt, 7 + creative/artistic. Lower this number if you see bad images.
- +
@@ -93,54 +102,62 @@