adds img2img capability

This commit is contained in:
HappyZ 2023-04-30 19:40:03 -07:00
parent 3db3ed283e
commit 6d0bc3ebab
6 changed files with 275 additions and 32 deletions

View File

@ -44,7 +44,7 @@ from utilities.img2img import Img2Img
app = Flask(__name__)
fast_web_debugging = False
app.config['TESTING'] = False
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
@ -270,9 +270,9 @@ def backend(event_termination):
def main():
if fast_web_debugging:
if app.testing:
try:
app.run(host="0.0.0.0")
app.run(host="0.0.0.0", port="5000")
except KeyboardInterrupt:
pass
return
@ -281,7 +281,7 @@ def main():
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run(host="0.0.0.0")
app.run(host="0.0.0.0", port="8888")
thread.join()
except KeyboardInterrupt:
event_termination.set()

View File

@ -76,12 +76,22 @@
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
max="1024">
</div>
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
<div class="form-row">
<label for="guidanceScale">Guidance Scale</label>
<input type="number" class="form-control" id="inputGuidanceScale"
aria-describedby="inputGuidanceScaleHelp" placeholder="25" min="1" max="30">
<div id="inputGuidanceScaleHelp" class="form-text">How much guidance to follow from
description. 20 strictly follow prompt, 7 creative/artistic.
</div>
</div>
<div class="row">
<button id="newTxt2ImgJob" class="btn btn-primary">Let's Go!</button>
</div>
</div>
<div class="col-md-9">
<div class="card mb-3">
<div class="card-header">
<ul class="nav nav-tabs card-header-tabs">
<ul class="nav nav-pills card-header-pills">
<li class="nav-item">
<a class="nav-link active" href="#card-txt">Text-to-Image</a>
</li>
@ -94,7 +104,57 @@
</ul>
</div>
<div class="card-body card-specific" id="card-img" style="display:none">
img
<div class="row">
<div class="col-md-6">
<div class="card">
<div class="card-header">
Reference Image
</div>
<div class="card-body">
<div class="row">
<button id="copy-txt-to-img" class="btn btn-primary mb-3">Copy from
Txt-to-Image</button>
<button id="copy-last-img" class="btn btn-primary mb-3">Copy from
Last
Image
Result</button>
<button id="upload-img" class="btn btn-primary mb-3">Upload
Image</button>
</div>
<div class="form-row">
<label for="strength">Strength</label>
<input type="number" class="form-control" id="inputStrength"
aria-describedby="inputStrengthHelp" placeholder="0.5" min="0"
max="1">
<div id="inputStrengthHelp" class="form-text">How semantically
consistent with the origional image.
</div>
</div>
<div class="row">
<button id="newImg2ImgJob" class="btn btn-primary mb-3">Let's Go
with Image Below!</button>
</div>
</div>
<img class="card-img-bottom" id="reference-img">
</div>
</div>
<div class="col-md-6">
<div class="card">
<div class="card-header">
Result
</div>
<div class="card-body">
<ul class="list-group">
<li class="list-group-item d-flex justify-content-between align-items-center"
id="img2ImgStatus"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="img2ImgSeed"></li>
</ul>
</div>
<img class="card-img-bottom" id="img2ImgImg">
</div>
</div>
</div>
</div>
<div class="card-body card-specific" id="card-inpainting" style="display:none">
TBD
@ -107,12 +167,12 @@
<div class="card-body">
<ul class="list-group">
<li class="list-group-item d-flex justify-content-between align-items-center"
id="resultStatus"></li>
id="txt2ImgStatus"></li>
<li class="list-group-item d-flex justify-content-between align-items-center"
id="resultSeed"></li>
id="txt2ImgSeed"></li>
</ul>
</div>
<img class="card-img-bottom" id="newJobImg">
<img class="card-img-bottom" id="txt2ImgImg">
</div>
</div>
</div>
@ -153,22 +213,34 @@
success: function (response) {
console.log(response);
if (response.jobs.length == 1) {
$('#resultStatus').html(response.jobs[0].status)
$('#resultSeed').html("seed: " + response.jobs[0].seed)
if (response.jobs[0].status == "done") {
$('#newJobImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
if (response.jobs[0].type == 'txt') {
$('#txt2ImgStatus').html(response.jobs[0].status);
$('#txt2ImgSeed').html("seed: " + response.jobs[0].seed);
if (response.jobs[0].status == "done") {
$('#txt2ImgImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
}
} else if (response.jobs[0].type == 'img') {
$('#img2ImgStatus').html(response.jobs[0].status);
$('#img2ImgSeed').html("seed: " + response.jobs[0].seed);
if (response.jobs[0].status == "done") {
$('#img2ImgImg').attr('src', response.jobs[0].img);
return;
}
if (response.jobs[0].status == "failed") {
return;
}
}
}
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1000); // refresh every second
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#resultStatus').html('failed');
$('#txt2ImgStatus').html('failed');
}
});
}
@ -189,6 +261,43 @@
}
});
// Cache variable to store the selected image data
var imageData = null;
$("#copy-txt-to-img").click(function () {
data = $("#txt2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from txt-to-img result");
return;
}
imageData = data;
$("#reference-img").attr("src", imageData);
});
$("#copy-last-img").click(function () {
data = $("#img2ImgImg").attr("src");
if (data == null || data == "") {
alert("nothing found from img-to-img result");
return;
}
imageData = data;
$("#reference-img").attr("src", imageData);
});
$("#upload-img").click(function () {
var input = $("<input type='file' accept='image/*'>");
input.on("change", function () {
var reader = new FileReader();
reader.onload = function (e) {
imageData = e.target.result;
$("#reference-img").attr("src", imageData);
};
reader.readAsDataURL(input[0].files[0]);
});
input.click();
});
$(".nav-link").click(function (e) {
e.preventDefault();
var target = $(this).attr("href"); // get the href value of the clicked link
@ -200,7 +309,7 @@
$(target).show();
});
$('#newJob').click(function (e) {
$('#newTxt2ImgJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
// Gather input field values
@ -211,6 +320,10 @@
if (seedVal == "0" || seedVal == "") {
seedVal = "0";
}
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
if (isNaN(guidanceScaleVal)) {
guidanceScaleVal = 25.0;
}
var stepsVal = parseInt($('#inputSteps').val());
if (isNaN(stepsVal)) {
stepsVal = 50;
@ -229,6 +342,11 @@
return;
}
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
alert("guidance scale must be between 1 and 30");
return;
}
if (widthVal < 8 || widthVal > 960) {
alert("width must be between 8 and 960!");
return;
@ -268,6 +386,7 @@
'steps': stepsVal,
'width': widthVal,
'height': heightVal,
'guidance_scale': guidanceScaleVal,
'neg_prompt': negPromptVal
}),
success: function (response) {
@ -275,13 +394,95 @@
if (response.uuid) {
$('#jobuuid').val(response.uuid);
}
$('#resultStatus').html('submitting new job..');
$('#txt2ImgStatus').html('submitting new job..');
waitForImage(apikeyVal, response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#resultStatus').html('failed');
$('#txt2ImgStatus').html('failed');
}
});
});
$('#newImg2ImgJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
if (imageData == null) {
alert("No image cached")
return;
}
// Gather input field values
var apikeyVal = $('#apiKey').val();
var promptVal = $('#prompt').val();
var negPromptVal = $('#negPrompt').val();
var seedVal = $('#inputSeed').val();
if (seedVal == "0" || seedVal == "") {
seedVal = "0";
}
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
if (isNaN(guidanceScaleVal)) {
guidanceScaleVal = 25.0;
}
var stepsVal = parseInt($('#inputSteps').val());
if (isNaN(stepsVal)) {
stepsVal = 50;
}
var strengthVal = parseInt($('#inputStrength').val());
if (isNaN(strengthVal)) {
strengthVal = 0.5;
}
if (promptVal == "") {
alert("missing prompt!");
return;
}
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
alert("guidance scale must be between 1 and 30");
return;
}
if (strengthVal < 0 || strengthVal > 1) {
alert("strength must be between 0 and 1");
return;
}
if (stepsVal > 200 || stepsVal < 1) {
alert("steps value must be between 1 and 200!");
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({
'api_key': apikeyVal,
'type': 'img',
'ref_img': imageData,
'prompt': promptVal,
'seed': seedVal,
'steps': stepsVal,
'guidance_scale': guidanceScaleVal,
'strength': strengthVal,
'neg_prompt': negPromptVal
}),
success: function (response) {
console.log(response);
if (response.uuid) {
$('#jobuuid').val(response.uuid);
}
$('#img2ImgStatus').html('submitting new job..');
waitForImage(apikeyVal, response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#img2ImgStatus').html('failed');
}
});
});

View File

@ -8,6 +8,8 @@ from utilities.constants import KEY_GUIDANCE_SCALE
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
from utilities.constants import KEY_HEIGHT
from utilities.constants import VALUE_HEIGHT_DEFAULT
from utilities.constants import KEY_STRENGTH
from utilities.constants import VALUE_STRENGTH_DEFAULT
from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER
@ -139,3 +141,15 @@ class Config:
)
self.__config[KEY_WIDTH] = value
return self
def get_strength(self) -> float:
return float(self.__config.get(KEY_STRENGTH, VALUE_STRENGTH_DEFAULT))
def set_strength(self, strength: float):
self.__logger.info(
"{} changed from {} to {}".format(
KEY_STRENGTH, self.get_strength(), strength
)
)
self.__config[KEY_STRENGTH] = strength
return self

View File

@ -20,7 +20,10 @@ KEY_HEIGHT = "HEIGHT"
VALUE_HEIGHT_DEFAULT = 512
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0
KEY_STRENGTH = "STRENGTH"
VALUE_STRENGTH_DEFAULT = 0.5
KEY_STEPS = "STEPS"
VALUE_STEPS_DEFAULT = 50
@ -69,4 +72,6 @@ OPTIONAL_KEYS = [
KEY_GUIDANCE_SCALE.lower(),
KEY_STEPS.lower(),
KEY_SCHEDULER.lower(),
KEY_STRENGTH.lower(),
REFERENCE_IMG.lower(),
]

View File

@ -15,7 +15,9 @@ def load_image(image: Union[str, bytes]) -> Union[Image.Image, None]:
return None
def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = False) -> bool:
def save_image(
image: Union[bytes, Image.Image], filepath: str, override: bool = False
) -> bool:
if os.path.isfile(filepath) and not override:
return False
try:
@ -31,13 +33,15 @@ def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool =
def crop_image(image: Image.Image, boundary: tuple) -> Image.Image:
'''
"""
Crop an image based on boundary defined in boundary tuple.
'''
"""
return image.crop(boundary)
def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "png") -> str:
def image_to_base64(
image: Union[bytes, str, Image.Image], image_format: str = "png"
) -> str:
if isinstance(image, str):
# this is a filepath
if not os.path.isfile(image):
@ -49,7 +53,18 @@ def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "
rawbytes = io.BytesIO()
image.save(rawbytes, format=image_format)
image = rawbytes.getvalue()
return "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
return (
"data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
)
def base64_to_image(image: str) -> Image.Image:
tmp = image.split(",")
if len(tmp) > 1:
base64parts = tmp[1]
else:
base64parts = image
return Image.open(io.BytesIO(base64.b64decode(base64parts)))
from skimage import io as skimageio
@ -57,7 +72,13 @@ from skimage import transform
from skimage import img_as_ubyte
def load_and_transform_image_for_torch(img_filepath: str, dimension: tuple = (), force_rgb: bool = True, transpose: bool = True, use_ubyte: bool = False) -> np.ndarray:
def load_and_transform_image_for_torch(
img_filepath: str,
dimension: tuple = (),
force_rgb: bool = True,
transpose: bool = True,
use_ubyte: bool = False,
) -> np.ndarray:
img = skimageio.imread(img_filepath)
if force_rgb:
img = img[:, :, :3]

View File

@ -13,6 +13,7 @@ from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
from utilities.images import base64_to_image
class Img2Img:
@ -59,13 +60,14 @@ class Img2Img:
self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str):
reference_image
reference_image = base64_to_image(reference_image).convert('RGB')
result = self.model.txt2img_pipeline(
result = self.model.img2img_pipeline(
prompt=prompt,
image=reference_image,
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
strength=config.get_strength(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,