adds img2img capability
This commit is contained in:
parent
3db3ed283e
commit
6d0bc3ebab
8
main.py
8
main.py
|
|
@ -44,7 +44,7 @@ from utilities.img2img import Img2Img
|
|||
|
||||
|
||||
app = Flask(__name__)
|
||||
fast_web_debugging = False
|
||||
app.config['TESTING'] = False
|
||||
memory_lock = Lock()
|
||||
event_termination = Event()
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
|
|
@ -270,9 +270,9 @@ def backend(event_termination):
|
|||
|
||||
|
||||
def main():
|
||||
if fast_web_debugging:
|
||||
if app.testing:
|
||||
try:
|
||||
app.run(host="0.0.0.0")
|
||||
app.run(host="0.0.0.0", port="5000")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return
|
||||
|
|
@ -281,7 +281,7 @@ def main():
|
|||
# ugly solution for now
|
||||
# TODO: use a database to track instead of internal memory
|
||||
try:
|
||||
app.run(host="0.0.0.0")
|
||||
app.run(host="0.0.0.0", port="8888")
|
||||
thread.join()
|
||||
except KeyboardInterrupt:
|
||||
event_termination.set()
|
||||
|
|
|
|||
|
|
@ -76,12 +76,22 @@
|
|||
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
|
||||
<div class="form-row">
|
||||
<label for="guidanceScale">Guidance Scale</label>
|
||||
<input type="number" class="form-control" id="inputGuidanceScale"
|
||||
aria-describedby="inputGuidanceScaleHelp" placeholder="25" min="1" max="30">
|
||||
<div id="inputGuidanceScaleHelp" class="form-text">How much guidance to follow from
|
||||
description. 20 strictly follow prompt, 7 creative/artistic.
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<button id="newTxt2ImgJob" class="btn btn-primary">Let's Go!</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-9">
|
||||
<div class="card mb-3">
|
||||
<div class="card-header">
|
||||
<ul class="nav nav-tabs card-header-tabs">
|
||||
<ul class="nav nav-pills card-header-pills">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link active" href="#card-txt">Text-to-Image</a>
|
||||
</li>
|
||||
|
|
@ -94,7 +104,57 @@
|
|||
</ul>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-img" style="display:none">
|
||||
img
|
||||
<div class="row">
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
Reference Image
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<button id="copy-txt-to-img" class="btn btn-primary mb-3">Copy from
|
||||
Txt-to-Image</button>
|
||||
<button id="copy-last-img" class="btn btn-primary mb-3">Copy from
|
||||
Last
|
||||
Image
|
||||
Result</button>
|
||||
<button id="upload-img" class="btn btn-primary mb-3">Upload
|
||||
Image</button>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="strength">Strength</label>
|
||||
<input type="number" class="form-control" id="inputStrength"
|
||||
aria-describedby="inputStrengthHelp" placeholder="0.5" min="0"
|
||||
max="1">
|
||||
<div id="inputStrengthHelp" class="form-text">How semantically
|
||||
consistent with the origional image.
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<button id="newImg2ImgJob" class="btn btn-primary mb-3">Let's Go
|
||||
with Image Below!</button>
|
||||
</div>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="reference-img">
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
Result
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<ul class="list-group">
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="img2ImgStatus"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="img2ImgSeed"></li>
|
||||
</ul>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="img2ImgImg">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-inpainting" style="display:none">
|
||||
TBD
|
||||
|
|
@ -107,12 +167,12 @@
|
|||
<div class="card-body">
|
||||
<ul class="list-group">
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultStatus"></li>
|
||||
id="txt2ImgStatus"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultSeed"></li>
|
||||
id="txt2ImgSeed"></li>
|
||||
</ul>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="newJobImg">
|
||||
<img class="card-img-bottom" id="txt2ImgImg">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -153,22 +213,34 @@
|
|||
success: function (response) {
|
||||
console.log(response);
|
||||
if (response.jobs.length == 1) {
|
||||
$('#resultStatus').html(response.jobs[0].status)
|
||||
$('#resultSeed').html("seed: " + response.jobs[0].seed)
|
||||
if (response.jobs[0].status == "done") {
|
||||
$('#newJobImg').attr('src', response.jobs[0].img);
|
||||
return;
|
||||
}
|
||||
if (response.jobs[0].status == "failed") {
|
||||
return;
|
||||
if (response.jobs[0].type == 'txt') {
|
||||
$('#txt2ImgStatus').html(response.jobs[0].status);
|
||||
$('#txt2ImgSeed').html("seed: " + response.jobs[0].seed);
|
||||
if (response.jobs[0].status == "done") {
|
||||
$('#txt2ImgImg').attr('src', response.jobs[0].img);
|
||||
return;
|
||||
}
|
||||
if (response.jobs[0].status == "failed") {
|
||||
return;
|
||||
}
|
||||
} else if (response.jobs[0].type == 'img') {
|
||||
$('#img2ImgStatus').html(response.jobs[0].status);
|
||||
$('#img2ImgSeed').html("seed: " + response.jobs[0].seed);
|
||||
if (response.jobs[0].status == "done") {
|
||||
$('#img2ImgImg').attr('src', response.jobs[0].img);
|
||||
return;
|
||||
}
|
||||
if (response.jobs[0].status == "failed") {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1000); // refresh every second
|
||||
setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second
|
||||
},
|
||||
error: function (xhr, status, error) {
|
||||
// Handle error response
|
||||
console.log(xhr.responseText);
|
||||
$('#resultStatus').html('failed');
|
||||
$('#txt2ImgStatus').html('failed');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -189,6 +261,43 @@
|
|||
}
|
||||
});
|
||||
|
||||
// Cache variable to store the selected image data
|
||||
var imageData = null;
|
||||
|
||||
$("#copy-txt-to-img").click(function () {
|
||||
data = $("#txt2ImgImg").attr("src");
|
||||
if (data == null || data == "") {
|
||||
alert("nothing found from txt-to-img result");
|
||||
return;
|
||||
}
|
||||
imageData = data;
|
||||
$("#reference-img").attr("src", imageData);
|
||||
});
|
||||
|
||||
$("#copy-last-img").click(function () {
|
||||
data = $("#img2ImgImg").attr("src");
|
||||
if (data == null || data == "") {
|
||||
alert("nothing found from img-to-img result");
|
||||
return;
|
||||
}
|
||||
imageData = data;
|
||||
$("#reference-img").attr("src", imageData);
|
||||
});
|
||||
|
||||
$("#upload-img").click(function () {
|
||||
var input = $("<input type='file' accept='image/*'>");
|
||||
input.on("change", function () {
|
||||
var reader = new FileReader();
|
||||
reader.onload = function (e) {
|
||||
imageData = e.target.result;
|
||||
$("#reference-img").attr("src", imageData);
|
||||
};
|
||||
reader.readAsDataURL(input[0].files[0]);
|
||||
});
|
||||
input.click();
|
||||
});
|
||||
|
||||
|
||||
$(".nav-link").click(function (e) {
|
||||
e.preventDefault();
|
||||
var target = $(this).attr("href"); // get the href value of the clicked link
|
||||
|
|
@ -200,7 +309,7 @@
|
|||
$(target).show();
|
||||
});
|
||||
|
||||
$('#newJob').click(function (e) {
|
||||
$('#newTxt2ImgJob').click(function (e) {
|
||||
e.preventDefault(); // Prevent the default form submission
|
||||
|
||||
// Gather input field values
|
||||
|
|
@ -211,6 +320,10 @@
|
|||
if (seedVal == "0" || seedVal == "") {
|
||||
seedVal = "0";
|
||||
}
|
||||
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
|
||||
if (isNaN(guidanceScaleVal)) {
|
||||
guidanceScaleVal = 25.0;
|
||||
}
|
||||
var stepsVal = parseInt($('#inputSteps').val());
|
||||
if (isNaN(stepsVal)) {
|
||||
stepsVal = 50;
|
||||
|
|
@ -229,6 +342,11 @@
|
|||
return;
|
||||
}
|
||||
|
||||
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
|
||||
alert("guidance scale must be between 1 and 30");
|
||||
return;
|
||||
}
|
||||
|
||||
if (widthVal < 8 || widthVal > 960) {
|
||||
alert("width must be between 8 and 960!");
|
||||
return;
|
||||
|
|
@ -268,6 +386,7 @@
|
|||
'steps': stepsVal,
|
||||
'width': widthVal,
|
||||
'height': heightVal,
|
||||
'guidance_scale': guidanceScaleVal,
|
||||
'neg_prompt': negPromptVal
|
||||
}),
|
||||
success: function (response) {
|
||||
|
|
@ -275,13 +394,95 @@
|
|||
if (response.uuid) {
|
||||
$('#jobuuid').val(response.uuid);
|
||||
}
|
||||
$('#resultStatus').html('submitting new job..');
|
||||
$('#txt2ImgStatus').html('submitting new job..');
|
||||
waitForImage(apikeyVal, response.uuid);
|
||||
},
|
||||
error: function (xhr, status, error) {
|
||||
// Handle error response
|
||||
console.log(xhr.responseText);
|
||||
$('#resultStatus').html('failed');
|
||||
$('#txt2ImgStatus').html('failed');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
$('#newImg2ImgJob').click(function (e) {
|
||||
e.preventDefault(); // Prevent the default form submission
|
||||
|
||||
if (imageData == null) {
|
||||
alert("No image cached")
|
||||
return;
|
||||
}
|
||||
|
||||
// Gather input field values
|
||||
var apikeyVal = $('#apiKey').val();
|
||||
var promptVal = $('#prompt').val();
|
||||
var negPromptVal = $('#negPrompt').val();
|
||||
var seedVal = $('#inputSeed').val();
|
||||
if (seedVal == "0" || seedVal == "") {
|
||||
seedVal = "0";
|
||||
}
|
||||
var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val())
|
||||
if (isNaN(guidanceScaleVal)) {
|
||||
guidanceScaleVal = 25.0;
|
||||
}
|
||||
var stepsVal = parseInt($('#inputSteps').val());
|
||||
if (isNaN(stepsVal)) {
|
||||
stepsVal = 50;
|
||||
}
|
||||
var strengthVal = parseInt($('#inputStrength').val());
|
||||
if (isNaN(strengthVal)) {
|
||||
strengthVal = 0.5;
|
||||
}
|
||||
|
||||
if (promptVal == "") {
|
||||
alert("missing prompt!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (guidanceScaleVal < 1 || guidanceScaleVal > 30) {
|
||||
alert("guidance scale must be between 1 and 30");
|
||||
return;
|
||||
}
|
||||
|
||||
if (strengthVal < 0 || strengthVal > 1) {
|
||||
alert("strength must be between 0 and 1");
|
||||
return;
|
||||
}
|
||||
|
||||
if (stepsVal > 200 || stepsVal < 1) {
|
||||
alert("steps value must be between 1 and 200!");
|
||||
return;
|
||||
}
|
||||
|
||||
// Send POST request using Ajax
|
||||
$.ajax({
|
||||
type: 'POST',
|
||||
url: '/add_job',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
dataType: 'json',
|
||||
data: JSON.stringify({
|
||||
'api_key': apikeyVal,
|
||||
'type': 'img',
|
||||
'ref_img': imageData,
|
||||
'prompt': promptVal,
|
||||
'seed': seedVal,
|
||||
'steps': stepsVal,
|
||||
'guidance_scale': guidanceScaleVal,
|
||||
'strength': strengthVal,
|
||||
'neg_prompt': negPromptVal
|
||||
}),
|
||||
success: function (response) {
|
||||
console.log(response);
|
||||
if (response.uuid) {
|
||||
$('#jobuuid').val(response.uuid);
|
||||
}
|
||||
$('#img2ImgStatus').html('submitting new job..');
|
||||
waitForImage(apikeyVal, response.uuid);
|
||||
},
|
||||
error: function (xhr, status, error) {
|
||||
// Handle error response
|
||||
console.log(xhr.responseText);
|
||||
$('#img2ImgStatus').html('failed');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from utilities.constants import KEY_GUIDANCE_SCALE
|
|||
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
||||
from utilities.constants import KEY_HEIGHT
|
||||
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
||||
from utilities.constants import KEY_STRENGTH
|
||||
from utilities.constants import VALUE_STRENGTH_DEFAULT
|
||||
from utilities.constants import KEY_PREVIEW
|
||||
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
||||
from utilities.constants import KEY_SCHEDULER
|
||||
|
|
@ -139,3 +141,15 @@ class Config:
|
|||
)
|
||||
self.__config[KEY_WIDTH] = value
|
||||
return self
|
||||
|
||||
def get_strength(self) -> float:
|
||||
return float(self.__config.get(KEY_STRENGTH, VALUE_STRENGTH_DEFAULT))
|
||||
|
||||
def set_strength(self, strength: float):
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(
|
||||
KEY_STRENGTH, self.get_strength(), strength
|
||||
)
|
||||
)
|
||||
self.__config[KEY_STRENGTH] = strength
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ KEY_HEIGHT = "HEIGHT"
|
|||
VALUE_HEIGHT_DEFAULT = 512
|
||||
|
||||
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0
|
||||
|
||||
KEY_STRENGTH = "STRENGTH"
|
||||
VALUE_STRENGTH_DEFAULT = 0.5
|
||||
|
||||
KEY_STEPS = "STEPS"
|
||||
VALUE_STEPS_DEFAULT = 50
|
||||
|
|
@ -69,4 +72,6 @@ OPTIONAL_KEYS = [
|
|||
KEY_GUIDANCE_SCALE.lower(),
|
||||
KEY_STEPS.lower(),
|
||||
KEY_SCHEDULER.lower(),
|
||||
KEY_STRENGTH.lower(),
|
||||
REFERENCE_IMG.lower(),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def load_image(image: Union[str, bytes]) -> Union[Image.Image, None]:
|
|||
return None
|
||||
|
||||
|
||||
def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = False) -> bool:
|
||||
def save_image(
|
||||
image: Union[bytes, Image.Image], filepath: str, override: bool = False
|
||||
) -> bool:
|
||||
if os.path.isfile(filepath) and not override:
|
||||
return False
|
||||
try:
|
||||
|
|
@ -31,13 +33,15 @@ def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool =
|
|||
|
||||
|
||||
def crop_image(image: Image.Image, boundary: tuple) -> Image.Image:
|
||||
'''
|
||||
"""
|
||||
Crop an image based on boundary defined in boundary tuple.
|
||||
'''
|
||||
"""
|
||||
return image.crop(boundary)
|
||||
|
||||
|
||||
def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "png") -> str:
|
||||
def image_to_base64(
|
||||
image: Union[bytes, str, Image.Image], image_format: str = "png"
|
||||
) -> str:
|
||||
if isinstance(image, str):
|
||||
# this is a filepath
|
||||
if not os.path.isfile(image):
|
||||
|
|
@ -49,7 +53,18 @@ def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "
|
|||
rawbytes = io.BytesIO()
|
||||
image.save(rawbytes, format=image_format)
|
||||
image = rawbytes.getvalue()
|
||||
return "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
|
||||
return (
|
||||
"data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
|
||||
)
|
||||
|
||||
|
||||
def base64_to_image(image: str) -> Image.Image:
|
||||
tmp = image.split(",")
|
||||
if len(tmp) > 1:
|
||||
base64parts = tmp[1]
|
||||
else:
|
||||
base64parts = image
|
||||
return Image.open(io.BytesIO(base64.b64decode(base64parts)))
|
||||
|
||||
|
||||
from skimage import io as skimageio
|
||||
|
|
@ -57,7 +72,13 @@ from skimage import transform
|
|||
from skimage import img_as_ubyte
|
||||
|
||||
|
||||
def load_and_transform_image_for_torch(img_filepath: str, dimension: tuple = (), force_rgb: bool = True, transpose: bool = True, use_ubyte: bool = False) -> np.ndarray:
|
||||
def load_and_transform_image_for_torch(
|
||||
img_filepath: str,
|
||||
dimension: tuple = (),
|
||||
force_rgb: bool = True,
|
||||
transpose: bool = True,
|
||||
use_ubyte: bool = False,
|
||||
) -> np.ndarray:
|
||||
img = skimageio.imread(img_filepath)
|
||||
if force_rgb:
|
||||
img = img[:, :, :3]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from utilities.memory import empty_memory_cache
|
|||
from utilities.model import Model
|
||||
from utilities.times import get_epoch_now
|
||||
from utilities.images import image_to_base64
|
||||
from utilities.images import base64_to_image
|
||||
|
||||
|
||||
class Img2Img:
|
||||
|
|
@ -59,13 +60,14 @@ class Img2Img:
|
|||
self.__logger.info("current seed: {}".format(seed))
|
||||
|
||||
if isinstance(reference_image, str):
|
||||
reference_image
|
||||
reference_image = base64_to_image(reference_image).convert('RGB')
|
||||
|
||||
result = self.model.txt2img_pipeline(
|
||||
result = self.model.img2img_pipeline(
|
||||
prompt=prompt,
|
||||
image=reference_image,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
strength=config.get_strength(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
callback=None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue