adds missing files separating frontend and backend
This commit is contained in:
parent
4a9d60e00b
commit
a609258a0c
|
|
@ -0,0 +1,162 @@
|
|||
import argparse
|
||||
|
||||
from utilities.constants import LOGGER_NAME_BACKEND
|
||||
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||
from utilities.constants import LOGGER_NAME_IMG2IMG
|
||||
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import KEY_LANGUAGE
|
||||
from utilities.constants import VALUE_LANGUAGE_EN
|
||||
from utilities.constants import KEY_PROMPT
|
||||
from utilities.constants import KEY_NEG_PROMPT
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
from utilities.constants import VALUE_JOB_FAILED
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
|
||||
from utilities.translator import translate_prompt
|
||||
from utilities.config import Config
|
||||
from utilities.database import Database
|
||||
from utilities.logger import Logger
|
||||
from utilities.model import Model
|
||||
from utilities.text2img import Text2Img
|
||||
from utilities.img2img import Img2Img
|
||||
from utilities.times import wait_for_seconds
|
||||
|
||||
|
||||
logger = Logger(name=LOGGER_NAME_BACKEND)
|
||||
database = Database(logger)
|
||||
|
||||
|
||||
def load_model(logger: Logger, use_gpu: bool) -> Model:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
# "CompVis/stable-diffusion-v1-4"
|
||||
# "stabilityai/stable-diffusion-2-1"
|
||||
# "SG161222/Realistic_Vision_V2.0"
|
||||
# "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||
# "darkstorm2150/Protogen_x5.8_Official_Release"
|
||||
# "prompthero/openjourney"
|
||||
# "naclbit/trinart_stable_diffusion_v2"
|
||||
# "hakurei/waifu-diffusion"
|
||||
model_name = "SG161222/Realistic_Vision_V2.0"
|
||||
# inpainting model candidates:
|
||||
# "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
|
||||
if use_gpu:
|
||||
model.set_low_memory_mode()
|
||||
model.load_all()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def backend(model, is_debugging: bool):
|
||||
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||
text2img.breakfast()
|
||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||
img2img.breakfast()
|
||||
|
||||
while 1:
|
||||
wait_for_seconds(1)
|
||||
|
||||
if is_debugging:
|
||||
pending_jobs = database.get_jobs()
|
||||
else:
|
||||
pending_jobs = database.get_all_pending_jobs()
|
||||
if len(pending_jobs) == 0:
|
||||
continue
|
||||
|
||||
next_job = pending_jobs[0]
|
||||
|
||||
if not is_debugging:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]
|
||||
)
|
||||
|
||||
prompt = next_job[KEY_PROMPT]
|
||||
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||
|
||||
if KEY_LANGUAGE in next_job:
|
||||
logger.info(
|
||||
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
|
||||
)
|
||||
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
|
||||
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
|
||||
logger.info(f"translated {prompt} to {prompt_en}")
|
||||
prompt = prompt_en
|
||||
if negative_prompt:
|
||||
negative_prompt_en = translate_prompt(
|
||||
negative_prompt, next_job[KEY_LANGUAGE]
|
||||
)
|
||||
logger.info(f"translated {negative_prompt} to {negative_prompt_en}")
|
||||
negative_prompt = negative_prompt_en
|
||||
|
||||
prompt += "RAW photo, (high detailed skin:1.2), 8k uhd, dslr, high quality, film grain, Fujifilm XT3"
|
||||
negative_prompt += "(deformed iris, deformed pupils:1.4), worst quality, low quality, jpeg artifacts, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
|
||||
config = Config().set_config(next_job)
|
||||
|
||||
try:
|
||||
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
||||
result_dict = text2img.lunch(
|
||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||
)
|
||||
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
|
||||
ref_img = next_job[REFERENCE_IMG]
|
||||
result_dict = img2img.lunch(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
reference_image=ref_img,
|
||||
config=config,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except BaseException as e:
|
||||
logger.error("text2img.lunch error: {}".format(e))
|
||||
if not is_debugging:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
|
||||
)
|
||||
continue
|
||||
|
||||
if not is_debugging:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID]
|
||||
)
|
||||
database.update_job(result_dict, job_uuid=next_job[UUID])
|
||||
|
||||
logger.critical("stopped")
|
||||
|
||||
|
||||
def main(args):
|
||||
database.connect(args.db)
|
||||
|
||||
model = load_model(logger, args.gpu)
|
||||
backend(model, args.debug)
|
||||
|
||||
database.safe_disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add an argument to set the 'debug' flag
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
parser.add_argument(
|
||||
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||
)
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
import argparse
|
||||
import uuid
|
||||
from flask import jsonify
|
||||
from flask import Flask
|
||||
from flask import render_template
|
||||
from flask import request
|
||||
|
||||
from utilities.constants import LOGGER_NAME_FRONTEND
|
||||
|
||||
from utilities.logger import Logger
|
||||
|
||||
from utilities.constants import APIKEY
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import MAX_JOB_NUMBER
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.constants import KEY_LANGUAGE
|
||||
from utilities.constants import SUPPORTED_LANGS
|
||||
from utilities.constants import REQUIRED_KEYS
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import VALUE_JOB_INPAINTING
|
||||
from utilities.database import Database
|
||||
|
||||
logger = Logger(name=LOGGER_NAME_FRONTEND)
|
||||
database = Database(logger)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/add_job", methods=["POST"])
|
||||
def add_job():
|
||||
req = request.get_json()
|
||||
|
||||
if APIKEY not in req:
|
||||
logger.error(f"{APIKEY} not present in {req}")
|
||||
return "", 401
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
logger.error(f"user not found with {req[APIKEY]}")
|
||||
return "", 401
|
||||
|
||||
for key in req.keys():
|
||||
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
|
||||
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
|
||||
for required_key in REQUIRED_KEYS:
|
||||
if required_key not in req:
|
||||
return jsonify({"msg": "missing one or more required keys"}), 404
|
||||
|
||||
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||
return jsonify({"msg": "missing reference image"}), 404
|
||||
|
||||
if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS:
|
||||
return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404
|
||||
|
||||
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
|
||||
return (
|
||||
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
|
||||
500,
|
||||
)
|
||||
|
||||
job_uuid = str(uuid.uuid4())
|
||||
logger.info("adding a new job with uuid {}..".format(job_uuid))
|
||||
|
||||
database.insert_new_job(req, job_uuid=job_uuid)
|
||||
|
||||
return jsonify({"msg": "", UUID: job_uuid})
|
||||
|
||||
|
||||
@app.route("/cancel_job", methods=["POST"])
|
||||
def cancel_job():
|
||||
req = request.get_json()
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
if UUID not in req:
|
||||
return jsonify({"msg": "missing uuid"}), 404
|
||||
|
||||
logger.info("cancelling job with uuid {}..".format(req[UUID]))
|
||||
|
||||
result = database.cancel_job(job_uuid=req[UUID])
|
||||
|
||||
if result:
|
||||
msg = "job with uuid {} removed".format(req[UUID])
|
||||
return jsonify({"msg": msg})
|
||||
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
if jobs:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"msg": "job {} is not in pending state, unable to cancel".format(
|
||||
req[UUID]
|
||||
)
|
||||
}
|
||||
),
|
||||
405,
|
||||
)
|
||||
|
||||
return (
|
||||
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
||||
404,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/get_jobs", methods=["POST"])
|
||||
def get_jobs():
|
||||
req = request.get_json()
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
def main(args):
|
||||
database.connect(args.db)
|
||||
|
||||
if args.debug:
|
||||
app.run(host="0.0.0.0", port="5432")
|
||||
else:
|
||||
app.run(host="0.0.0.0", port="8888")
|
||||
|
||||
database.safe_disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add an argument to set the 'debug' flag
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
parser.add_argument(
|
||||
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from transformers import MBart50TokenizerFast
|
||||
from transformers import MBartForConditionalGeneration
|
||||
|
||||
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
|
||||
def translate_prompt(prompt, src_lang):
|
||||
"""helper function to translate prompt to English"""
|
||||
|
||||
tokenizer.set_src_lang_special_tokens(src_lang)
|
||||
tokenizer.src_lang = src_lang
|
||||
|
||||
encoded_prompt = tokenizer(prompt, return_tensors="pt").to("cpu")
|
||||
generated_tokens = model.generate(**encoded_prompt, max_new_tokens=1000, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
||||
|
||||
en_trans = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
return en_trans[0]
|
||||
Loading…
Reference in New Issue