diff --git a/main.py b/main.py index 3314d9f..9268dee 100644 --- a/main.py +++ b/main.py @@ -44,7 +44,7 @@ from utilities.img2img import Img2Img app = Flask(__name__) -fast_web_debugging = False +app.config['TESTING'] = False memory_lock = Lock() event_termination = Event() logger = Logger(name=LOGGER_NAME) @@ -270,9 +270,9 @@ def backend(event_termination): def main(): - if fast_web_debugging: + if app.testing: try: - app.run(host="0.0.0.0") + app.run(host="0.0.0.0", port="5000") except KeyboardInterrupt: pass return @@ -281,7 +281,7 @@ def main(): # ugly solution for now # TODO: use a database to track instead of internal memory try: - app.run(host="0.0.0.0") + app.run(host="0.0.0.0", port="8888") thread.join() except KeyboardInterrupt: event_termination.set() diff --git a/templates/index.html b/templates/index.html index 3a62a77..e4e967f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -76,12 +76,22 @@ - +
+ + +
How much guidance to follow from + description. 20 strictly follow prompt, 7 creative/artistic. +
+
+
+ +
-
@@ -153,22 +213,34 @@ success: function (response) { console.log(response); if (response.jobs.length == 1) { - $('#resultStatus').html(response.jobs[0].status) - $('#resultSeed').html("seed: " + response.jobs[0].seed) - if (response.jobs[0].status == "done") { - $('#newJobImg').attr('src', response.jobs[0].img); - return; - } - if (response.jobs[0].status == "failed") { - return; + if (response.jobs[0].type == 'txt') { + $('#txt2ImgStatus').html(response.jobs[0].status); + $('#txt2ImgSeed').html("seed: " + response.jobs[0].seed); + if (response.jobs[0].status == "done") { + $('#txt2ImgImg').attr('src', response.jobs[0].img); + return; + } + if (response.jobs[0].status == "failed") { + return; + } + } else if (response.jobs[0].type == 'img') { + $('#img2ImgStatus').html(response.jobs[0].status); + $('#img2ImgSeed').html("seed: " + response.jobs[0].seed); + if (response.jobs[0].status == "done") { + $('#img2ImgImg').attr('src', response.jobs[0].img); + return; + } + if (response.jobs[0].status == "failed") { + return; + } } } - setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1000); // refresh every second + setTimeout(function () { waitForImage(apikeyVal, uuidValue); }, 1500); // refresh every second }, error: function (xhr, status, error) { // Handle error response console.log(xhr.responseText); - $('#resultStatus').html('failed'); + $('#txt2ImgStatus').html('failed'); } }); } @@ -189,6 +261,43 @@ } }); + // Cache variable to store the selected image data + var imageData = null; + + $("#copy-txt-to-img").click(function () { + data = $("#txt2ImgImg").attr("src"); + if (data == null || data == "") { + alert("nothing found from txt-to-img result"); + return; + } + imageData = data; + $("#reference-img").attr("src", imageData); + }); + + $("#copy-last-img").click(function () { + data = $("#img2ImgImg").attr("src"); + if (data == null || data == "") { + alert("nothing found from img-to-img result"); + return; + } + imageData = data; + $("#reference-img").attr("src", imageData); + }); + + $("#upload-img").click(function () { + var input = $(""); + input.on("change", function () { + var reader = new FileReader(); + reader.onload = function (e) { + imageData = e.target.result; + $("#reference-img").attr("src", imageData); + }; + reader.readAsDataURL(input[0].files[0]); + }); + input.click(); + }); + + $(".nav-link").click(function (e) { e.preventDefault(); var target = $(this).attr("href"); // get the href value of the clicked link @@ -200,7 +309,7 @@ $(target).show(); }); - $('#newJob').click(function (e) { + $('#newTxt2ImgJob').click(function (e) { e.preventDefault(); // Prevent the default form submission // Gather input field values @@ -211,6 +320,10 @@ if (seedVal == "0" || seedVal == "") { seedVal = "0"; } + var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val()) + if (isNaN(guidanceScaleVal)) { + guidanceScaleVal = 25.0; + } var stepsVal = parseInt($('#inputSteps').val()); if (isNaN(stepsVal)) { stepsVal = 50; @@ -229,6 +342,11 @@ return; } + if (guidanceScaleVal < 1 || guidanceScaleVal > 30) { + alert("guidance scale must be between 1 and 30"); + return; + } + if (widthVal < 8 || widthVal > 960) { alert("width must be between 8 and 960!"); return; @@ -268,6 +386,7 @@ 'steps': stepsVal, 'width': widthVal, 'height': heightVal, + 'guidance_scale': guidanceScaleVal, 'neg_prompt': negPromptVal }), success: function (response) { @@ -275,13 +394,95 @@ if (response.uuid) { $('#jobuuid').val(response.uuid); } - $('#resultStatus').html('submitting new job..'); + $('#txt2ImgStatus').html('submitting new job..'); waitForImage(apikeyVal, response.uuid); }, error: function (xhr, status, error) { // Handle error response console.log(xhr.responseText); - $('#resultStatus').html('failed'); + $('#txt2ImgStatus').html('failed'); + } + }); + }); + + $('#newImg2ImgJob').click(function (e) { + e.preventDefault(); // Prevent the default form submission + + if (imageData == null) { + alert("No image cached") + return; + } + + // Gather input field values + var apikeyVal = $('#apiKey').val(); + var promptVal = $('#prompt').val(); + var negPromptVal = $('#negPrompt').val(); + var seedVal = $('#inputSeed').val(); + if (seedVal == "0" || seedVal == "") { + seedVal = "0"; + } + var guidanceScaleVal = parseFloat($('#inputGuidanceScale').val()) + if (isNaN(guidanceScaleVal)) { + guidanceScaleVal = 25.0; + } + var stepsVal = parseInt($('#inputSteps').val()); + if (isNaN(stepsVal)) { + stepsVal = 50; + } + var strengthVal = parseInt($('#inputStrength').val()); + if (isNaN(strengthVal)) { + strengthVal = 0.5; + } + + if (promptVal == "") { + alert("missing prompt!"); + return; + } + + if (guidanceScaleVal < 1 || guidanceScaleVal > 30) { + alert("guidance scale must be between 1 and 30"); + return; + } + + if (strengthVal < 0 || strengthVal > 1) { + alert("strength must be between 0 and 1"); + return; + } + + if (stepsVal > 200 || stepsVal < 1) { + alert("steps value must be between 1 and 200!"); + return; + } + + // Send POST request using Ajax + $.ajax({ + type: 'POST', + url: '/add_job', + contentType: 'application/json; charset=utf-8', + dataType: 'json', + data: JSON.stringify({ + 'api_key': apikeyVal, + 'type': 'img', + 'ref_img': imageData, + 'prompt': promptVal, + 'seed': seedVal, + 'steps': stepsVal, + 'guidance_scale': guidanceScaleVal, + 'strength': strengthVal, + 'neg_prompt': negPromptVal + }), + success: function (response) { + console.log(response); + if (response.uuid) { + $('#jobuuid').val(response.uuid); + } + $('#img2ImgStatus').html('submitting new job..'); + waitForImage(apikeyVal, response.uuid); + }, + error: function (xhr, status, error) { + // Handle error response + console.log(xhr.responseText); + $('#img2ImgStatus').html('failed'); } }); }); diff --git a/utilities/config.py b/utilities/config.py index c5f3114..7cc0026 100644 --- a/utilities/config.py +++ b/utilities/config.py @@ -8,6 +8,8 @@ from utilities.constants import KEY_GUIDANCE_SCALE from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT from utilities.constants import KEY_HEIGHT from utilities.constants import VALUE_HEIGHT_DEFAULT +from utilities.constants import KEY_STRENGTH +from utilities.constants import VALUE_STRENGTH_DEFAULT from utilities.constants import KEY_PREVIEW from utilities.constants import VALUE_PREVIEW_DEFAULT from utilities.constants import KEY_SCHEDULER @@ -139,3 +141,15 @@ class Config: ) self.__config[KEY_WIDTH] = value return self + + def get_strength(self) -> float: + return float(self.__config.get(KEY_STRENGTH, VALUE_STRENGTH_DEFAULT)) + + def set_strength(self, strength: float): + self.__logger.info( + "{} changed from {} to {}".format( + KEY_STRENGTH, self.get_strength(), strength + ) + ) + self.__config[KEY_STRENGTH] = strength + return self diff --git a/utilities/constants.py b/utilities/constants.py index bc688fd..5c793a6 100644 --- a/utilities/constants.py +++ b/utilities/constants.py @@ -20,7 +20,10 @@ KEY_HEIGHT = "HEIGHT" VALUE_HEIGHT_DEFAULT = 512 KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE" -VALUE_GUIDANCE_SCALE_DEFAULT = 15.0 +VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 + +KEY_STRENGTH = "STRENGTH" +VALUE_STRENGTH_DEFAULT = 0.5 KEY_STEPS = "STEPS" VALUE_STEPS_DEFAULT = 50 @@ -69,4 +72,6 @@ OPTIONAL_KEYS = [ KEY_GUIDANCE_SCALE.lower(), KEY_STEPS.lower(), KEY_SCHEDULER.lower(), + KEY_STRENGTH.lower(), + REFERENCE_IMG.lower(), ] diff --git a/utilities/images.py b/utilities/images.py index 6ca2d5a..a5a4338 100644 --- a/utilities/images.py +++ b/utilities/images.py @@ -15,7 +15,9 @@ def load_image(image: Union[str, bytes]) -> Union[Image.Image, None]: return None -def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = False) -> bool: +def save_image( + image: Union[bytes, Image.Image], filepath: str, override: bool = False +) -> bool: if os.path.isfile(filepath) and not override: return False try: @@ -31,13 +33,15 @@ def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = def crop_image(image: Image.Image, boundary: tuple) -> Image.Image: - ''' + """ Crop an image based on boundary defined in boundary tuple. - ''' + """ return image.crop(boundary) -def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "png") -> str: +def image_to_base64( + image: Union[bytes, str, Image.Image], image_format: str = "png" +) -> str: if isinstance(image, str): # this is a filepath if not os.path.isfile(image): @@ -49,7 +53,18 @@ def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = " rawbytes = io.BytesIO() image.save(rawbytes, format=image_format) image = rawbytes.getvalue() - return "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode() + return ( + "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode() + ) + + +def base64_to_image(image: str) -> Image.Image: + tmp = image.split(",") + if len(tmp) > 1: + base64parts = tmp[1] + else: + base64parts = image + return Image.open(io.BytesIO(base64.b64decode(base64parts))) from skimage import io as skimageio @@ -57,7 +72,13 @@ from skimage import transform from skimage import img_as_ubyte -def load_and_transform_image_for_torch(img_filepath: str, dimension: tuple = (), force_rgb: bool = True, transpose: bool = True, use_ubyte: bool = False) -> np.ndarray: +def load_and_transform_image_for_torch( + img_filepath: str, + dimension: tuple = (), + force_rgb: bool = True, + transpose: bool = True, + use_ubyte: bool = False, +) -> np.ndarray: img = skimageio.imread(img_filepath) if force_rgb: img = img[:, :, :3] diff --git a/utilities/img2img.py b/utilities/img2img.py index 7ce4a9a..3adb28a 100644 --- a/utilities/img2img.py +++ b/utilities/img2img.py @@ -13,6 +13,7 @@ from utilities.memory import empty_memory_cache from utilities.model import Model from utilities.times import get_epoch_now from utilities.images import image_to_base64 +from utilities.images import base64_to_image class Img2Img: @@ -59,13 +60,14 @@ class Img2Img: self.__logger.info("current seed: {}".format(seed)) if isinstance(reference_image, str): - reference_image + reference_image = base64_to_image(reference_image).convert('RGB') - result = self.model.txt2img_pipeline( + result = self.model.img2img_pipeline( prompt=prompt, image=reference_image, negative_prompt=negative_prompt, guidance_scale=config.get_guidance_scale(), + strength=config.get_strength(), num_inference_steps=config.get_steps(), generator=generator, callback=None,