diff --git a/BUILD b/BUILD index 2c27efb..80de63b 100644 --- a/BUILD +++ b/BUILD @@ -11,9 +11,11 @@ par_binary( "//utilities:constants", "//utilities:database", "//utilities:logger", + "//utilities:images", ], data=[ "templates/index.html", + "templates/restoration.html", "static/bootstrap.min.css", "static/jquery-3.6.1.min.js", "static/bootstrap.bundle.min.js", @@ -32,6 +34,7 @@ par_binary( "//utilities:constants", "//utilities:database", "//utilities:memory", + "//utilities:external", "//utilities:logger", "//utilities:model", "//utilities:config", diff --git a/backend.py b/backend.py index a3235e4..89db89b 100644 --- a/backend.py +++ b/backend.py @@ -19,6 +19,7 @@ from utilities.constants import KEY_JOB_TYPE from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import VALUE_JOB_RESTORATION from utilities.constants import REFERENCE_IMG from utilities.constants import MASK_IMG @@ -32,6 +33,7 @@ from utilities.img2img import Img2Img from utilities.inpainting import Inpainting from utilities.times import wait_for_seconds from utilities.memory import empty_memory_cache +from utilities.external import gfpgan logger = Logger(name=LOGGER_NAME_BACKEND) @@ -62,7 +64,7 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode return model -def backend(model, is_debugging: bool): +def backend(model, gfpgan_folderpath, is_debugging: bool): text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG)) text2img.breakfast() img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) @@ -87,10 +89,14 @@ def backend(model, is_debugging: bool): {KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID] ) - prompt = next_job[KEY_PROMPT] - negative_prompt = next_job[KEY_NEG_PROMPT] + prompt = next_job.get(KEY_PROMPT, "") + negative_prompt = next_job.get(KEY_NEG_PROMPT, "") - if KEY_LANGUAGE in next_job: + if ( + next_job[KEY_JOB_TYPE] + in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_TXT2IMG] + and KEY_LANGUAGE in next_job + ): if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]: logger.info( f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first" @@ -130,6 +136,13 @@ def backend(model, is_debugging: bool): mask_image=mask_img, config=config, ) + elif next_job[KEY_JOB_TYPE] == VALUE_JOB_RESTORATION: + ref_img_filepath = next_job[REFERENCE_IMG] + result_dict = gfpgan(gfpgan_folderpath, next_job[UUID], ref_img_filepath, config=config, logger=logger) + if not result_dict: + raise ValueError("failed to run gfpgan") + else: + raise ValueError("unrecognized job type") except KeyboardInterrupt: break except BaseException as e: @@ -154,7 +167,7 @@ def main(args): database.connect(args.db) model = load_model(logger, args.gpu, args.reduce_memory_usage) - backend(model, args.debug) + backend(model, args.gfpgan, args.debug) database.safe_disconnect() @@ -180,6 +193,14 @@ if __name__ == "__main__": help="Reduce memory usage when using GPU", ) + # Add an argument to reduce memory usage + parser.add_argument( + "--gfpgan", + type=str, + default="", + help="GFPGAN folderpath", + ) + # Add an argument to set the path of the database file parser.add_argument( "--image-output-folder", diff --git a/frontend.py b/frontend.py index 229cf68..3f0903b 100644 --- a/frontend.py +++ b/frontend.py @@ -13,7 +13,9 @@ from utilities.logger import Logger from utilities.constants import APIKEY from utilities.constants import KEY_JOB_TYPE +from utilities.constants import BASE64IMAGE from utilities.constants import REFERENCE_IMG +from utilities.constants import MASK_IMG from utilities.constants import MAX_JOB_NUMBER from utilities.constants import OPTIONAL_KEYS from utilities.constants import KEY_LANGUAGE @@ -23,7 +25,10 @@ from utilities.constants import UUID from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import VALUE_JOB_RESTORATION +from utilities.constants import IMAGE_NOT_FOUND_BASE64 from utilities.database import Database +from utilities.images import load_image logger = Logger(name=LOGGER_NAME_FRONTEND) database = Database(logger) @@ -35,6 +40,7 @@ limiter = Limiter( @app.route("/add_job", methods=["POST"]) +@limiter.limit("1/second") def add_job(): req = request.get_json() @@ -49,13 +55,23 @@ def add_job(): for key in req.keys(): if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS): return jsonify({"msg": "provided one or more unrecognized keys"}), 404 - for required_key in REQUIRED_KEYS: - if required_key not in req: - return jsonify({"msg": "missing one or more required keys"}), 404 - if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req: + # only checks required key for non-restoration jobs + if req.get(KEY_JOB_TYPE, None) != VALUE_JOB_RESTORATION: + for required_key in REQUIRED_KEYS: + if required_key not in req: + return jsonify({"msg": "missing one or more required keys"}), 404 + + if ( + req[KEY_JOB_TYPE] + in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION] + and REFERENCE_IMG not in req + ): return jsonify({"msg": "missing reference image"}), 404 + if req[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING and MASK_IMG not in req: + return jsonify({"msg": "missing mask image"}), 404 + if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS: return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404 @@ -74,6 +90,7 @@ def add_job(): @app.route("/cancel_job", methods=["POST"]) +@limiter.limit("1/second") def cancel_job(): req = request.get_json() if APIKEY not in req: @@ -114,6 +131,7 @@ def cancel_job(): @app.route("/get_jobs", methods=["POST"]) +@limiter.limit("1/second") def get_jobs(): req = request.get_json() if APIKEY not in req: @@ -127,10 +145,24 @@ def get_jobs(): if UUID in req: jobs = database.get_jobs( - job_uuid=req[UUID], apikey=req[APIKEY], limit_count=job_count_limit + job_uuid=req[UUID], + apikey=req[APIKEY], + job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [], + limit_count=job_count_limit, ) else: - jobs = database.get_jobs(apikey=req[APIKEY], limit_count=job_count_limit) + jobs = database.get_jobs( + apikey=req[APIKEY], + job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [], + limit_count=job_count_limit, + ) + + for job in jobs: + # load image to job if has one + for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: + if key in job and "base64" not in job[key]: + data = load_image(job[key], to_base64=True) + job[key] = data if data else IMAGE_NOT_FOUND_BASE64 return jsonify({"jobs": jobs}) @@ -143,14 +175,28 @@ def random_jobs(): jobs = database.get_random_jobs(limit_count=job_count_limit) + for job in jobs: + # load image to job if has one + for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: + if key in job and "base64" not in job[key]: + data = load_image(job[key], to_base64=True) + job[key] = data if data else IMAGE_NOT_FOUND_BASE64 + return jsonify({"jobs": jobs}) @app.route("/") +@limiter.limit("1/second") def index(): return render_template("index.html") +@app.route("/restoration") +@limiter.limit("1/second") +def restoration(): + return render_template("restoration.html") + + def main(args): database.set_image_output_folder(args.image_output_folder) database.connect(args.db) diff --git a/requirements.txt b/requirements.txt index 351fc2b..07fd7f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ torch==2.0.0 transformers==4.28.1 sentencepiece==0.1.99 Flask-Limiter==3.3.1 +protobuf==3.20 \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index 24558b7..567747b 100644 --- a/templates/index.html +++ b/templates/index.html @@ -23,7 +23,11 @@
@@ -496,12 +500,11 @@ } } } - setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second + setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second }, error: function (xhr, status, error) { - // Handle error response - console.log(xhr.responseText); - $('#txt2ImgStatus').html('failed'); + console.log(error); + setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second } }); } @@ -631,16 +634,12 @@ $('#getJobHistory').click(function () { var apikeyValue = $('#apiKey').val(); var uuidValue = $('#lookupUUID').val(); - if (uuidValue == null) { - alert("no UUID specified"); - return; - } $.ajax({ type: 'POST', url: '/get_jobs', contentType: 'application/json; charset=utf-8', dataType: 'json', - data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue }), + data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue, 'type': 'txt,img,inpaint' }), success: function (response) { var jobsLength = response.jobs.length; if (jobsLength == 0) { @@ -662,8 +661,9 @@ "