From 3097b262f3eda250e6ee51e074aa162b4ec73ed6 Mon Sep 17 00:00:00 2001 From: HappyZ Date: Sat, 27 May 2023 22:58:13 -0700 Subject: [PATCH] adds restoration via external GFPGAN --- BUILD | 3 + backend.py | 31 ++- frontend.py | 58 +++++- requirements.txt | 1 + templates/index.html | 26 +-- templates/restoration.html | 378 +++++++++++++++++++++++++++++++++++++ utilities/BUILD | 11 ++ utilities/constants.py | 1 + utilities/database.py | 29 ++- utilities/external.py | 48 +++++ utilities/img2img.py | 7 +- utilities/inpainting.py | 12 +- 12 files changed, 563 insertions(+), 42 deletions(-) create mode 100644 templates/restoration.html create mode 100644 utilities/external.py diff --git a/BUILD b/BUILD index 2c27efb..80de63b 100644 --- a/BUILD +++ b/BUILD @@ -11,9 +11,11 @@ par_binary( "//utilities:constants", "//utilities:database", "//utilities:logger", + "//utilities:images", ], data=[ "templates/index.html", + "templates/restoration.html", "static/bootstrap.min.css", "static/jquery-3.6.1.min.js", "static/bootstrap.bundle.min.js", @@ -32,6 +34,7 @@ par_binary( "//utilities:constants", "//utilities:database", "//utilities:memory", + "//utilities:external", "//utilities:logger", "//utilities:model", "//utilities:config", diff --git a/backend.py b/backend.py index a3235e4..89db89b 100644 --- a/backend.py +++ b/backend.py @@ -19,6 +19,7 @@ from utilities.constants import KEY_JOB_TYPE from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import VALUE_JOB_RESTORATION from utilities.constants import REFERENCE_IMG from utilities.constants import MASK_IMG @@ -32,6 +33,7 @@ from utilities.img2img import Img2Img from utilities.inpainting import Inpainting from utilities.times import wait_for_seconds from utilities.memory import empty_memory_cache +from utilities.external import gfpgan logger = Logger(name=LOGGER_NAME_BACKEND) @@ -62,7 +64,7 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode return model -def backend(model, is_debugging: bool): +def backend(model, gfpgan_folderpath, is_debugging: bool): text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG)) text2img.breakfast() img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) @@ -87,10 +89,14 @@ def backend(model, is_debugging: bool): {KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID] ) - prompt = next_job[KEY_PROMPT] - negative_prompt = next_job[KEY_NEG_PROMPT] + prompt = next_job.get(KEY_PROMPT, "") + negative_prompt = next_job.get(KEY_NEG_PROMPT, "") - if KEY_LANGUAGE in next_job: + if ( + next_job[KEY_JOB_TYPE] + in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_TXT2IMG] + and KEY_LANGUAGE in next_job + ): if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]: logger.info( f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first" @@ -130,6 +136,13 @@ def backend(model, is_debugging: bool): mask_image=mask_img, config=config, ) + elif next_job[KEY_JOB_TYPE] == VALUE_JOB_RESTORATION: + ref_img_filepath = next_job[REFERENCE_IMG] + result_dict = gfpgan(gfpgan_folderpath, next_job[UUID], ref_img_filepath, config=config, logger=logger) + if not result_dict: + raise ValueError("failed to run gfpgan") + else: + raise ValueError("unrecognized job type") except KeyboardInterrupt: break except BaseException as e: @@ -154,7 +167,7 @@ def main(args): database.connect(args.db) model = load_model(logger, args.gpu, args.reduce_memory_usage) - backend(model, args.debug) + backend(model, args.gfpgan, args.debug) database.safe_disconnect() @@ -180,6 +193,14 @@ if __name__ == "__main__": help="Reduce memory usage when using GPU", ) + # Add an argument to reduce memory usage + parser.add_argument( + "--gfpgan", + type=str, + default="", + help="GFPGAN folderpath", + ) + # Add an argument to set the path of the database file parser.add_argument( "--image-output-folder", diff --git a/frontend.py b/frontend.py index 229cf68..3f0903b 100644 --- a/frontend.py +++ b/frontend.py @@ -13,7 +13,9 @@ from utilities.logger import Logger from utilities.constants import APIKEY from utilities.constants import KEY_JOB_TYPE +from utilities.constants import BASE64IMAGE from utilities.constants import REFERENCE_IMG +from utilities.constants import MASK_IMG from utilities.constants import MAX_JOB_NUMBER from utilities.constants import OPTIONAL_KEYS from utilities.constants import KEY_LANGUAGE @@ -23,7 +25,10 @@ from utilities.constants import UUID from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import VALUE_JOB_RESTORATION +from utilities.constants import IMAGE_NOT_FOUND_BASE64 from utilities.database import Database +from utilities.images import load_image logger = Logger(name=LOGGER_NAME_FRONTEND) database = Database(logger) @@ -35,6 +40,7 @@ limiter = Limiter( @app.route("/add_job", methods=["POST"]) +@limiter.limit("1/second") def add_job(): req = request.get_json() @@ -49,13 +55,23 @@ def add_job(): for key in req.keys(): if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS): return jsonify({"msg": "provided one or more unrecognized keys"}), 404 - for required_key in REQUIRED_KEYS: - if required_key not in req: - return jsonify({"msg": "missing one or more required keys"}), 404 - if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req: + # only checks required key for non-restoration jobs + if req.get(KEY_JOB_TYPE, None) != VALUE_JOB_RESTORATION: + for required_key in REQUIRED_KEYS: + if required_key not in req: + return jsonify({"msg": "missing one or more required keys"}), 404 + + if ( + req[KEY_JOB_TYPE] + in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION] + and REFERENCE_IMG not in req + ): return jsonify({"msg": "missing reference image"}), 404 + if req[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING and MASK_IMG not in req: + return jsonify({"msg": "missing mask image"}), 404 + if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS: return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404 @@ -74,6 +90,7 @@ def add_job(): @app.route("/cancel_job", methods=["POST"]) +@limiter.limit("1/second") def cancel_job(): req = request.get_json() if APIKEY not in req: @@ -114,6 +131,7 @@ def cancel_job(): @app.route("/get_jobs", methods=["POST"]) +@limiter.limit("1/second") def get_jobs(): req = request.get_json() if APIKEY not in req: @@ -127,10 +145,24 @@ def get_jobs(): if UUID in req: jobs = database.get_jobs( - job_uuid=req[UUID], apikey=req[APIKEY], limit_count=job_count_limit + job_uuid=req[UUID], + apikey=req[APIKEY], + job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [], + limit_count=job_count_limit, ) else: - jobs = database.get_jobs(apikey=req[APIKEY], limit_count=job_count_limit) + jobs = database.get_jobs( + apikey=req[APIKEY], + job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [], + limit_count=job_count_limit, + ) + + for job in jobs: + # load image to job if has one + for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: + if key in job and "base64" not in job[key]: + data = load_image(job[key], to_base64=True) + job[key] = data if data else IMAGE_NOT_FOUND_BASE64 return jsonify({"jobs": jobs}) @@ -143,14 +175,28 @@ def random_jobs(): jobs = database.get_random_jobs(limit_count=job_count_limit) + for job in jobs: + # load image to job if has one + for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: + if key in job and "base64" not in job[key]: + data = load_image(job[key], to_base64=True) + job[key] = data if data else IMAGE_NOT_FOUND_BASE64 + return jsonify({"jobs": jobs}) @app.route("/") +@limiter.limit("1/second") def index(): return render_template("index.html") +@app.route("/restoration") +@limiter.limit("1/second") +def restoration(): + return render_template("restoration.html") + + def main(args): database.set_image_output_folder(args.image_output_folder) database.connect(args.db) diff --git a/requirements.txt b/requirements.txt index 351fc2b..07fd7f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ torch==2.0.0 transformers==4.28.1 sentencepiece==0.1.99 Flask-Limiter==3.3.1 +protobuf==3.20 \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index 24558b7..567747b 100644 --- a/templates/index.html +++ b/templates/index.html @@ -23,7 +23,11 @@ @@ -496,12 +500,11 @@ } } } - setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second + setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second }, error: function (xhr, status, error) { - // Handle error response - console.log(xhr.responseText); - $('#txt2ImgStatus').html('failed'); + console.log(error); + setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second } }); } @@ -631,16 +634,12 @@ $('#getJobHistory').click(function () { var apikeyValue = $('#apiKey').val(); var uuidValue = $('#lookupUUID').val(); - if (uuidValue == null) { - alert("no UUID specified"); - return; - } $.ajax({ type: 'POST', url: '/get_jobs', contentType: 'application/json; charset=utf-8', dataType: 'json', - data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue }), + data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue, 'type': 'txt,img,inpaint' }), success: function (response) { var jobsLength = response.jobs.length; if (jobsLength == 0) { @@ -662,8 +661,9 @@ "
  • seed: " + response.jobs[i].seed + "
  • " + "
  • uuid: " + response.jobs[i].uuid + "
  • " + "
  • w x h: " + response.jobs[i].width + " x " + response.jobs[i].height + "
  • " + - "" + - ""); + "" + + (response.jobs[i].ref_img ? ("") : "") + + ""); // Add event handler for click to toggle blurriness if (isPrivate === 1) { element.find('.card').addClass('private-card'); diff --git a/templates/restoration.html b/templates/restoration.html new file mode 100644 index 0000000..12b0275 --- /dev/null +++ b/templates/restoration.html @@ -0,0 +1,378 @@ + + + + + + {{ config.TITLE }} + + + + + + + +
    + + +
    +
    + Restoration +
    +
    +
    +
    +
    + + +
    + +
    + +
    +
    +
    +
    +
    +
    + + +
    +
    + Upsampling scale of the image +
    +
    +
    +
    + + +
    +
    + Adjustable weights +
    +
    +
    +
    +
    +
    +
    + 1. Choose Original Image +
    +
    +
    + +
    +
    + +
    +
    + +
    +
    +
    +
    +
    + 2. Result +
    +
    +
      +
    • + + Job UUID +
    • +
    • + + Job Status +
    • +
    +
    + +
    +
    +
    +
    +
    + +
    +
    + History +
    +
    +
    + + + +
    +
    +
    +
    + +
    + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/utilities/BUILD b/utilities/BUILD index 80dc768..303c121 100644 --- a/utilities/BUILD +++ b/utilities/BUILD @@ -27,6 +27,17 @@ py_library( ], ) +py_library( + name="external", + srcs=["external.py"], + deps=[ + ":logger", + ":config", + ":constants", + ":images", + ], +) + py_library( name="envvar", srcs=["envvar.py"], diff --git a/utilities/constants.py b/utilities/constants.py index f0dcbb2..c316390 100644 --- a/utilities/constants.py +++ b/utilities/constants.py @@ -32,6 +32,7 @@ VALUE_JOB_IMG2IMG = "img" REFERENCE_IMG = "ref_img" MASK_IMG = "mask_img" VALUE_JOB_INPAINTING = "inpaint" +VALUE_JOB_RESTORATION = "restoration" KEY_LANGUAGE = "lang" VALUE_LANGUAGE_ZH_CN = "zh_CN" diff --git a/utilities/database.py b/utilities/database.py index 687c582..cb4fb08 100644 --- a/utilities/database.py +++ b/utilities/database.py @@ -12,6 +12,7 @@ from utilities.constants import KEY_JOB_TYPE from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import VALUE_JOB_RESTORATION from utilities.constants import KEY_JOB_STATUS from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_DONE @@ -26,7 +27,6 @@ from utilities.constants import INTERNAL_KEYS from utilities.constants import REFERENCE_IMG from utilities.constants import MASK_IMG from utilities.constants import BASE64IMAGE -from utilities.constants import IMAGE_NOT_FOUND_BASE64 from utilities.constants import HISTORY_TABLE_NAME from utilities.constants import USERS_TABLE_NAME @@ -35,7 +35,6 @@ from utilities.logger import DummyLogger from utilities.times import get_epoch_now from utilities.times import epoch_to_string from utilities.images import save_image -from utilities.images import load_image # Function to acquire a lock on the database file @@ -130,7 +129,7 @@ class Database: return result[0] def get_random_jobs(self, limit_count=0) -> list: - query = f"SELECT {', '.join(ANONYMOUS_KEYS)} FROM {HISTORY_TABLE_NAME} WHERE {KEY_JOB_STATUS} = ? AND {KEY_IS_PRIVATE} = ? AND rowid IN (SELECT rowid FROM {HISTORY_TABLE_NAME} ORDER BY RANDOM() LIMIT ?)" + query = f"SELECT {', '.join(ANONYMOUS_KEYS)} FROM {HISTORY_TABLE_NAME} WHERE {KEY_JOB_STATUS} = ? AND {KEY_IS_PRIVATE} = ? AND rowid IN (SELECT rowid FROM {HISTORY_TABLE_NAME} ORDER BY RANDOM() LIMIT ?) AND {KEY_JOB_TYPE} IN ({VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION})" # execute the query and return the results c = self.get_cursor() @@ -143,20 +142,17 @@ class Database: for i in range(len(ANONYMOUS_KEYS)) if row[i] is not None } - # load image to job if has one - for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: - if key in job and "base64" not in job[key]: - data = load_image(job[key], to_base64=True) - job[key] = data if data else IMAGE_NOT_FOUND_BASE64 jobs.append(job) return jobs - def get_jobs(self, job_uuid="", apikey="", job_status="", limit_count=0) -> list: + def get_jobs( + self, job_uuid="", apikey="", job_status="", job_types=[], limit_count=0 + ) -> list: """ Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters. - If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter. + If `job_uuid` or `apikey` or `job_status` or `job_type` is provided, the query will include that filter. Returns a list of jobs matching the filters provided. """ @@ -172,6 +168,11 @@ class Database: if job_status: query_filters.append(f"{KEY_JOB_STATUS} = ?") values.append(job_status) + if job_types: + query_filters.append( + f"{KEY_JOB_TYPE} IN ({', '.join(['?' for _ in job_types])})" + ) + values += job_types columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}" @@ -190,11 +191,6 @@ class Database: job = { columns[i]: row[i] for i in range(len(columns)) if row[i] is not None } - # load image to job if has one - for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]: - if key in job and "base64" not in job[key]: - data = load_image(job[key], to_base64=True) - job[key] = data if data else IMAGE_NOT_FOUND_BASE64 jobs.append(job) return jobs @@ -254,6 +250,9 @@ class Database: Returns True if the update was successful, otherwise False. """ + if not job_dict: + return False + # store image to job_dict if has one if ( self.__image_output_folder diff --git a/utilities/external.py b/utilities/external.py new file mode 100644 index 0000000..f1ae972 --- /dev/null +++ b/utilities/external.py @@ -0,0 +1,48 @@ +import os + +from utilities.constants import BASE64IMAGE +from utilities.constants import KEY_WIDTH +from utilities.constants import KEY_HEIGHT +from utilities.constants import KEY_BASE_MODEL + +from utilities.config import Config +from utilities.logger import DummyLogger +from utilities.images import image_to_base64 +from utilities.images import load_image +from utilities.images import save_image +from utilities.images import base64_to_image + + +def gfpgan( + gfpgan_folderpath, job_uuid, img_filepath, config=Config(), logger=DummyLogger() +): + if not os.path.isdir(gfpgan_folderpath): + logger.error(f"unable to find GFPGAN folder {gfpgan_folderpath}") + return {} + + if not os.path.isfile(img_filepath): + logger.error(f"unable to find image file {img_filepath}") + return {} + + tmp_output_dir = f"/tmp/{job_uuid}" + os.makedirs(tmp_output_dir, exist_ok=True) + + cmd = f"/usr/bin/python {gfpgan_folderpath}/inference_gfpgan.py -i {img_filepath} -o {tmp_output_dir} -v 1.3 -s {config.get_steps()} -w {config.get_strength()}" + logger.info(f"running: {cmd}") + os.system(cmd) + + img_output_path = os.path.join(tmp_output_dir, "restored_imgs", os.path.basename(img_filepath)) + logger.info(f"image path: {img_output_path}") + try: + + image = load_image(img_output_path) + width, height = image.size + return { + BASE64IMAGE: image_to_base64(image), + KEY_WIDTH: width, + KEY_HEIGHT: height, + KEY_BASE_MODEL: "gfpgan", + } + except Exception as e: + logger.error(f"Scaling failed: {e}") + return {} diff --git a/utilities/img2img.py b/utilities/img2img.py index 033ac92..710122b 100644 --- a/utilities/img2img.py +++ b/utilities/img2img.py @@ -15,6 +15,7 @@ from utilities.memory import empty_memory_cache from utilities.model import Model from utilities.times import get_epoch_now from utilities.images import image_to_base64 +from utilities.images import load_image from utilities.images import base64_to_image @@ -116,7 +117,11 @@ class Img2Img: self.__logger.info("current seed: {}".format(seed)) if isinstance(reference_image, str): - reference_image = base64_to_image(reference_image).convert("RGB") + if "base64" in reference_image: + reference_image = base64_to_image(reference_image).convert("RGB") + else: + # is filepath + reference_image = load_image(reference_image).convert("RGB") reference_image.thumbnail((config.get_width(), config.get_height())) ( diff --git a/utilities/inpainting.py b/utilities/inpainting.py index 072186e..54b6e82 100644 --- a/utilities/inpainting.py +++ b/utilities/inpainting.py @@ -119,11 +119,19 @@ class Inpainting: self.__logger.info("current seed: {}".format(seed)) if isinstance(reference_image, str): - reference_image = base64_to_image(reference_image).convert("RGB") + if "base64" in reference_image: + reference_image = base64_to_image(reference_image).convert("RGB") + else: + # is filepath + reference_image = load_image(reference_image).convert("RGB") reference_image.thumbnail((config.get_width(), config.get_height())) if isinstance(mask_image, str): - mask_image = base64_to_image(mask_image).convert("RGB") + if "base64" in mask_image: + mask_image = base64_to_image(mask_image).convert("RGB") + else: + # is filepath + mask_image = load_image(mask_image).convert("RGB") # assume mask image and reference image size ratio is the same if mask_image.size[0] < reference_image.size[0]: mask_image = mask_image.resize(reference_image.size)