adds missing files separating frontend and backend

This commit is contained in:
HappyZ 2023-05-08 16:06:09 -07:00
parent 4a9d60e00b
commit a609258a0c
3 changed files with 333 additions and 0 deletions

162
backend.py Normal file
View File

@ -0,0 +1,162 @@
import argparse
from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import UUID
from utilities.constants import KEY_LANGUAGE
from utilities.constants import VALUE_LANGUAGE_EN
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import VALUE_JOB_FAILED
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import REFERENCE_IMG
from utilities.translator import translate_prompt
from utilities.config import Config
from utilities.database import Database
from utilities.logger import Logger
from utilities.model import Model
from utilities.text2img import Text2Img
from utilities.img2img import Img2Img
from utilities.times import wait_for_seconds
logger = Logger(name=LOGGER_NAME_BACKEND)
database = Database(logger)
def load_model(logger: Logger, use_gpu: bool) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "darkstorm2150/Protogen_x5.8_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "SG161222/Realistic_Vision_V2.0"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu:
model.set_low_memory_mode()
model.load_all()
return model
def backend(model, is_debugging: bool):
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
text2img.breakfast()
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
img2img.breakfast()
while 1:
wait_for_seconds(1)
if is_debugging:
pending_jobs = database.get_jobs()
else:
pending_jobs = database.get_all_pending_jobs()
if len(pending_jobs) == 0:
continue
next_job = pending_jobs[0]
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]
)
prompt = next_job[KEY_PROMPT]
negative_prompt = next_job[KEY_NEG_PROMPT]
if KEY_LANGUAGE in next_job:
logger.info(
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
)
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
logger.info(f"translated {prompt} to {prompt_en}")
prompt = prompt_en
if negative_prompt:
negative_prompt_en = translate_prompt(
negative_prompt, next_job[KEY_LANGUAGE]
)
logger.info(f"translated {negative_prompt} to {negative_prompt_en}")
negative_prompt = negative_prompt_en
prompt += "RAW photo, (high detailed skin:1.2), 8k uhd, dslr, high quality, film grain, Fujifilm XT3"
negative_prompt += "(deformed iris, deformed pupils:1.4), worst quality, low quality, jpeg artifacts, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
config = Config().set_config(next_job)
try:
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
ref_img = next_job[REFERENCE_IMG]
result_dict = img2img.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
config=config,
)
except KeyboardInterrupt:
break
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
)
continue
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID]
)
database.update_job(result_dict, job_uuid=next_job[UUID])
logger.critical("stopped")
def main(args):
database.connect(args.db)
model = load_model(logger, args.gpu)
backend(model, args.debug)
database.safe_disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'debug' flag
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
# Add an argument to set the path of the database file
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
args = parser.parse_args()
main(args)

154
frontend.py Normal file
View File

@ -0,0 +1,154 @@
import argparse
import uuid
from flask import jsonify
from flask import Flask
from flask import render_template
from flask import request
from utilities.constants import LOGGER_NAME_FRONTEND
from utilities.logger import Logger
from utilities.constants import APIKEY
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import REFERENCE_IMG
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import KEY_LANGUAGE
from utilities.constants import SUPPORTED_LANGS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.database import Database
logger = Logger(name=LOGGER_NAME_FRONTEND)
database = Database(logger)
app = Flask(__name__)
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if APIKEY not in req:
logger.error(f"{APIKEY} not present in {req}")
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
logger.error(f"user not found with {req[APIKEY]}")
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
return jsonify({"msg": "missing reference image"}), 404
if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS:
return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
return (
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
500,
)
job_uuid = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(job_uuid))
database.insert_new_job(req, job_uuid=job_uuid)
return jsonify({"msg": "", UUID: job_uuid})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if APIKEY not in req:
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("cancelling job with uuid {}..".format(req[UUID]))
result = database.cancel_job(job_uuid=req[UUID])
if result:
msg = "job with uuid {} removed".format(req[UUID])
return jsonify({"msg": msg})
jobs = database.get_jobs(job_uuid=req[UUID])
if jobs:
return (
jsonify(
{
"msg": "job {} is not in pending state, unable to cancel".format(
req[UUID]
)
}
),
405,
)
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if APIKEY not in req:
return "", 401
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
jobs = database.get_jobs(job_uuid=req[UUID])
return jsonify({"jobs": jobs})
@app.route("/")
def index():
return render_template("index.html")
def main(args):
database.connect(args.db)
if args.debug:
app.run(host="0.0.0.0", port="5432")
else:
app.run(host="0.0.0.0", port="8888")
database.safe_disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'debug' flag
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
args = parser.parse_args()
main(args)

17
utilities/translator.py Normal file
View File

@ -0,0 +1,17 @@
from transformers import MBart50TokenizerFast
from transformers import MBartForConditionalGeneration
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
def translate_prompt(prompt, src_lang):
"""helper function to translate prompt to English"""
tokenizer.set_src_lang_special_tokens(src_lang)
tokenizer.src_lang = src_lang
encoded_prompt = tokenizer(prompt, return_tensors="pt").to("cpu")
generated_tokens = model.generate(**encoded_prompt, max_new_tokens=1000, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
en_trans = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return en_trans[0]