adds restoration via external GFPGAN

This commit is contained in:
HappyZ 2023-05-27 22:58:13 -07:00
parent 203268252f
commit 3097b262f3
12 changed files with 563 additions and 42 deletions

3
BUILD
View File

@ -11,9 +11,11 @@ par_binary(
"//utilities:constants", "//utilities:constants",
"//utilities:database", "//utilities:database",
"//utilities:logger", "//utilities:logger",
"//utilities:images",
], ],
data=[ data=[
"templates/index.html", "templates/index.html",
"templates/restoration.html",
"static/bootstrap.min.css", "static/bootstrap.min.css",
"static/jquery-3.6.1.min.js", "static/jquery-3.6.1.min.js",
"static/bootstrap.bundle.min.js", "static/bootstrap.bundle.min.js",
@ -32,6 +34,7 @@ par_binary(
"//utilities:constants", "//utilities:constants",
"//utilities:database", "//utilities:database",
"//utilities:memory", "//utilities:memory",
"//utilities:external",
"//utilities:logger", "//utilities:logger",
"//utilities:model", "//utilities:model",
"//utilities:config", "//utilities:config",

View File

@ -19,6 +19,7 @@ from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_RESTORATION
from utilities.constants import REFERENCE_IMG from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG from utilities.constants import MASK_IMG
@ -32,6 +33,7 @@ from utilities.img2img import Img2Img
from utilities.inpainting import Inpainting from utilities.inpainting import Inpainting
from utilities.times import wait_for_seconds from utilities.times import wait_for_seconds
from utilities.memory import empty_memory_cache from utilities.memory import empty_memory_cache
from utilities.external import gfpgan
logger = Logger(name=LOGGER_NAME_BACKEND) logger = Logger(name=LOGGER_NAME_BACKEND)
@ -62,7 +64,7 @@ def load_model(logger: Logger, use_gpu: bool, reduce_memory_usage: bool) -> Mode
return model return model
def backend(model, is_debugging: bool): def backend(model, gfpgan_folderpath, is_debugging: bool):
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG)) text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
text2img.breakfast() text2img.breakfast()
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
@ -87,10 +89,14 @@ def backend(model, is_debugging: bool):
{KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID] {KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]
) )
prompt = next_job[KEY_PROMPT] prompt = next_job.get(KEY_PROMPT, "")
negative_prompt = next_job[KEY_NEG_PROMPT] negative_prompt = next_job.get(KEY_NEG_PROMPT, "")
if KEY_LANGUAGE in next_job: if (
next_job[KEY_JOB_TYPE]
in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_TXT2IMG]
and KEY_LANGUAGE in next_job
):
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]: if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
logger.info( logger.info(
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first" f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
@ -130,6 +136,13 @@ def backend(model, is_debugging: bool):
mask_image=mask_img, mask_image=mask_img,
config=config, config=config,
) )
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_RESTORATION:
ref_img_filepath = next_job[REFERENCE_IMG]
result_dict = gfpgan(gfpgan_folderpath, next_job[UUID], ref_img_filepath, config=config, logger=logger)
if not result_dict:
raise ValueError("failed to run gfpgan")
else:
raise ValueError("unrecognized job type")
except KeyboardInterrupt: except KeyboardInterrupt:
break break
except BaseException as e: except BaseException as e:
@ -154,7 +167,7 @@ def main(args):
database.connect(args.db) database.connect(args.db)
model = load_model(logger, args.gpu, args.reduce_memory_usage) model = load_model(logger, args.gpu, args.reduce_memory_usage)
backend(model, args.debug) backend(model, args.gfpgan, args.debug)
database.safe_disconnect() database.safe_disconnect()
@ -180,6 +193,14 @@ if __name__ == "__main__":
help="Reduce memory usage when using GPU", help="Reduce memory usage when using GPU",
) )
# Add an argument to reduce memory usage
parser.add_argument(
"--gfpgan",
type=str,
default="",
help="GFPGAN folderpath",
)
# Add an argument to set the path of the database file # Add an argument to set the path of the database file
parser.add_argument( parser.add_argument(
"--image-output-folder", "--image-output-folder",

View File

@ -13,7 +13,9 @@ from utilities.logger import Logger
from utilities.constants import APIKEY from utilities.constants import APIKEY
from utilities.constants import KEY_JOB_TYPE from utilities.constants import KEY_JOB_TYPE
from utilities.constants import BASE64IMAGE
from utilities.constants import REFERENCE_IMG from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.constants import MAX_JOB_NUMBER from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS from utilities.constants import OPTIONAL_KEYS
from utilities.constants import KEY_LANGUAGE from utilities.constants import KEY_LANGUAGE
@ -23,7 +25,10 @@ from utilities.constants import UUID
from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_RESTORATION
from utilities.constants import IMAGE_NOT_FOUND_BASE64
from utilities.database import Database from utilities.database import Database
from utilities.images import load_image
logger = Logger(name=LOGGER_NAME_FRONTEND) logger = Logger(name=LOGGER_NAME_FRONTEND)
database = Database(logger) database = Database(logger)
@ -35,6 +40,7 @@ limiter = Limiter(
@app.route("/add_job", methods=["POST"]) @app.route("/add_job", methods=["POST"])
@limiter.limit("1/second")
def add_job(): def add_job():
req = request.get_json() req = request.get_json()
@ -49,13 +55,23 @@ def add_job():
for key in req.keys(): for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS): if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404 return jsonify({"msg": "provided one or more unrecognized keys"}), 404
# only checks required key for non-restoration jobs
if req.get(KEY_JOB_TYPE, None) != VALUE_JOB_RESTORATION:
for required_key in REQUIRED_KEYS: for required_key in REQUIRED_KEYS:
if required_key not in req: if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404 return jsonify({"msg": "missing one or more required keys"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req: if (
req[KEY_JOB_TYPE]
in [VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION]
and REFERENCE_IMG not in req
):
return jsonify({"msg": "missing reference image"}), 404 return jsonify({"msg": "missing reference image"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING and MASK_IMG not in req:
return jsonify({"msg": "missing mask image"}), 404
if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS: if KEY_LANGUAGE in req and req[KEY_LANGUAGE] not in SUPPORTED_LANGS:
return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404 return jsonify({"msg": f"not suporting {req[KEY_LANGUAGE]}"}), 404
@ -74,6 +90,7 @@ def add_job():
@app.route("/cancel_job", methods=["POST"]) @app.route("/cancel_job", methods=["POST"])
@limiter.limit("1/second")
def cancel_job(): def cancel_job():
req = request.get_json() req = request.get_json()
if APIKEY not in req: if APIKEY not in req:
@ -114,6 +131,7 @@ def cancel_job():
@app.route("/get_jobs", methods=["POST"]) @app.route("/get_jobs", methods=["POST"])
@limiter.limit("1/second")
def get_jobs(): def get_jobs():
req = request.get_json() req = request.get_json()
if APIKEY not in req: if APIKEY not in req:
@ -127,10 +145,24 @@ def get_jobs():
if UUID in req: if UUID in req:
jobs = database.get_jobs( jobs = database.get_jobs(
job_uuid=req[UUID], apikey=req[APIKEY], limit_count=job_count_limit job_uuid=req[UUID],
apikey=req[APIKEY],
job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [],
limit_count=job_count_limit,
) )
else: else:
jobs = database.get_jobs(apikey=req[APIKEY], limit_count=job_count_limit) jobs = database.get_jobs(
apikey=req[APIKEY],
job_types=req.get[KEY_JOB_TYPE].split(",") if req.get(KEY_JOB_TYPE, "") else [],
limit_count=job_count_limit,
)
for job in jobs:
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
return jsonify({"jobs": jobs}) return jsonify({"jobs": jobs})
@ -143,14 +175,28 @@ def random_jobs():
jobs = database.get_random_jobs(limit_count=job_count_limit) jobs = database.get_random_jobs(limit_count=job_count_limit)
for job in jobs:
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
return jsonify({"jobs": jobs}) return jsonify({"jobs": jobs})
@app.route("/") @app.route("/")
@limiter.limit("1/second")
def index(): def index():
return render_template("index.html") return render_template("index.html")
@app.route("/restoration")
@limiter.limit("1/second")
def restoration():
return render_template("restoration.html")
def main(args): def main(args):
database.set_image_output_folder(args.image_output_folder) database.set_image_output_folder(args.image_output_folder)
database.connect(args.db) database.connect(args.db)

View File

@ -9,3 +9,4 @@ torch==2.0.0
transformers==4.28.1 transformers==4.28.1
sentencepiece==0.1.99 sentencepiece==0.1.99
Flask-Limiter==3.3.1 Flask-Limiter==3.3.1
protobuf==3.20

View File

@ -23,7 +23,11 @@
<div class="collapse navbar-collapse" id="navbarSupportedContent"> <div class="collapse navbar-collapse" id="navbarSupportedContent">
<ul class="navbar-nav me-auto mb-2 mb-lg-0"> <ul class="navbar-nav me-auto mb-2 mb-lg-0">
<li class="nav-item"> <li class="nav-item">
<a class="nav-link active" data-en_XX="Home" data-zh_CN="主页">Home</a> <a class="nav-link active" data-en_XX="Home" data-zh_CN="主页" href=".">Home</a>
</li>
<li class="nav-item">
<a class="nav-link" data-en_XX="Restoration" data-zh_CN="图像修复"
href="restoration">Restoration</a>
</li> </li>
<li class="nav-item"> <li class="nav-item">
<a class="nav-link" data-en_XX="Help" data-zh_CN="帮助">Help</a> <a class="nav-link" data-en_XX="Help" data-zh_CN="帮助">Help</a>
@ -38,7 +42,7 @@
<label class="input-group-text" for="language" data-en_XX="Language" <label class="input-group-text" for="language" data-en_XX="Language"
data-zh_CN="语言">Language</label> data-zh_CN="语言">Language</label>
<select class="form-select" id="language"> <select class="form-select" id="language">
<option value="zh_CN">中文</option> <option value="zh_CN">中文(测试)</option>
<option selected value="en_XX">English</option> <option selected value="en_XX">English</option>
</select> </select>
</div> </div>
@ -496,12 +500,11 @@
} }
} }
} }
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second
}, },
error: function (xhr, status, error) { error: function (xhr, status, error) {
// Handle error response console.log(error);
console.log(xhr.responseText); setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second
$('#txt2ImgStatus').html('failed');
} }
}); });
} }
@ -631,16 +634,12 @@
$('#getJobHistory').click(function () { $('#getJobHistory').click(function () {
var apikeyValue = $('#apiKey').val(); var apikeyValue = $('#apiKey').val();
var uuidValue = $('#lookupUUID').val(); var uuidValue = $('#lookupUUID').val();
if (uuidValue == null) {
alert("no UUID specified");
return;
}
$.ajax({ $.ajax({
type: 'POST', type: 'POST',
url: '/get_jobs', url: '/get_jobs',
contentType: 'application/json; charset=utf-8', contentType: 'application/json; charset=utf-8',
dataType: 'json', dataType: 'json',
data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue }), data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue, 'type': 'txt,img,inpaint' }),
success: function (response) { success: function (response) {
var jobsLength = response.jobs.length; var jobsLength = response.jobs.length;
if (jobsLength == 0) { if (jobsLength == 0) {
@ -662,8 +661,9 @@
"<li class='list-group-item'>seed: " + response.jobs[i].seed + "</li>" + "<li class='list-group-item'>seed: " + response.jobs[i].seed + "</li>" +
"<li class='list-group-item'>uuid: " + response.jobs[i].uuid + "</li>" + "<li class='list-group-item'>uuid: " + response.jobs[i].uuid + "</li>" +
"<li class='list-group-item'>w x h: " + response.jobs[i].width + " x " + response.jobs[i].height + "</li>" + "<li class='list-group-item'>w x h: " + response.jobs[i].width + " x " + response.jobs[i].height + "</li>" +
"</ul>" + "</ul></div>" +
"</div></div></div>"); (response.jobs[i].ref_img ? ("<img src='" + response.jobs[i].ref_img + "' class='card-img-bottom'>") : "") +
"</div></div>");
// Add event handler for click to toggle blurriness // Add event handler for click to toggle blurriness
if (isPrivate === 1) { if (isPrivate === 1) {
element.find('.card').addClass('private-card'); element.find('.card').addClass('private-card');

378
templates/restoration.html Normal file
View File

@ -0,0 +1,378 @@
<html>
<head>
<meta charset="utf-8">
<title>{{ config.TITLE }}</title>
<meta name="description" content="Restoration Online">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link href="{{ url_for('static',filename='bootstrap.min.css') }}" rel="stylesheet">
</head>
<body>
<div class="container">
<nav class="navbar navbar-expand-lg bg-body-tertiary">
<div class="container-fluid">
<a class="navbar-brand" href="#">{{ config.TITLE }}</a>
<button class="navbar-toggler" type="button" data-bs-toggle="collapse"
data-bs-target="#navbarSupportedContent" aria-controls="navbarSupportedContent"
aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div class="collapse navbar-collapse" id="navbarSupportedContent">
<ul class="navbar-nav me-auto mb-2 mb-lg-0">
<li class="nav-item">
<a class="nav-link" data-en_XX="Home" data-zh_CN="主页" href=".">Home</a>
</li>
<li class="nav-item">
<a class="nav-link active" data-en_XX="Restoration" data-zh_CN="图像修复"
href="restoration">Restoration</a>
</li>
<li class="nav-item">
<a class="nav-link" data-en_XX="Help" data-zh_CN="帮助">Help</a>
</li>
<li class="nav-item">
<a class="nav-link" data-en_XX="About" data-zh_CN="关于">About</a>
</li>
</ul>
</div>
<div class="col-md-4">
<div class="input-group">
<label class="input-group-text" for="language" data-en_XX="Language"
data-zh_CN="语言">Language</label>
<select class="form-select" id="language">
<option value="zh_CN">中文(测试)</option>
<option selected value="en_XX">English</option>
</select>
</div>
</div>
</div>
</nav>
<div class="card mb-3">
<div class="card-header">
<span data-en_XX="Restoration" data-zh_CN="图像修复">Restoration</span>
</div>
<div class="card-body">
<div class="row">
<div class="col mb-3">
<div class="input-group">
<label for="apiKey" class="input-group-text" data-en_XX="API Key" data-zh_CN="API 密钥">API
Key</label>
<input type="password" class="form-control" id="apiKey" value="">
<div class="input-group-text">
<input class="form-check-input" type="checkbox" value="" id="isPrivate">
</div>
<label for="isPrivate" class="input-group-text" data-en_XX="Generate Private Images"
data-zh_CN="生成非公开图片">Generate Private Images</label>
</div>
</div>
</div>
<div class="row">
<div class="col-md-3 mb-3">
<div class="input-group input-group-sm">
<label for="restorationUpscale" class="input-group-text" data-en_XX="Upscale"
data-zh_CN="放大倍数">Upscale</label>
<input type="number" class="form-control" id="restorationUpscale"
aria-describedby="upscaleHelp" placeholder="1" min="1" max="8">
</div>
<div id="upscaleHelp" class="form-text" data-en_XX="Upsampling scale of the image"
data-zh_CN="将图像放大多少倍">
Upsampling scale of the image
</div>
</div>
<div class="col-md-3 mb-3">
<div class="input-group input-group-sm">
<label for="restorationWeight" class="input-group-text" data-en_XX="Weight"
data-zh_CN="比重">Weight</label>
<input type="number" class="form-control" id="restorationWeight"
aria-describedby="weightHelp" placeholder="0.5" min="0" max="1">
</div>
<div id="weightHelp" class="form-text" data-en_XX="Adjustable weights" data-zh_CN="0-1比重">
Adjustable weights
</div>
</div>
</div>
<div class="row">
<div class="col-md-6">
<div class="card">
<div class="card-header" data-en_XX="1. Choose Original Image" data-zh_CN="1. 选择原图">
1. Choose Original Image
</div>
<div class="card-body">
<div class="row">
<button id="restoreUploadImg" class="btn btn-primary mb-3" data-en_XX="Upload image"
data-zh_CN="上传一张图片">Upload
image</button>
</div>
<div class="row">
<button id="newRestorationJob" class="btn btn-primary mb-3"
data-en_XX="Let's Go with Image Below!" data-zh_CN="修复下面的图!">Restore Image
Below!</button>
</div>
</div>
<img class="card-img-bottom" id="restoreOriginalImg">
</div>
</div>
<div class="col-md-6">
<div class="card">
<div class="card-header" data-en_XX="2. Result" data-zh_CN="2. 结果">
2. Result
</div>
<div class="card-body">
<ul class="list-group">
<li class="list-group-item d-flex justify-content-between align-items-center">
<span id="restorationJobUUID"></span>
<span class="badge bg-primary rounded-pill" data-en_XX="Job UUID"
data-zh_CN="图片唯一识别码">Job UUID</span>
</li>
<li class="list-group-item d-flex justify-content-between align-items-center">
<span id="restorationStatus"></span>
<span class="badge bg-primary rounded-pill" data-en_XX="Job Status"
data-zh_CN="生成状态">Job Status</span>
</li>
</ul>
</div>
<img class="card-img-bottom" id="restorationImg">
</div>
</div>
</div>
</div>
</div>
<div class="card mb-3">
<div class="card-header">
<span data-en_XX="History" data-zh_CN="历史">History</span>
</div>
<div class="card-body">
<div class="input-group mb-3">
<label for="lookupUUID" class="input-group-text" data-en_XX="UUID (Optional)"
data-zh_CN="图片唯一识别码(选填)">UUID (Optional)</label>
<input type="text" class="form-control" id="lookupUUID" value="">
<button id="getJobHistory" class="btn btn-primary" data-en_XX="Get Job(s)" data-zh_CN="搜索历史">Get
Job(s)</button>
</div>
<div id="joblist"></div>
</div>
</div>
</div>
<!-- CSS code -->
<style>
</style>
<script src="{{ url_for('static',filename='jquery-3.6.1.min.js') }}"></script>
<script src="{{ url_for('static',filename='bootstrap.bundle.min.js') }}"></script>
<script src="{{ url_for('static',filename='jsketch.min.js') }}"></script>
<script src="{{ url_for('static',filename='jquery.sketchable.min.js') }}"></script>
<script src="{{ url_for('static',filename='jquery.sketchable.memento.min.js') }}"></script>
<script src="{{ url_for('static',filename='masonry.pkgd.min.js') }}"></script>
<script src="{{ url_for('static',filename='imagesloaded.pkgd.min.js') }}"></script>
<script>
$(document).ready(function () {
$('input[id]').on('input', function () {
var input = $(this);
var key = input.attr('id');
var value = input.val();
if (input.attr('type') == "checkbox") {
value = input.is(":checked");
}
localStorage.setItem(key, value);
}).each(function () {
var input = $(this);
var key = input.attr('id');
var value = localStorage.getItem(key);
if (input.attr('type') == "checkbox") {
input.prop('checked', value == "true");
} else if (value) {
input.val(value);
}
});
// Define the function to update the text based on the selected language
function updateText(language) {
$("[data-" + language + "]").each(function () {
$(this).text($(this).data(language.toLowerCase()));
});
}
// Listen for changes to the select element
$("#language").change(function () {
// Get the newly selected value
var newLanguage = $(this).val();
// Store the selected value in cache
localStorage.setItem("selectedLanguage", newLanguage);
// Update the text based on the selected language
updateText(newLanguage);
});
// Get the selected value from cache (if it exists)
var cachedLanguage = localStorage.getItem("selectedLanguage");
if (cachedLanguage) {
// Set the selected value
$("#language").val(cachedLanguage);
// Update the text based on the selected language
updateText(cachedLanguage);
}
var restoreOriginalImgData = null;
$("#restoreUploadImg").click(function () {
var input = $("<input type='file' accept='image/*'>").on("change", function () {
var reader = new FileReader();
reader.onload = function (e) {
restoreOriginalImgData = e.target.result;
$("#restoreOriginalImg").attr("src", restoreOriginalImgData);
};
reader.readAsDataURL(this.files[0]);
});
input.click();
});
function waitForImage(apikeyVal, uuidValue) {
// Wait until image is done
$.ajax({
type: 'POST',
url: '/get_jobs',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({ 'apikey': apikeyVal, 'uuid': uuidValue }),
success: function (response) {
console.log(response);
if (response.jobs.length == 1) {
if (response.jobs[0].type == 'restoration') {
$('#restorationStatus').html(response.jobs[0].status);
$('#restorationJobUUID').html(uuidValue);
if (response.jobs[0].status == "done") {
$('#restorationImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
}
}
}
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second
},
error: function (xhr, status, error) {
console.log(error);
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every 1.5 second
}
});
}
function submitJob(formData, uuidSelector, statusSelector) {
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify(formData),
success: function (response) {
if (response.uuid) {
$(uuidSelector).html(response.uuid);
}
$(statusSelector).html('Submitting new job..');
waitForImage(formData.apikey, response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$(statusSelector).html('Failed');
}
});
}
$('#newRestorationJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
if (restoreOriginalImgData == null) {
alert("No image cached");
return;
}
// Helper function to get input field value or a default value if empty
function getInputValue(id, defaultValue) {
var value = $(id).val().trim();
return value !== '' ? value : defaultValue;
}
// Validate input field values
var restorationUpscaleVal = parseInt(getInputValue('#restorationUpscale', '1'));
var restorationWeightVal = parseFloat(getInputValue('#restorationWeight', '0.5'));
var apikeyVal = $('#apiKey').val();
var formData = {
'apikey': apikeyVal,
'type': 'restoration',
'ref_img': restoreOriginalImgData,
'steps': restorationUpscaleVal, // reuse sd keys
'strength': restorationWeightVal, // reuse sd keys
'lang': $("#language option:selected").val(),
'is_private': $('#isPrivate').is(":checked") ? 1 : 0
};
submitJob(formData, '#restorationJobUUID', '#restorationStatus');
});
$('#getJobHistory').click(function () {
var apikeyValue = $('#apiKey').val();
var uuidValue = $('#lookupUUID').val();
$.ajax({
type: 'POST',
url: '/get_jobs',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({ 'apikey': apikeyValue, 'uuid': uuidValue, 'type': 'restoration' }),
success: function (response) {
var jobsLength = response.jobs.length;
if (jobsLength == 0) {
$('#joblist').html("found nothing");
return;
}
var $joblist = $('#joblist');
var $grid = $('<div class="row"></div>');
$joblist.html($grid);
for (var i = 0; i < jobsLength; i++) {
console.log(response.jobs[i]);
var isPrivate = response.jobs[i].is_private;
var element = $("<div class='col col-sm-6 col-md-6 col-lg-4 mb-3'><div class='card'>" +
(response.jobs[i].img ? ("<img src='" + response.jobs[i].img + "' class='card-img-top'><div class='card-body'>") : "") +
"<ul class='list-group list-group-flush'>" +
"<li class='list-group-item'>status: " + response.jobs[i].status + "</li>" +
"<li class='list-group-item'>scale: " + response.jobs[i].steps + "</li>" +
"<li class='list-group-item'>weight: " + response.jobs[i].strength + "</li>" +
"<li class='list-group-item'>uuid: " + response.jobs[i].uuid + "</li>" +
"<li class='list-group-item'>w x h: " + response.jobs[i].width + " x " + response.jobs[i].height + "</li>" +
"</ul></div>" +
(response.jobs[i].ref_img ? ("<img src='" + response.jobs[i].ref_img + "' class='card-img-bottom'>") : "") +
"</div></div>");
$grid.append(element);
};
$grid.imagesLoaded().progress(function () {
$grid.masonry({
itemSelector: '.col',
columnWidth: '.col',
percentPosition: true
});
});
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#joblist').html("found nothing");
}
});
});
});
</script>
</body>
</html>

View File

@ -27,6 +27,17 @@ py_library(
], ],
) )
py_library(
name="external",
srcs=["external.py"],
deps=[
":logger",
":config",
":constants",
":images",
],
)
py_library( py_library(
name="envvar", name="envvar",
srcs=["envvar.py"], srcs=["envvar.py"],

View File

@ -32,6 +32,7 @@ VALUE_JOB_IMG2IMG = "img"
REFERENCE_IMG = "ref_img" REFERENCE_IMG = "ref_img"
MASK_IMG = "mask_img" MASK_IMG = "mask_img"
VALUE_JOB_INPAINTING = "inpaint" VALUE_JOB_INPAINTING = "inpaint"
VALUE_JOB_RESTORATION = "restoration"
KEY_LANGUAGE = "lang" KEY_LANGUAGE = "lang"
VALUE_LANGUAGE_ZH_CN = "zh_CN" VALUE_LANGUAGE_ZH_CN = "zh_CN"

View File

@ -12,6 +12,7 @@ from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_RESTORATION
from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_DONE from utilities.constants import VALUE_JOB_DONE
@ -26,7 +27,6 @@ from utilities.constants import INTERNAL_KEYS
from utilities.constants import REFERENCE_IMG from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG from utilities.constants import MASK_IMG
from utilities.constants import BASE64IMAGE from utilities.constants import BASE64IMAGE
from utilities.constants import IMAGE_NOT_FOUND_BASE64
from utilities.constants import HISTORY_TABLE_NAME from utilities.constants import HISTORY_TABLE_NAME
from utilities.constants import USERS_TABLE_NAME from utilities.constants import USERS_TABLE_NAME
@ -35,7 +35,6 @@ from utilities.logger import DummyLogger
from utilities.times import get_epoch_now from utilities.times import get_epoch_now
from utilities.times import epoch_to_string from utilities.times import epoch_to_string
from utilities.images import save_image from utilities.images import save_image
from utilities.images import load_image
# Function to acquire a lock on the database file # Function to acquire a lock on the database file
@ -130,7 +129,7 @@ class Database:
return result[0] return result[0]
def get_random_jobs(self, limit_count=0) -> list: def get_random_jobs(self, limit_count=0) -> list:
query = f"SELECT {', '.join(ANONYMOUS_KEYS)} FROM {HISTORY_TABLE_NAME} WHERE {KEY_JOB_STATUS} = ? AND {KEY_IS_PRIVATE} = ? AND rowid IN (SELECT rowid FROM {HISTORY_TABLE_NAME} ORDER BY RANDOM() LIMIT ?)" query = f"SELECT {', '.join(ANONYMOUS_KEYS)} FROM {HISTORY_TABLE_NAME} WHERE {KEY_JOB_STATUS} = ? AND {KEY_IS_PRIVATE} = ? AND rowid IN (SELECT rowid FROM {HISTORY_TABLE_NAME} ORDER BY RANDOM() LIMIT ?) AND {KEY_JOB_TYPE} IN ({VALUE_JOB_IMG2IMG, VALUE_JOB_INPAINTING, VALUE_JOB_RESTORATION})"
# execute the query and return the results # execute the query and return the results
c = self.get_cursor() c = self.get_cursor()
@ -143,20 +142,17 @@ class Database:
for i in range(len(ANONYMOUS_KEYS)) for i in range(len(ANONYMOUS_KEYS))
if row[i] is not None if row[i] is not None
} }
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
jobs.append(job) jobs.append(job)
return jobs return jobs
def get_jobs(self, job_uuid="", apikey="", job_status="", limit_count=0) -> list: def get_jobs(
self, job_uuid="", apikey="", job_status="", job_types=[], limit_count=0
) -> list:
""" """
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters. Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter. If `job_uuid` or `apikey` or `job_status` or `job_type` is provided, the query will include that filter.
Returns a list of jobs matching the filters provided. Returns a list of jobs matching the filters provided.
""" """
@ -172,6 +168,11 @@ class Database:
if job_status: if job_status:
query_filters.append(f"{KEY_JOB_STATUS} = ?") query_filters.append(f"{KEY_JOB_STATUS} = ?")
values.append(job_status) values.append(job_status)
if job_types:
query_filters.append(
f"{KEY_JOB_TYPE} IN ({', '.join(['?' for _ in job_types])})"
)
values += job_types
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}" query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
@ -190,11 +191,6 @@ class Database:
job = { job = {
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
} }
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
jobs.append(job) jobs.append(job)
return jobs return jobs
@ -254,6 +250,9 @@ class Database:
Returns True if the update was successful, otherwise False. Returns True if the update was successful, otherwise False.
""" """
if not job_dict:
return False
# store image to job_dict if has one # store image to job_dict if has one
if ( if (
self.__image_output_folder self.__image_output_folder

48
utilities/external.py Normal file
View File

@ -0,0 +1,48 @@
import os
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_BASE_MODEL
from utilities.config import Config
from utilities.logger import DummyLogger
from utilities.images import image_to_base64
from utilities.images import load_image
from utilities.images import save_image
from utilities.images import base64_to_image
def gfpgan(
gfpgan_folderpath, job_uuid, img_filepath, config=Config(), logger=DummyLogger()
):
if not os.path.isdir(gfpgan_folderpath):
logger.error(f"unable to find GFPGAN folder {gfpgan_folderpath}")
return {}
if not os.path.isfile(img_filepath):
logger.error(f"unable to find image file {img_filepath}")
return {}
tmp_output_dir = f"/tmp/{job_uuid}"
os.makedirs(tmp_output_dir, exist_ok=True)
cmd = f"/usr/bin/python {gfpgan_folderpath}/inference_gfpgan.py -i {img_filepath} -o {tmp_output_dir} -v 1.3 -s {config.get_steps()} -w {config.get_strength()}"
logger.info(f"running: {cmd}")
os.system(cmd)
img_output_path = os.path.join(tmp_output_dir, "restored_imgs", os.path.basename(img_filepath))
logger.info(f"image path: {img_output_path}")
try:
image = load_image(img_output_path)
width, height = image.size
return {
BASE64IMAGE: image_to_base64(image),
KEY_WIDTH: width,
KEY_HEIGHT: height,
KEY_BASE_MODEL: "gfpgan",
}
except Exception as e:
logger.error(f"Scaling failed: {e}")
return {}

View File

@ -15,6 +15,7 @@ from utilities.memory import empty_memory_cache
from utilities.model import Model from utilities.model import Model
from utilities.times import get_epoch_now from utilities.times import get_epoch_now
from utilities.images import image_to_base64 from utilities.images import image_to_base64
from utilities.images import load_image
from utilities.images import base64_to_image from utilities.images import base64_to_image
@ -116,7 +117,11 @@ class Img2Img:
self.__logger.info("current seed: {}".format(seed)) self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str): if isinstance(reference_image, str):
if "base64" in reference_image:
reference_image = base64_to_image(reference_image).convert("RGB") reference_image = base64_to_image(reference_image).convert("RGB")
else:
# is filepath
reference_image = load_image(reference_image).convert("RGB")
reference_image.thumbnail((config.get_width(), config.get_height())) reference_image.thumbnail((config.get_width(), config.get_height()))
( (

View File

@ -119,11 +119,19 @@ class Inpainting:
self.__logger.info("current seed: {}".format(seed)) self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str): if isinstance(reference_image, str):
if "base64" in reference_image:
reference_image = base64_to_image(reference_image).convert("RGB") reference_image = base64_to_image(reference_image).convert("RGB")
else:
# is filepath
reference_image = load_image(reference_image).convert("RGB")
reference_image.thumbnail((config.get_width(), config.get_height())) reference_image.thumbnail((config.get_width(), config.get_height()))
if isinstance(mask_image, str): if isinstance(mask_image, str):
if "base64" in mask_image:
mask_image = base64_to_image(mask_image).convert("RGB") mask_image = base64_to_image(mask_image).convert("RGB")
else:
# is filepath
mask_image = load_image(mask_image).convert("RGB")
# assume mask image and reference image size ratio is the same # assume mask image and reference image size ratio is the same
if mask_image.size[0] < reference_image.size[0]: if mask_image.size[0] < reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size) mask_image = mask_image.resize(reference_image.size)