adds inpainting capability

This commit is contained in:
HappyZ 2023-05-14 18:26:21 -07:00
parent 1735857e27
commit 605b05aa8b
11 changed files with 436 additions and 14 deletions

View File

@ -3,6 +3,7 @@ import argparse
from utilities.constants import LOGGER_NAME_BACKEND from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import LOGGER_NAME_IMG2IMG from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import LOGGER_NAME_INPAINT
from utilities.constants import UUID from utilities.constants import UUID
from utilities.constants import KEY_LANGUAGE from utilities.constants import KEY_LANGUAGE
@ -16,7 +17,9 @@ from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import KEY_JOB_TYPE from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import REFERENCE_IMG from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.translator import translate_prompt from utilities.translator import translate_prompt
from utilities.config import Config from utilities.config import Config
@ -25,6 +28,7 @@ from utilities.logger import Logger
from utilities.model import Model from utilities.model import Model
from utilities.text2img import Text2Img from utilities.text2img import Text2Img
from utilities.img2img import Img2Img from utilities.img2img import Img2Img
from utilities.inpainting import Inpainting
from utilities.times import wait_for_seconds from utilities.times import wait_for_seconds
@ -61,6 +65,8 @@ def backend(model, is_debugging: bool):
text2img.breakfast() text2img.breakfast()
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
img2img.breakfast() img2img.breakfast()
inpainting = Inpainting(model, logger=Logger(name=LOGGER_NAME_INPAINT))
inpainting.breakfast()
while 1: while 1:
wait_for_seconds(1) wait_for_seconds(1)
@ -112,10 +118,20 @@ def backend(model, is_debugging: bool):
reference_image=ref_img, reference_image=ref_img,
config=config, config=config,
) )
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING:
ref_img = next_job[REFERENCE_IMG]
mask_img = next_job[MASK_IMG]
result_dict = inpainting.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
mask_image=mask_img,
config=config,
)
except KeyboardInterrupt: except KeyboardInterrupt:
break break
except BaseException as e: except BaseException as e:
logger.error("text2img.lunch error: {}".format(e)) logger.error(e)
if not is_debugging: if not is_debugging:
database.update_job( database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID] {KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]

View File

@ -46,6 +46,7 @@ def create_table_history(c):
neg_prompt TEXT, neg_prompt TEXT,
seed TEXT, seed TEXT,
ref_img TEXT, ref_img TEXT,
mask_img TEXT,
img TEXT, img TEXT,
width INT, width INT,
height INT, height INT,
@ -184,7 +185,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details: if details:
c.execute( c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
(user[1],), (user[1],),
) )
rows = c.fetchall() rows = c.fetchall()
@ -201,7 +202,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details: if details:
c.execute( c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?", "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
(user[1],), (user[1],),
) )
rows = c.fetchall() rows = c.fetchall()

7
static/bootstrap.bundle.min.js vendored Normal file

File diff suppressed because one or more lines are too long

2
static/jquery-3.6.1.min.js vendored Normal file

File diff suppressed because one or more lines are too long

1
static/jquery.sketchable.min.js vendored Normal file

File diff suppressed because one or more lines are too long

1
static/jsketch.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -176,7 +176,70 @@
</div> </div>
</div> </div>
<div class="card-body card-specific" id="card-inpainting" style="display:none"> <div class="card-body card-specific" id="card-inpainting" style="display:none">
TBD <div class="row">
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Original Image" data-zh_CN="原图">
Original Image
</div>
<div class="card-body">
<div class="row">
<button id="copy-txt-to-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Copy from text-to-image"
data-zh_CN="从【文字->图片】结果复制">Copy
from text-to-image</button>
<button id="copy-last-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Copy from last image result"
data-zh_CN="从【图片->图片】结果复制">Copy from last image result</button>
<button id="upload-img-inpaint" class="btn btn-primary mb-3"
data-en_XX="Upload image" data-zh_CN="上传一张图片">Upload
image</button>
</div>
</div>
<img class="card-img-bottom" id="inpaint-img">
</div>
</div>
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Mask Image" data-zh_CN="修复部分">
Mask Image
</div>
<div class="card-body">
<div class="row">
</div>
<div class="row">
<button id="newInpaintingJob" class="btn btn-primary mb-3"
data-en_XX="Let's Go with Image + Mask Below!"
data-zh_CN="就用下面的图进行修复!">Let's
Go with Image + Mask Below!</button>
</div>
</div>
<div class="card-img-bottom" style="position: relative;">
<img id="inpaint-img-for-mask" width="100%">
<canvas style="cursor: pointer; position: absolute; top: 0; left: 0;"
id="inpaint-img-mask">
</div>
</div>
</div>
<div class="col-md-4">
<div class="card">
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
Result
</div>
<div class="card-body">
<ul class="list-group">
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintJobUUID"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintStatus"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="inpaintSeed"></li>
</ul>
</div>
<img class="card-img-bottom" id="inpaintImg">
</div>
</div>
</div>
</div> </div>
<div class="card-body card-specific" id="card-txt" style="display:none"> <div class="card-body card-specific" id="card-txt" style="display:none">
<div class="card"> <div class="card">
@ -218,11 +281,10 @@
</div> </div>
</div> </div>
<script src="https://code.jquery.com/jquery-3.6.1.min.js" <script src="{{ url_for('static',filename='jquery-3.6.1.min.js') }}"></script>
integrity="sha256-o88AwQnZB+VDvE9tvIXrMQaPlFFSUTR+nldQm1LuPXQ=" crossorigin="anonymous"></script> <script src="{{ url_for('static',filename='bootstrap.bundle.min.js') }}"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/js/bootstrap.bundle.min.js" <script src="{{ url_for('static',filename='jsketch.min.js') }}"></script>
integrity="sha384-OERcA2EqjJCMA+/3y+gxIOqMEjwtxJY7qPCqsdltbNJuaOe923+mo//f6V8Qbsw3" <script src="{{ url_for('static',filename='jquery.sketchable.min.js') }}"></script>
crossorigin="anonymous"></script>
<script> <script>
function waitForImage(apikeyVal, uuidValue) { function waitForImage(apikeyVal, uuidValue) {
@ -256,6 +318,16 @@
if (response.jobs[0].status == "failed") { if (response.jobs[0].status == "failed") {
return; return;
} }
} else if (response.jobs[0].type == 'inpaint') {
$('#inpaintStatus').html(response.jobs[0].status);
$('#inpaintSeed').html("seed: " + response.jobs[0].seed);
if (response.jobs[0].status == "done") {
$('#inpaintImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
}
} }
} }
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second
@ -298,7 +370,7 @@
}); });
}); });
// Cache variable to store the selected image data // Cache variable to store the selected image data for img2img
var imageData = null; var imageData = null;
$("#copy-txt-to-img").click(function () { $("#copy-txt-to-img").click(function () {
@ -561,6 +633,135 @@
}); });
}); });
$('#newInpaintingJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
if (inpaintOriginalImg == null) {
alert("No image cached")
return;
}
var canvas = $('#inpaint-img-mask')[0];
var ctx = canvas.getContext('2d');
var maskImageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
// Loop through the pixels and change the colors
for (var i = 0; i < maskImageData.data.length; i += 4) {
if (maskImageData.data[i + 3] == 0) { // If pixel is transparent, change to black
maskImageData.data[i] = 0;
maskImageData.data[i + 1] = 0;
maskImageData.data[i + 2] = 0;
maskImageData.data[i + 3] = 255;
} else { // If pixel is not transparent, change to white
maskImageData.data[i] = 255;
maskImageData.data[i + 1] = 255;
maskImageData.data[i + 2] = 255;
maskImageData.data[i + 3] = 255;
}
}
var tempCanvas = document.createElement('canvas'); // Create a new canvas element
tempCanvas.width = canvas.width; // Set the width of the new canvas to match the original canvas
tempCanvas.height = canvas.height; // Set the height of the new canvas to match the original canvas
var tempCtx = tempCanvas.getContext('2d');
tempCtx.putImageData(maskImageData, 0, 0); // Put modified image data onto the new canvas
var inpaintMaskImg = tempCanvas.toDataURL(); // Get the modified base64-encoded image data
// Gather input field values
var apikeyVal = $('#apiKey').val();
var promptVal = $('#prompt').val();
var negPromptVal = $('#negPrompt').val();
var seedVal = $('#inputSeed').val();
if (seedVal == "0" || seedVal == "") {
seedVal = "0";
}
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
if (isNaN(guidanceScaleVal)) {
guidanceScaleVal = 25.0;
}
var stepsVal = parseInt($('#inputSteps').val());
if (isNaN(stepsVal)) {
stepsVal = 50;
}
var widthVal = parseInt($('#inputWidth').val());
if (isNaN(widthVal)) {
widthVal = 512;
}
var heightVal = parseInt($('#inputHeight').val());
if (isNaN(heightVal)) {
heightVal = 512;
}
if (promptVal == "") {
alert("missing prompt!");
return;
}
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
alert("guidance scale must be between 1 and 30");
return;
}
if (widthVal < 8 || widthVal > 960) {
alert("width must be between 8 and 960!");
return;
}
if (widthVal % 8 != 0) {
alert("width must be divisible by 8!");
return;
}
if (heightVal < 8 || heightVal > 960) {
alert("height must be between 8 and 960!");
return;
}
if (heightVal % 8 != 0) {
alert("height must be divisible by 8!");
return;
}
if (stepsVal > 200 || stepsVal < 1) {
alert("steps value must be between 1 and 200!");
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({
'apikey': apikeyVal,
'type': 'inpaint',
'ref_img': inpaintOriginalImg,
'mask_img': inpaintMaskImg,
'prompt': promptVal,
'seed': seedVal,
'steps': stepsVal,
'width': widthVal,
'height': heightVal,
'lang': $("#language option:selected").val(),
'guidance_scale': guidanceScaleVal,
'neg_prompt': negPromptVal
}),
success: function (response) {
console.log(response);
if (response.uuid) {
$('#inpaintJobUUID').val(response.uuid);
}
$('#inpaintStatus').html('submitting new job..');
waitForImage(apikeyVal, response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#inpaintStatus').html('failed');
}
});
});
// Listen for changes to the select element // Listen for changes to the select element
$("#language").change(function () { $("#language").change(function () {
// Get the newly selected value // Get the newly selected value
@ -584,6 +785,66 @@
$("#language").change(); $("#language").change();
} }
// Cache variable to store the selected image data for inpainting
var inpaintOriginalImg = null;
$("#copy-txt-to-img-inpaint").click(function () {
data = $("#txt2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from txt-to-img result");
return;
}
inpaintOriginalImg = data
$("#inpaint-img").attr("src", inpaintOriginalImg);
$("#inpaint-img").trigger("change");
});
$("#copy-last-img-inpaint").click(function () {
data = $("#img2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from img-to-img result");
return;
}
inpaintOriginalImg = data;
$("#inpaint-img").attr("src", inpaintOriginalImg);
$("#inpaint-img").trigger("change");
});
$("#upload-img-inpaint").click(function () {
var input = $("<input type='file' accept='image/*'>");
input.on("change", function () {
var reader = new FileReader();
reader.onload = function (e) {
inpaintOriginalImg = e.target.result;
$("#inpaint-img").attr("src", inpaintOriginalImg);
var img = new Image();
img.src = inpaintOriginalImg;
img.onload = function() {
$("#inpaint-img").trigger("change");
};
};
reader.readAsDataURL(input[0].files[0]);
});
input.click();
});
$("#inpaint-img").on("change", function () {
var src = $(this).attr("src");
$("#inpaint-img-for-mask").attr("src", src);
$('#inpaint-img-mask').width($(this).width());
$('#inpaint-img-mask').height($(this).height());
var options = {
graphics: {
firstPointSize: 0,
lineWidth: 5,
strokeStyle: 'black',
}
};
var $sketcher = $('#inpaint-img-mask').sketchable(options);
});
}); });
</script> </script>
</body> </body>

View File

@ -95,6 +95,20 @@ py_library(
], ],
) )
py_library(
name="inpainting",
srcs=["inpainting.py"],
deps=[
":constants",
":config",
":logger",
":images",
":memory",
":model",
":times",
],
)
py_library( py_library(
name="times", name="times",
srcs=["times.py"], srcs=["times.py"],

View File

@ -5,6 +5,7 @@ LOGGER_NAME_FRONTEND = VALUE_APP + "_fe"
LOGGER_NAME_BACKEND = VALUE_APP + "_be" LOGGER_NAME_BACKEND = VALUE_APP + "_be"
LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img" LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img"
LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img" LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img"
LOGGER_NAME_INPAINT = VALUE_APP + "_inpaint"
MAX_JOB_NUMBER = 10 MAX_JOB_NUMBER = 10
LOCK_FILEPATH = "/tmp/happysd_db.lock" LOCK_FILEPATH = "/tmp/happysd_db.lock"
@ -29,6 +30,7 @@ KEY_JOB_TYPE = "type"
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
VALUE_JOB_IMG2IMG = "img" VALUE_JOB_IMG2IMG = "img"
REFERENCE_IMG = "ref_img" REFERENCE_IMG = "ref_img"
MASK_IMG = "mask_img"
VALUE_JOB_INPAINTING = "inpaint" VALUE_JOB_INPAINTING = "inpaint"
KEY_LANGUAGE = "lang" KEY_LANGUAGE = "lang"
@ -70,8 +72,9 @@ OPTIONAL_KEYS = [
KEY_STEPS, # int KEY_STEPS, # int
KEY_SCHEDULER, # str KEY_SCHEDULER, # str
KEY_STRENGTH, # float KEY_STRENGTH, # float
REFERENCE_IMG, # str (base64) REFERENCE_IMG, # str (base64 or filepath)
KEY_LANGUAGE, MASK_IMG, # str (base64 or filepath)
KEY_LANGUAGE, # str
] ]
# - output only # - output only

View File

@ -22,6 +22,7 @@ from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS from utilities.constants import REQUIRED_KEYS
from utilities.constants import REFERENCE_IMG from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.constants import BASE64IMAGE from utilities.constants import BASE64IMAGE
from utilities.constants import IMAGE_NOT_FOUND_BASE64 from utilities.constants import IMAGE_NOT_FOUND_BASE64
@ -164,7 +165,7 @@ class Database:
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
} }
# load image to job if has one # load image to job if has one
for key in [BASE64IMAGE, REFERENCE_IMG]: for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
if key in job and "base64" not in job[key]: if key in job and "base64" not in job[key]:
data = load_image(job[key], to_base64=True) data = load_image(job[key], to_base64=True)
job[key] = data if data else IMAGE_NOT_FOUND_BASE64 job[key] = data if data else IMAGE_NOT_FOUND_BASE64
@ -184,16 +185,26 @@ class Database:
job_uuid = str(uuid.uuid4()) job_uuid = str(uuid.uuid4())
self.__logger.info(f"inserting a new job with {job_uuid}") self.__logger.info(f"inserting a new job with {job_uuid}")
current_epoch = get_epoch_now()
# store image to job_dict if has one # store image to job_dict if has one
if ( if (
self.__image_output_folder self.__image_output_folder
and REFERENCE_IMG in job_dict and REFERENCE_IMG in job_dict
and "base64" in job_dict[REFERENCE_IMG] and "base64" in job_dict[REFERENCE_IMG]
): ):
ref_img_filepath = f"{self.__image_output_folder}/{get_epoch_now()}_ref.png" ref_img_filepath = f"{self.__image_output_folder}/{current_epoch}_ref.png"
self.__logger.info(f"saving reference image to {ref_img_filepath}") self.__logger.info(f"saving reference image to {ref_img_filepath}")
if save_image(job_dict[REFERENCE_IMG], ref_img_filepath): if save_image(job_dict[REFERENCE_IMG], ref_img_filepath):
job_dict[REFERENCE_IMG] = ref_img_filepath job_dict[REFERENCE_IMG] = ref_img_filepath
if (
self.__image_output_folder
and MASK_IMG in job_dict
and "base64" in job_dict[MASK_IMG]
):
mask_img_filepath = f"{self.__image_output_folder}/{current_epoch}_mask.png"
self.__logger.info(f"saving mask image to {mask_img_filepath}")
if save_image(job_dict[MASK_IMG], mask_img_filepath):
job_dict[MASK_IMG] = mask_img_filepath
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()] values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS

105
utilities/inpainting.py Normal file
View File

@ -0,0 +1,105 @@
import torch
from typing import Union
from PIL import Image
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_SEED
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_STEPS
from utilities.config import Config
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
from utilities.images import base64_to_image
class Inpainting:
"""
Inpainting class.
"""
def __init__(
self,
model: Model,
output_folder: str = "",
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__output_folder = output_folder
self.__logger = logger
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
def lunch(
self,
prompt: str,
negative_prompt: str = "",
reference_image: Union[Image.Image, None, str] = None,
mask_image: Union[Image.Image, None, str] = None,
config: Config = Config(),
) -> dict:
if not prompt:
self.__logger.error("no prompt provided, won't proceed")
return {}
if reference_image is None:
return {}
if mask_image is None:
return {}
self.model.set_inpaint_scheduler(config.get_scheduler())
t = get_epoch_now()
seed = config.get_seed()
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str):
reference_image = base64_to_image(reference_image).convert("RGB")
reference_image.thumbnail((config.get_width(), config.get_height()))
if isinstance(mask_image, str):
mask_image = base64_to_image(mask_image).convert("RGB")
# assume mask image and reference image size ratio is the same
if mask_image.size[0] < reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size)
elif mask_image.size[0] > reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
result = self.model.inpaint_pipeline(
prompt=prompt,
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
# resize it back based on ratio (keep width 512)
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result_img.save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()
return {
BASE64IMAGE: image_to_base64(result_img),
KEY_SEED: str(seed),
KEY_WIDTH: config.get_width(),
KEY_HEIGHT: config.get_height(),
KEY_STEPS: config.get_steps(),
}