adds inpainting capability

This commit is contained in:
HappyZ 2023-05-14 18:26:21 -07:00
parent 1735857e27
commit 605b05aa8b
11 changed files with 436 additions and 14 deletions

View File

@ -3,6 +3,7 @@ import argparse
from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import LOGGER_NAME_INPAINT
from utilities.constants import UUID
from utilities.constants import KEY_LANGUAGE
@ -16,7 +17,9 @@ from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.translator import translate_prompt
from utilities.config import Config
@ -25,6 +28,7 @@ from utilities.logger import Logger
from utilities.model import Model
from utilities.text2img import Text2Img
from utilities.img2img import Img2Img
from utilities.inpainting import Inpainting
from utilities.times import wait_for_seconds
@ -61,6 +65,8 @@ def backend(model, is_debugging: bool):
text2img.breakfast()
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
img2img.breakfast()
inpainting = Inpainting(model, logger=Logger(name=LOGGER_NAME_INPAINT))
inpainting.breakfast()
while 1:
wait_for_seconds(1)
@ -112,10 +118,20 @@ def backend(model, is_debugging: bool):
reference_image=ref_img,
config=config,
)
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING:
ref_img = next_job[REFERENCE_IMG]
mask_img = next_job[MASK_IMG]
result_dict = inpainting.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
mask_image=mask_img,
config=config,
)
except KeyboardInterrupt:
break
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
logger.error(e)
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]

View File

@ -46,6 +46,7 @@ def create_table_history(c):
neg_prompt TEXT,
seed TEXT,
ref_img TEXT,
mask_img TEXT,
img TEXT,
width INT,
height INT,
@ -184,7 +185,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()
@ -201,7 +202,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()

7
static/bootstrap.bundle.min.js vendored Normal file

File diff suppressed because one or more lines are too long

2
static/jquery-3.6.1.min.js vendored Normal file

File diff suppressed because one or more lines are too long

1
static/jquery.sketchable.min.js vendored Normal file

File diff suppressed because one or more lines are too long

1
static/jsketch.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -176,7 +176,70 @@
</div>
</div>
<div class="card-body card-specific" id="card-inpainting" style="display:none">
TBD
<div class="row">
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Original Image" data-zh_CN="原图">
Original Image
</div>
<div class="card-body">
<div class="row">
<button id="copy-txt-to-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Copy from text-to-image"
data-zh_CN="从【文字->图片】结果复制">Copy
from text-to-image</button>
<button id="copy-last-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Copy from last image result"
data-zh_CN="从【图片->图片】结果复制">Copy from last image result</button>
<button id="upload-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Upload image" data-zh_CN="上传一张图片">Upload
image</button>
</div>
</div>
<img class="card-img-bottom" id="inpaint-img">
</div>
</div>
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Mask Image" data-zh_CN="修复部分">
Mask Image
</div>
<div class="card-body">
<div class="row">
</div>
<div class="row">
<button id="newInpaintingJob" class="btn btn-primary mb-3"
data-en_XX="Let's Go with Image + Mask Below!"
data-zh_CN="就用下面的图进行修复!">Let's
Go with Image + Mask Below!</button>
</div>
</div>
<div class="card-img-bottom" style="position: relative;">
<img id="inpaint-img-for-mask" width="100%">
<canvas style="cursor: pointer; position: absolute; top: 0; left: 0;"
id="inpaint-img-mask">
</div>
</div>
</div>
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
Result
</div>
<div class="card-body">
<ul class="list-group">
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintJobUUID"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintStatus"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintSeed"></li>
</ul>
</div>
<img class="card-img-bottom" id="inpaintImg">
</div>
</div>
</div>
</div>
<div class="card-body card-specific" id="card-txt" style="display:none">
<div class="card">
@ -218,11 +281,10 @@
</div>
</div>
<script src="https://code.jquery.com/jquery-3.6.1.min.js"
integrity="sha256-o88AwQnZB+VDvE9tvIXrMQaPlFFSUTR+nldQm1LuPXQ=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/js/bootstrap.bundle.min.js"
integrity="sha384-OERcA2EqjJCMA+/3y+gxIOqMEjwtxJY7qPCqsdltbNJuaOe923+mo//f6V8Qbsw3"
crossorigin="anonymous"></script>
<script src="{{ url_for('static',filename='jquery-3.6.1.min.js') }}"></script>
<script src="{{ url_for('static',filename='bootstrap.bundle.min.js') }}"></script>
<script src="{{ url_for('static',filename='jsketch.min.js') }}"></script>
<script src="{{ url_for('static',filename='jquery.sketchable.min.js') }}"></script>
<script>
function waitForImage(apikeyVal, uuidValue) {
@ -256,6 +318,16 @@
if (response.jobs[0].status == "failed") {
return;
}
} else if (response.jobs[0].type == 'inpaint') {
$('#inpaintStatus').html(response.jobs[0].status);
$('#inpaintSeed').html("seed: " + response.jobs[0].seed);
if (response.jobs[0].status == "done") {
$('#inpaintImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
}
}
}
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second
@ -298,7 +370,7 @@
});
});
// Cache variable to store the selected image data
// Cache variable to store the selected image data for img2img
var imageData = null;
$("#copy-txt-to-img").click(function () {
@ -561,6 +633,135 @@
});
});
$('#newInpaintingJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
if (inpaintOriginalImg == null) {
alert("No image cached")
return;
}
var canvas = $('#inpaint-img-mask')[0];
var ctx = canvas.getContext('2d');
var maskImageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
// Loop through the pixels and change the colors
for (var i = 0; i < maskImageData.data.length; i += 4) {
if (maskImageData.data[i + 3] == 0) { // If pixel is transparent, change to black
maskImageData.data[i] = 0;
maskImageData.data[i + 1] = 0;
maskImageData.data[i + 2] = 0;
maskImageData.data[i + 3] = 255;
} else { // If pixel is not transparent, change to white
maskImageData.data[i] = 255;
maskImageData.data[i + 1] = 255;
maskImageData.data[i + 2] = 255;
maskImageData.data[i + 3] = 255;
}
}
var tempCanvas = document.createElement('canvas'); // Create a new canvas element
tempCanvas.width = canvas.width; // Set the width of the new canvas to match the original canvas
tempCanvas.height = canvas.height; // Set the height of the new canvas to match the original canvas
var tempCtx = tempCanvas.getContext('2d');
tempCtx.putImageData(maskImageData, 0, 0); // Put modified image data onto the new canvas
var inpaintMaskImg = tempCanvas.toDataURL(); // Get the modified base64-encoded image data
// Gather input field values
var apikeyVal = $('#apiKey').val();
var promptVal = $('#prompt').val();
var negPromptVal = $('#negPrompt').val();
var seedVal = $('#inputSeed').val();
if (seedVal == "0" || seedVal == "") {
seedVal = "0";
}
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
if (isNaN(guidanceScaleVal)) {
guidanceScaleVal = 25.0;
}
var stepsVal = parseInt($('#inputSteps').val());
if (isNaN(stepsVal)) {
stepsVal = 50;
}
var widthVal = parseInt($('#inputWidth').val());
if (isNaN(widthVal)) {
widthVal = 512;
}
var heightVal = parseInt($('#inputHeight').val());
if (isNaN(heightVal)) {
heightVal = 512;
}
if (promptVal == "") {
alert("missing prompt!");
return;
}
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
alert("guidance scale must be between 1 and 30");
return;
}
if (widthVal < 8 || widthVal > 960) {
alert("width must be between 8 and 960!");
return;
}
if (widthVal % 8 != 0) {
alert("width must be divisible by 8!");
return;
}
if (heightVal < 8 || heightVal > 960) {
alert("height must be between 8 and 960!");
return;
}
if (heightVal % 8 != 0) {
alert("height must be divisible by 8!");
return;
}
if (stepsVal > 200 || stepsVal < 1) {
alert("steps value must be between 1 and 200!");
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({
'apikey': apikeyVal,
'type': 'inpaint',
'ref_img': inpaintOriginalImg,
'mask_img': inpaintMaskImg,
'prompt': promptVal,
'seed': seedVal,
'steps': stepsVal,
'width': widthVal,
'height': heightVal,
'lang': $("#language option:selected").val(),
'guidance_scale': guidanceScaleVal,
'neg_prompt': negPromptVal
}),
success: function (response) {
console.log(response);
if (response.uuid) {
$('#inpaintJobUUID').val(response.uuid);
}
$('#inpaintStatus').html('submitting new job..');
waitForImage(apikeyVal, response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#inpaintStatus').html('failed');
}
});
});
// Listen for changes to the select element
$("#language").change(function () {
// Get the newly selected value
@ -584,6 +785,66 @@
$("#language").change();
}
// Cache variable to store the selected image data for inpainting
var inpaintOriginalImg = null;
$("#copy-txt-to-img-inpaint").click(function () {
data = $("#txt2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from txt-to-img result");
return;
}
inpaintOriginalImg = data
$("#inpaint-img").attr("src", inpaintOriginalImg);
$("#inpaint-img").trigger("change");
});
$("#copy-last-img-inpaint").click(function () {
data = $("#img2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from img-to-img result");
return;
}
inpaintOriginalImg = data;
$("#inpaint-img").attr("src", inpaintOriginalImg);
$("#inpaint-img").trigger("change");
});
$("#upload-img-inpaint").click(function () {
var input = $("<input type='file' accept='image/*'>");
input.on("change", function () {
var reader = new FileReader();
reader.onload = function (e) {
inpaintOriginalImg = e.target.result;
$("#inpaint-img").attr("src", inpaintOriginalImg);
var img = new Image();
img.src = inpaintOriginalImg;
img.onload = function() {
$("#inpaint-img").trigger("change");
};
};
reader.readAsDataURL(input[0].files[0]);
});
input.click();
});
$("#inpaint-img").on("change", function () {
var src = $(this).attr("src");
$("#inpaint-img-for-mask").attr("src", src);
$('#inpaint-img-mask').width($(this).width());
$('#inpaint-img-mask').height($(this).height());
var options = {
graphics: {
firstPointSize: 0,
lineWidth: 5,
strokeStyle: 'black',
}
};
var $sketcher = $('#inpaint-img-mask').sketchable(options);
});
});
</script>
</body>

View File

@ -95,6 +95,20 @@ py_library(
],
)
py_library(
name="inpainting",
srcs=["inpainting.py"],
deps=[
":constants",
":config",
":logger",
":images",
":memory",
":model",
":times",
],
)
py_library(
name="times",
srcs=["times.py"],

View File

@ -5,6 +5,7 @@ LOGGER_NAME_FRONTEND = VALUE_APP + "_fe"
LOGGER_NAME_BACKEND = VALUE_APP + "_be"
LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img"
LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img"
LOGGER_NAME_INPAINT = VALUE_APP + "_inpaint"
MAX_JOB_NUMBER = 10
LOCK_FILEPATH = "/tmp/happysd_db.lock"
@ -29,6 +30,7 @@ KEY_JOB_TYPE = "type"
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
VALUE_JOB_IMG2IMG = "img"
REFERENCE_IMG = "ref_img"
MASK_IMG = "mask_img"
VALUE_JOB_INPAINTING = "inpaint"
KEY_LANGUAGE = "lang"
@ -70,8 +72,9 @@ OPTIONAL_KEYS = [
KEY_STEPS, # int
KEY_SCHEDULER, # str
KEY_STRENGTH, # float
REFERENCE_IMG, # str (base64)
KEY_LANGUAGE,
REFERENCE_IMG, # str (base64 or filepath)
MASK_IMG, # str (base64 or filepath)
KEY_LANGUAGE, # str
]
# - output only

View File

@ -22,6 +22,7 @@ from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.constants import BASE64IMAGE
from utilities.constants import IMAGE_NOT_FOUND_BASE64
@ -164,7 +165,7 @@ class Database:
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
}
# load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG]:
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
@ -184,16 +185,26 @@ class Database:
job_uuid = str(uuid.uuid4())
self.__logger.info(f"inserting a new job with {job_uuid}")
current_epoch = get_epoch_now()
# store image to job_dict if has one
if (
self.__image_output_folder
and REFERENCE_IMG in job_dict
and "base64" in job_dict[REFERENCE_IMG]
):
ref_img_filepath = f"{self.__image_output_folder}/{get_epoch_now()}_ref.png"
ref_img_filepath = f"{self.__image_output_folder}/{current_epoch}_ref.png"
self.__logger.info(f"saving reference image to {ref_img_filepath}")
if save_image(job_dict[REFERENCE_IMG], ref_img_filepath):
job_dict[REFERENCE_IMG] = ref_img_filepath
if (
self.__image_output_folder
and MASK_IMG in job_dict
and "base64" in job_dict[MASK_IMG]
):
mask_img_filepath = f"{self.__image_output_folder}/{current_epoch}_mask.png"
self.__logger.info(f"saving mask image to {mask_img_filepath}")
if save_image(job_dict[MASK_IMG], mask_img_filepath):
job_dict[MASK_IMG] = mask_img_filepath
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS

105
utilities/inpainting.py Normal file
View File

@ -0,0 +1,105 @@
import torch
from typing import Union
from PIL import Image
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_SEED
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_STEPS
from utilities.config import Config
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
from utilities.images import base64_to_image
class Inpainting:
"""
Inpainting class.
"""
def __init__(
self,
model: Model,
output_folder: str = "",
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__output_folder = output_folder
self.__logger = logger
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
def lunch(
self,
prompt: str,
negative_prompt: str = "",
reference_image: Union[Image.Image, None, str] = None,
mask_image: Union[Image.Image, None, str] = None,
config: Config = Config(),
) -> dict:
if not prompt:
self.__logger.error("no prompt provided, won't proceed")
return {}
if reference_image is None:
return {}
if mask_image is None:
return {}
self.model.set_inpaint_scheduler(config.get_scheduler())
t = get_epoch_now()
seed = config.get_seed()
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str):
reference_image = base64_to_image(reference_image).convert("RGB")
reference_image.thumbnail((config.get_width(), config.get_height()))
if isinstance(mask_image, str):
mask_image = base64_to_image(mask_image).convert("RGB")
# assume mask image and reference image size ratio is the same
if mask_image.size[0] < reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size)
elif mask_image.size[0] > reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
result = self.model.inpaint_pipeline(
prompt=prompt,
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
# resize it back based on ratio (keep width 512)
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result_img.save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()
return {
BASE64IMAGE: image_to_base64(result_img),
KEY_SEED: str(seed),
KEY_WIDTH: config.get_width(),
KEY_HEIGHT: config.get_height(),
KEY_STEPS: config.get_steps(),
}