adds inpainting capability
This commit is contained in:
parent
1735857e27
commit
605b05aa8b
18
backend.py
18
backend.py
|
|
@ -3,6 +3,7 @@ import argparse
|
|||
from utilities.constants import LOGGER_NAME_BACKEND
|
||||
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||
from utilities.constants import LOGGER_NAME_IMG2IMG
|
||||
from utilities.constants import LOGGER_NAME_INPAINT
|
||||
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import KEY_LANGUAGE
|
||||
|
|
@ -16,7 +17,9 @@ from utilities.constants import VALUE_JOB_RUNNING
|
|||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import VALUE_JOB_INPAINTING
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import MASK_IMG
|
||||
|
||||
from utilities.translator import translate_prompt
|
||||
from utilities.config import Config
|
||||
|
|
@ -25,6 +28,7 @@ from utilities.logger import Logger
|
|||
from utilities.model import Model
|
||||
from utilities.text2img import Text2Img
|
||||
from utilities.img2img import Img2Img
|
||||
from utilities.inpainting import Inpainting
|
||||
from utilities.times import wait_for_seconds
|
||||
|
||||
|
||||
|
|
@ -61,6 +65,8 @@ def backend(model, is_debugging: bool):
|
|||
text2img.breakfast()
|
||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||
img2img.breakfast()
|
||||
inpainting = Inpainting(model, logger=Logger(name=LOGGER_NAME_INPAINT))
|
||||
inpainting.breakfast()
|
||||
|
||||
while 1:
|
||||
wait_for_seconds(1)
|
||||
|
|
@ -112,10 +118,20 @@ def backend(model, is_debugging: bool):
|
|||
reference_image=ref_img,
|
||||
config=config,
|
||||
)
|
||||
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_INPAINTING:
|
||||
ref_img = next_job[REFERENCE_IMG]
|
||||
mask_img = next_job[MASK_IMG]
|
||||
result_dict = inpainting.lunch(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
reference_image=ref_img,
|
||||
mask_image=mask_img,
|
||||
config=config,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except BaseException as e:
|
||||
logger.error("text2img.lunch error: {}".format(e))
|
||||
logger.error(e)
|
||||
if not is_debugging:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def create_table_history(c):
|
|||
neg_prompt TEXT,
|
||||
seed TEXT,
|
||||
ref_img TEXT,
|
||||
mask_img TEXT,
|
||||
img TEXT,
|
||||
width INT,
|
||||
height INT,
|
||||
|
|
@ -184,7 +185,7 @@ def show_users(c, username="", details=False):
|
|||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute(
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
|
||||
(user[1],),
|
||||
)
|
||||
rows = c.fetchall()
|
||||
|
|
@ -201,7 +202,7 @@ def show_users(c, username="", details=False):
|
|||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute(
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
|
||||
(user[1],),
|
||||
)
|
||||
rows = c.fetchall()
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -176,7 +176,70 @@
|
|||
</div>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-inpainting" style="display:none">
|
||||
TBD
|
||||
<div class="row">
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header" data-en_XX="Original Image" data-zh_CN="原图">
|
||||
Original Image
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<button id="copy-txt-to-img-inpaint" class="btn btn-primary mb-3"
|
||||
data-en_XX="Copy from text-to-image"
|
||||
data-zh_CN="从【文字->图片】结果复制">Copy
|
||||
from text-to-image</button>
|
||||
<button id="copy-last-img-inpaint" class="btn btn-primary mb-3"
|
||||
data-en_XX="Copy from last image result"
|
||||
data-zh_CN="从【图片->图片】结果复制">Copy from last image result</button>
|
||||
<button id="upload-img-inpaint" class="btn btn-primary mb-3"
|
||||
data-en_XX="Upload image" data-zh_CN="上传一张图片">Upload
|
||||
image</button>
|
||||
</div>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="inpaint-img">
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header" data-en_XX="Mask Image" data-zh_CN="修复部分">
|
||||
Mask Image
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
</div>
|
||||
<div class="row">
|
||||
<button id="newInpaintingJob" class="btn btn-primary mb-3"
|
||||
data-en_XX="Let's Go with Image + Mask Below!"
|
||||
data-zh_CN="就用下面的图进行修复!">Let's
|
||||
Go with Image + Mask Below!</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-img-bottom" style="position: relative;">
|
||||
<img id="inpaint-img-for-mask" width="100%">
|
||||
<canvas style="cursor: pointer; position: absolute; top: 0; left: 0;"
|
||||
id="inpaint-img-mask">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
|
||||
Result
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<ul class="list-group">
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="inpaintJobUUID"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="inpaintStatus"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="inpaintSeed"></li>
|
||||
</ul>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="inpaintImg">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-txt" style="display:none">
|
||||
<div class="card">
|
||||
|
|
@ -218,11 +281,10 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://code.jquery.com/jquery-3.6.1.min.js"
|
||||
integrity="sha256-o88AwQnZB+VDvE9tvIXrMQaPlFFSUTR+nldQm1LuPXQ=" crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/js/bootstrap.bundle.min.js"
|
||||
integrity="sha384-OERcA2EqjJCMA+/3y+gxIOqMEjwtxJY7qPCqsdltbNJuaOe923+mo//f6V8Qbsw3"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="{{ url_for('static',filename='jquery-3.6.1.min.js') }}"></script>
|
||||
<script src="{{ url_for('static',filename='bootstrap.bundle.min.js') }}"></script>
|
||||
<script src="{{ url_for('static',filename='jsketch.min.js') }}"></script>
|
||||
<script src="{{ url_for('static',filename='jquery.sketchable.min.js') }}"></script>
|
||||
|
||||
<script>
|
||||
function waitForImage(apikeyVal, uuidValue) {
|
||||
|
|
@ -256,6 +318,16 @@
|
|||
if (response.jobs[0].status == "failed") {
|
||||
return;
|
||||
}
|
||||
} else if (response.jobs[0].type == 'inpaint') {
|
||||
$('#inpaintStatus').html(response.jobs[0].status);
|
||||
$('#inpaintSeed').html("seed: " + response.jobs[0].seed);
|
||||
if (response.jobs[0].status == "done") {
|
||||
$('#inpaintImg').attr('src', response.jobs[0].img);
|
||||
return;
|
||||
}
|
||||
if (response.jobs[0].status == "failed") {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second
|
||||
|
|
@ -298,7 +370,7 @@
|
|||
});
|
||||
});
|
||||
|
||||
// Cache variable to store the selected image data
|
||||
// Cache variable to store the selected image data for img2img
|
||||
var imageData = null;
|
||||
|
||||
$("#copy-txt-to-img").click(function () {
|
||||
|
|
@ -561,6 +633,135 @@
|
|||
});
|
||||
});
|
||||
|
||||
$('#newInpaintingJob').click(function (e) {
|
||||
e.preventDefault(); // Prevent the default form submission
|
||||
|
||||
if (inpaintOriginalImg == null) {
|
||||
alert("No image cached")
|
||||
return;
|
||||
}
|
||||
|
||||
var canvas = $('#inpaint-img-mask')[0];
|
||||
var ctx = canvas.getContext('2d');
|
||||
var maskImageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
||||
|
||||
// Loop through the pixels and change the colors
|
||||
for (var i = 0; i < maskImageData.data.length; i += 4) {
|
||||
if (maskImageData.data[i + 3] == 0) { // If pixel is transparent, change to black
|
||||
maskImageData.data[i] = 0;
|
||||
maskImageData.data[i + 1] = 0;
|
||||
maskImageData.data[i + 2] = 0;
|
||||
maskImageData.data[i + 3] = 255;
|
||||
} else { // If pixel is not transparent, change to white
|
||||
maskImageData.data[i] = 255;
|
||||
maskImageData.data[i + 1] = 255;
|
||||
maskImageData.data[i + 2] = 255;
|
||||
maskImageData.data[i + 3] = 255;
|
||||
}
|
||||
}
|
||||
|
||||
var tempCanvas = document.createElement('canvas'); // Create a new canvas element
|
||||
tempCanvas.width = canvas.width; // Set the width of the new canvas to match the original canvas
|
||||
tempCanvas.height = canvas.height; // Set the height of the new canvas to match the original canvas
|
||||
var tempCtx = tempCanvas.getContext('2d');
|
||||
tempCtx.putImageData(maskImageData, 0, 0); // Put modified image data onto the new canvas
|
||||
var inpaintMaskImg = tempCanvas.toDataURL(); // Get the modified base64-encoded image data
|
||||
|
||||
// Gather input field values
|
||||
var apikeyVal = $('#apiKey').val();
|
||||
var promptVal = $('#prompt').val();
|
||||
var negPromptVal = $('#negPrompt').val();
|
||||
var seedVal = $('#inputSeed').val();
|
||||
if (seedVal == "0" || seedVal == "") {
|
||||
seedVal = "0";
|
||||
}
|
||||
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
|
||||
if (isNaN(guidanceScaleVal)) {
|
||||
guidanceScaleVal = 25.0;
|
||||
}
|
||||
var stepsVal = parseInt($('#inputSteps').val());
|
||||
if (isNaN(stepsVal)) {
|
||||
stepsVal = 50;
|
||||
}
|
||||
var widthVal = parseInt($('#inputWidth').val());
|
||||
if (isNaN(widthVal)) {
|
||||
widthVal = 512;
|
||||
}
|
||||
var heightVal = parseInt($('#inputHeight').val());
|
||||
if (isNaN(heightVal)) {
|
||||
heightVal = 512;
|
||||
}
|
||||
|
||||
if (promptVal == "") {
|
||||
alert("missing prompt!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
|
||||
alert("guidance scale must be between 1 and 30");
|
||||
return;
|
||||
}
|
||||
|
||||
if (widthVal < 8 || widthVal > 960) {
|
||||
alert("width must be between 8 and 960!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (widthVal % 8 != 0) {
|
||||
alert("width must be divisible by 8!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (heightVal < 8 || heightVal > 960) {
|
||||
alert("height must be between 8 and 960!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (heightVal % 8 != 0) {
|
||||
alert("height must be divisible by 8!");
|
||||
return;
|
||||
}
|
||||
if (stepsVal > 200 || stepsVal < 1) {
|
||||
alert("steps value must be between 1 and 200!");
|
||||
return;
|
||||
}
|
||||
|
||||
// Send POST request using Ajax
|
||||
$.ajax({
|
||||
type: 'POST',
|
||||
url: '/add_job',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
dataType: 'json',
|
||||
data: JSON.stringify({
|
||||
'apikey': apikeyVal,
|
||||
'type': 'inpaint',
|
||||
'ref_img': inpaintOriginalImg,
|
||||
'mask_img': inpaintMaskImg,
|
||||
'prompt': promptVal,
|
||||
'seed': seedVal,
|
||||
'steps': stepsVal,
|
||||
'width': widthVal,
|
||||
'height': heightVal,
|
||||
'lang': $("#language option:selected").val(),
|
||||
'guidance_scale': guidanceScaleVal,
|
||||
'neg_prompt': negPromptVal
|
||||
}),
|
||||
success: function (response) {
|
||||
console.log(response);
|
||||
if (response.uuid) {
|
||||
$('#inpaintJobUUID').val(response.uuid);
|
||||
}
|
||||
$('#inpaintStatus').html('submitting new job..');
|
||||
waitForImage(apikeyVal, response.uuid);
|
||||
},
|
||||
error: function (xhr, status, error) {
|
||||
// Handle error response
|
||||
console.log(xhr.responseText);
|
||||
$('#inpaintStatus').html('failed');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Listen for changes to the select element
|
||||
$("#language").change(function () {
|
||||
// Get the newly selected value
|
||||
|
|
@ -584,6 +785,66 @@
|
|||
$("#language").change();
|
||||
}
|
||||
|
||||
// Cache variable to store the selected image data for inpainting
|
||||
var inpaintOriginalImg = null;
|
||||
|
||||
$("#copy-txt-to-img-inpaint").click(function () {
|
||||
data = $("#txt2ImgImg").attr("src");
|
||||
if (data == null || data == "") {
|
||||
alert("nothing found from txt-to-img result");
|
||||
return;
|
||||
}
|
||||
inpaintOriginalImg = data
|
||||
$("#inpaint-img").attr("src", inpaintOriginalImg);
|
||||
$("#inpaint-img").trigger("change");
|
||||
});
|
||||
|
||||
$("#copy-last-img-inpaint").click(function () {
|
||||
data = $("#img2ImgImg").attr("src");
|
||||
if (data == null || data == "") {
|
||||
alert("nothing found from img-to-img result");
|
||||
return;
|
||||
}
|
||||
inpaintOriginalImg = data;
|
||||
$("#inpaint-img").attr("src", inpaintOriginalImg);
|
||||
$("#inpaint-img").trigger("change");
|
||||
});
|
||||
|
||||
$("#upload-img-inpaint").click(function () {
|
||||
var input = $("<input type='file' accept='image/*'>");
|
||||
input.on("change", function () {
|
||||
var reader = new FileReader();
|
||||
reader.onload = function (e) {
|
||||
inpaintOriginalImg = e.target.result;
|
||||
$("#inpaint-img").attr("src", inpaintOriginalImg);
|
||||
var img = new Image();
|
||||
img.src = inpaintOriginalImg;
|
||||
img.onload = function() {
|
||||
$("#inpaint-img").trigger("change");
|
||||
};
|
||||
};
|
||||
reader.readAsDataURL(input[0].files[0]);
|
||||
});
|
||||
input.click();
|
||||
});
|
||||
|
||||
$("#inpaint-img").on("change", function () {
|
||||
var src = $(this).attr("src");
|
||||
$("#inpaint-img-for-mask").attr("src", src);
|
||||
$('#inpaint-img-mask').width($(this).width());
|
||||
$('#inpaint-img-mask').height($(this).height());
|
||||
|
||||
var options = {
|
||||
graphics: {
|
||||
firstPointSize: 0,
|
||||
lineWidth: 5,
|
||||
strokeStyle: 'black',
|
||||
}
|
||||
};
|
||||
|
||||
var $sketcher = $('#inpaint-img-mask').sketchable(options);
|
||||
});
|
||||
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
|
|
|
|||
|
|
@ -95,6 +95,20 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="inpainting",
|
||||
srcs=["inpainting.py"],
|
||||
deps=[
|
||||
":constants",
|
||||
":config",
|
||||
":logger",
|
||||
":images",
|
||||
":memory",
|
||||
":model",
|
||||
":times",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="times",
|
||||
srcs=["times.py"],
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ LOGGER_NAME_FRONTEND = VALUE_APP + "_fe"
|
|||
LOGGER_NAME_BACKEND = VALUE_APP + "_be"
|
||||
LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img"
|
||||
LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img"
|
||||
LOGGER_NAME_INPAINT = VALUE_APP + "_inpaint"
|
||||
MAX_JOB_NUMBER = 10
|
||||
|
||||
LOCK_FILEPATH = "/tmp/happysd_db.lock"
|
||||
|
|
@ -29,6 +30,7 @@ KEY_JOB_TYPE = "type"
|
|||
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
|
||||
VALUE_JOB_IMG2IMG = "img"
|
||||
REFERENCE_IMG = "ref_img"
|
||||
MASK_IMG = "mask_img"
|
||||
VALUE_JOB_INPAINTING = "inpaint"
|
||||
|
||||
KEY_LANGUAGE = "lang"
|
||||
|
|
@ -70,8 +72,9 @@ OPTIONAL_KEYS = [
|
|||
KEY_STEPS, # int
|
||||
KEY_SCHEDULER, # str
|
||||
KEY_STRENGTH, # float
|
||||
REFERENCE_IMG, # str (base64)
|
||||
KEY_LANGUAGE,
|
||||
REFERENCE_IMG, # str (base64 or filepath)
|
||||
MASK_IMG, # str (base64 or filepath)
|
||||
KEY_LANGUAGE, # str
|
||||
]
|
||||
|
||||
# - output only
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from utilities.constants import OPTIONAL_KEYS
|
|||
from utilities.constants import REQUIRED_KEYS
|
||||
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import MASK_IMG
|
||||
from utilities.constants import BASE64IMAGE
|
||||
from utilities.constants import IMAGE_NOT_FOUND_BASE64
|
||||
|
||||
|
|
@ -164,7 +165,7 @@ class Database:
|
|||
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
|
||||
}
|
||||
# load image to job if has one
|
||||
for key in [BASE64IMAGE, REFERENCE_IMG]:
|
||||
for key in [BASE64IMAGE, REFERENCE_IMG, MASK_IMG]:
|
||||
if key in job and "base64" not in job[key]:
|
||||
data = load_image(job[key], to_base64=True)
|
||||
job[key] = data if data else IMAGE_NOT_FOUND_BASE64
|
||||
|
|
@ -184,16 +185,26 @@ class Database:
|
|||
job_uuid = str(uuid.uuid4())
|
||||
self.__logger.info(f"inserting a new job with {job_uuid}")
|
||||
|
||||
current_epoch = get_epoch_now()
|
||||
# store image to job_dict if has one
|
||||
if (
|
||||
self.__image_output_folder
|
||||
and REFERENCE_IMG in job_dict
|
||||
and "base64" in job_dict[REFERENCE_IMG]
|
||||
):
|
||||
ref_img_filepath = f"{self.__image_output_folder}/{get_epoch_now()}_ref.png"
|
||||
ref_img_filepath = f"{self.__image_output_folder}/{current_epoch}_ref.png"
|
||||
self.__logger.info(f"saving reference image to {ref_img_filepath}")
|
||||
if save_image(job_dict[REFERENCE_IMG], ref_img_filepath):
|
||||
job_dict[REFERENCE_IMG] = ref_img_filepath
|
||||
if (
|
||||
self.__image_output_folder
|
||||
and MASK_IMG in job_dict
|
||||
and "base64" in job_dict[MASK_IMG]
|
||||
):
|
||||
mask_img_filepath = f"{self.__image_output_folder}/{current_epoch}_mask.png"
|
||||
self.__logger.info(f"saving mask image to {mask_img_filepath}")
|
||||
if save_image(job_dict[MASK_IMG], mask_img_filepath):
|
||||
job_dict[MASK_IMG] = mask_img_filepath
|
||||
|
||||
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
|
||||
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
|
||||
from utilities.constants import BASE64IMAGE
|
||||
from utilities.constants import KEY_SEED
|
||||
from utilities.constants import KEY_WIDTH
|
||||
from utilities.constants import KEY_HEIGHT
|
||||
from utilities.constants import KEY_STEPS
|
||||
from utilities.config import Config
|
||||
from utilities.logger import DummyLogger
|
||||
from utilities.memory import empty_memory_cache
|
||||
from utilities.model import Model
|
||||
from utilities.times import get_epoch_now
|
||||
from utilities.images import image_to_base64
|
||||
from utilities.images import base64_to_image
|
||||
|
||||
|
||||
class Inpainting:
|
||||
"""
|
||||
Inpainting class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
output_folder: str = "",
|
||||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
def brunch(self, prompt: str, negative_prompt: str = ""):
|
||||
self.breakfast()
|
||||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
pass
|
||||
|
||||
def lunch(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
reference_image: Union[Image.Image, None, str] = None,
|
||||
mask_image: Union[Image.Image, None, str] = None,
|
||||
config: Config = Config(),
|
||||
) -> dict:
|
||||
if not prompt:
|
||||
self.__logger.error("no prompt provided, won't proceed")
|
||||
return {}
|
||||
if reference_image is None:
|
||||
return {}
|
||||
if mask_image is None:
|
||||
return {}
|
||||
|
||||
self.model.set_inpaint_scheduler(config.get_scheduler())
|
||||
|
||||
t = get_epoch_now()
|
||||
seed = config.get_seed()
|
||||
generator = torch.Generator(self.__device).manual_seed(seed)
|
||||
self.__logger.info("current seed: {}".format(seed))
|
||||
|
||||
if isinstance(reference_image, str):
|
||||
reference_image = base64_to_image(reference_image).convert("RGB")
|
||||
reference_image.thumbnail((config.get_width(), config.get_height()))
|
||||
|
||||
if isinstance(mask_image, str):
|
||||
mask_image = base64_to_image(mask_image).convert("RGB")
|
||||
# assume mask image and reference image size ratio is the same
|
||||
if mask_image.size[0] < reference_image.size[0]:
|
||||
mask_image = mask_image.resize(reference_image.size)
|
||||
elif mask_image.size[0] > reference_image.size[0]:
|
||||
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
|
||||
|
||||
result = self.model.inpaint_pipeline(
|
||||
prompt=prompt,
|
||||
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
|
||||
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
callback=None,
|
||||
callback_steps=10,
|
||||
)
|
||||
|
||||
# resize it back based on ratio (keep width 512)
|
||||
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
|
||||
|
||||
if self.__output_folder:
|
||||
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||
result_img.save(out_filepath)
|
||||
self.__logger.info("output to file: {}".format(out_filepath))
|
||||
|
||||
empty_memory_cache()
|
||||
|
||||
return {
|
||||
BASE64IMAGE: image_to_base64(result_img),
|
||||
KEY_SEED: str(seed),
|
||||
KEY_WIDTH: config.get_width(),
|
||||
KEY_HEIGHT: config.get_height(),
|
||||
KEY_STEPS: config.get_steps(),
|
||||
}
|
||||
Loading…
Reference in New Issue