add img2img backend barebone, reorganized components

This commit is contained in:
HappyZ 2023-04-30 00:32:20 -07:00
parent 5cdd1dc85b
commit 4ba9b8a9a1
7 changed files with 286 additions and 113 deletions

1
BUILD
View File

@ -11,6 +11,7 @@ par_binary(
"//utilities:logger", "//utilities:logger",
"//utilities:model", "//utilities:model",
"//utilities:text2img", "//utilities:text2img",
"//utilities:img2img",
"//utilities:envvar", "//utilities:envvar",
"//utilities:times", "//utilities:times",
], ],

25
main.py
View File

@ -14,14 +14,21 @@ from utilities.constants import API_KEY
from utilities.constants import API_KEY_FOR_DEMO from utilities.constants import API_KEY_FOR_DEMO
from utilities.constants import KEY_APP from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import KEY_PROMPT from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import LOGGER_NAME from utilities.constants import LOGGER_NAME
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import REFERENCE_IMG
from utilities.constants import MAX_JOB_NUMBER from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID from utilities.constants import UUID
from utilities.constants import VALUE_APP from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE from utilities.constants import VALUE_JOB_DONE
@ -36,7 +43,7 @@ from utilities.text2img import Text2Img
app = Flask(__name__) app = Flask(__name__)
fast_web_debugging = False fast_web_debugging = True
memory_lock = Lock() memory_lock = Lock()
event_termination = Event() event_termination = Event()
logger = Logger(name=LOGGER_NAME) logger = Logger(name=LOGGER_NAME)
@ -65,6 +72,9 @@ def add_job():
if required_key not in req: if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404 return jsonify({"msg": "missing one or more required keys"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
return jsonify({"msg": "missing reference image"}), 404
if len(local_job_stack) > MAX_JOB_NUMBER: if len(local_job_stack) > MAX_JOB_NUMBER:
return jsonify({"msg": "too many jobs in queue, please wait"}), 500 return jsonify({"msg": "too many jobs in queue, please wait"}), 500
@ -204,9 +214,11 @@ def load_model(logger: Logger) -> Model:
def backend(event_termination): def backend(event_termination):
model = load_model(logger) model = load_model(logger)
text2img = Text2Img(model, logger=logger) text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
text2img.breakfast() text2img.breakfast()
img2img.breakfast()
while not event_termination.is_set(): while not event_termination.is_set():
wait_for_seconds(1) wait_for_seconds(1)
@ -223,9 +235,18 @@ def backend(event_termination):
config = Config().set_config(next_job) config = Config().set_config(next_job)
try: try:
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
result_dict = text2img.lunch( result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config prompt=prompt, negative_prompt=negative_prompt, config=config
) )
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
ref_img = next_job[REFERENCE_IMG]
result_dict = img2img.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
config=config,
)
except BaseException as e: except BaseException as e:
logger.error("text2img.lunch error: {}".format(e)) logger.error("text2img.lunch error: {}".format(e))
local_job_stack.pop(0) local_job_stack.pop(0)

View File

@ -14,21 +14,7 @@
<body> <body>
<div class="container"> <div class="container">
<div class="card mb-3"> <div class="card mb-3">
<div class="card-header">
<ul class="nav nav-tabs card-header-tabs">
<li class="nav-item">
<a class="nav-link active" href="#">Text-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link disabled" href="#">Image-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link disabled" href="#">Inpainting</a>
</li>
</ul>
</div>
<div class="card-body"> <div class="card-body">
<form>
<div class="row mb-3"> <div class="row mb-3">
<div class="col-sm-8"> <div class="col-sm-8">
<label for="apiKey" class="form-label">API Key</label> <label for="apiKey" class="form-label">API Key</label>
@ -52,8 +38,7 @@
</div> </div>
<div class="form-row mb-3"> <div class="form-row mb-3">
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label> <label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" <input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" value="">
value="">
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. <div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
@ -64,40 +49,57 @@
neck"</div> neck"</div>
</div> </div>
<div class="row"> <div class="row">
<div class="col-md-6"> <div class="col-md-3">
<div class="row"> <div class="form-row">
<div class="form-row col-md-6">
<label for="inputSeed">Seed</label> <label for="inputSeed">Seed</label>
<input type="text" class="form-control" id="inputSeed" <input type="text" class="form-control" id="inputSeed" aria-describedby="inputSeedHelp"
aria-describedby="inputSeedHelp" value=""> value="">
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random <div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
seed seed
</div> </div>
</div> </div>
<div class="form-row col-md-6"> <div class="form-row">
<label for="inputSteps">Steps</label> <label for="inputSteps">Steps</label>
<input type="number" class="form-control" id="inputSteps" <input type="number" class="form-control" id="inputSteps" aria-describedby="inputStepsHelp"
aria-describedby="inputStepsHelp" placeholder="default is 50"> placeholder="default is 50">
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s <div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
(GPU) (GPU)
</div> </div>
</div> </div>
</div> <div class="form-row">
<div class="row">
<div class="form-row col-md-6">
<label for="inputWidth">Width</label> <label for="inputWidth">Width</label>
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1" <input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
max="1024"> max="1024">
</div> </div>
<div class="form-row col-md-6"> <div class="form-row">
<label for="inputHeight">Height</label> <label for="inputHeight">Height</label>
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1" <input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
max="1024"> max="1024">
</div> </div>
</div>
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button> <button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
</div> </div>
<div class="col-md-6"> <div class="col-md-9">
<div class="card mb-3">
<div class="card-header">
<ul class="nav nav-tabs card-header-tabs">
<li class="nav-item">
<a class="nav-link active" href="#card-txt">Text-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#card-img">Image-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#card-inpainting">Inpainting</a>
</li>
</ul>
</div>
<div class="card-body card-specific" id="card-img" style="display:none">
img
</div>
<div class="card-body card-specific" id="card-inpainting" style="display:none">
TBD
</div>
<div class="card-body card-specific" id="card-txt">
<div class="card"> <div class="card">
<div class="card-header"> <div class="card-header">
Result Result
@ -114,7 +116,9 @@
</div> </div>
</div> </div>
</div> </div>
</form>
</div>
</div>
</div> </div>
</div> </div>
@ -138,7 +142,6 @@
crossorigin="anonymous"></script> crossorigin="anonymous"></script>
<script> <script>
function waitForImage(apikeyVal, uuidValue) { function waitForImage(apikeyVal, uuidValue) {
// Wait until image is done // Wait until image is done
$.ajax({ $.ajax({
@ -186,6 +189,17 @@
} }
}); });
$(".nav-link").click(function (e) {
e.preventDefault();
var target = $(this).attr("href"); // get the href value of the clicked link
// hide all card divs and show the corresponding one
$(".card-specific").hide();
$(".nav-link").removeClass("active");
$(this).addClass("active");
$(target).show();
});
$('#newJob').click(function (e) { $('#newJob').click(function (e) {
e.preventDefault(); // Prevent the default form submission e.preventDefault(); // Prevent the default form submission
@ -248,6 +262,7 @@
dataType: 'json', dataType: 'json',
data: JSON.stringify({ data: JSON.stringify({
'api_key': apikeyVal, 'api_key': apikeyVal,
'type': 'txt',
'prompt': promptVal, 'prompt': promptVal,
'seed': seedVal, 'seed': seedVal,
'steps': stepsVal, 'steps': stepsVal,

View File

@ -73,6 +73,20 @@ py_library(
], ],
) )
py_library(
name="img2img",
srcs=["img2img.py"],
deps=[
":constants",
":config",
":logger",
":images",
":memory",
":model",
":times",
],
)
py_library( py_library(
name="times", name="times",
srcs=["times.py"], srcs=["times.py"],

View File

@ -1,5 +1,6 @@
import random import random
import time import time
from typing import Union
from utilities.constants import KEY_OUTPUT_FOLDER from utilities.constants import KEY_OUTPUT_FOLDER
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
@ -48,16 +49,26 @@ class Config:
def get_output_folder(self) -> str: def get_output_folder(self) -> str:
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT) return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
def set_output_folder(self, folder:str): def set_output_folder(self, folder: str):
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder)) self.__logger.info(
"{} changed from {} to {}".format(
KEY_OUTPUT_FOLDER, self.get_output_folder(), folder
)
)
self.__config[KEY_OUTPUT_FOLDER] = folder self.__config[KEY_OUTPUT_FOLDER] = folder
return self return self
def get_guidance_scale(self) -> float: def get_guidance_scale(self) -> float:
return float(self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)) return float(
self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
)
def set_guidance_scale(self, scale: float): def set_guidance_scale(self, scale: float):
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale)) self.__logger.info(
"{} changed from {} to {}".format(
KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale
)
)
self.__config[KEY_GUIDANCE_SCALE] = scale self.__config[KEY_GUIDANCE_SCALE] = scale
return self return self
@ -65,7 +76,9 @@ class Config:
return int(self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)) return int(self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT))
def set_height(self, value: int): def set_height(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)) self.__logger.info(
"{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)
)
self.__config[KEY_HEIGHT] = value self.__config[KEY_HEIGHT] = value
return self return self
@ -73,7 +86,9 @@ class Config:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT) return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
def set_preview(self, boolean: bool): def set_preview(self, boolean: bool):
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)) self.__logger.info(
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
)
self.__config[KEY_PREVIEW] = boolean self.__config[KEY_PREVIEW] = boolean
return self return self
@ -83,7 +98,11 @@ class Config:
def set_scheduler(self, scheduler: str): def set_scheduler(self, scheduler: str):
if not scheduler: if not scheduler:
scheduler = VALUE_SCHEDULER_DEFAULT scheduler = VALUE_SCHEDULER_DEFAULT
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler)) self.__logger.info(
"{} changed from {} to {}".format(
KEY_SCHEDULER, self.get_scheduler(), scheduler
)
)
self.__config[KEY_SCHEDULER] = scheduler self.__config[KEY_SCHEDULER] = scheduler
return self return self
@ -95,7 +114,9 @@ class Config:
return seed return seed
def set_seed(self, seed: int): def set_seed(self, seed: int):
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)) self.__logger.info(
"{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)
)
self.__config[KEY_SEED] = seed self.__config[KEY_SEED] = seed
return self return self
@ -103,7 +124,9 @@ class Config:
return int(self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)) return int(self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT))
def set_steps(self, steps: int): def set_steps(self, steps: int):
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)) self.__logger.info(
"{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)
)
self.__config[KEY_STEPS] = steps self.__config[KEY_STEPS] = steps
return self return self
@ -111,6 +134,8 @@ class Config:
return int(self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)) return int(self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT))
def set_width(self, value: int): def set_width(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)) self.__logger.info(
"{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)
)
self.__config[KEY_WIDTH] = value self.__config[KEY_WIDTH] = value
return self return self

View File

@ -1,9 +1,12 @@
LOGGER_NAME = "main"
MAX_JOB_NUMBER = 10
KEY_APP = "APP" KEY_APP = "APP"
VALUE_APP = "demo" VALUE_APP = "demo"
LOGGER_NAME = VALUE_APP
LOGGER_NAME_TXT2IMG = "txt2img"
LOGGER_NAME_IMG2IMG = "img2img"
MAX_JOB_NUMBER = 10
KEY_OUTPUT_FOLDER = "OUTFOLDER" KEY_OUTPUT_FOLDER = "OUTFOLDER"
VALUE_OUTPUT_FOLDER_DEFAULT = "" VALUE_OUTPUT_FOLDER_DEFAULT = ""
@ -20,7 +23,7 @@ KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0 VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
KEY_STEPS = "STEPS" KEY_STEPS = "STEPS"
VALUE_STEPS_DEFAULT = 100 VALUE_STEPS_DEFAULT = 50
KEY_SCHEDULER = "SCHEDULER" KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default" VALUE_SCHEDULER_DEFAULT = "Default"
@ -47,10 +50,16 @@ VALUE_JOB_PENDING = "pending"
VALUE_JOB_RUNNING = "running" VALUE_JOB_RUNNING = "running"
VALUE_JOB_DONE = "done" VALUE_JOB_DONE = "done"
VALUE_JOB_FAILED = "failed" VALUE_JOB_FAILED = "failed"
KEY_JOB_TYPE = "type"
VALUE_JOB_TXT2IMG = "txt"
VALUE_JOB_IMG2IMG = "img"
VALUE_JOB_INPAINTING = "inpaint"
REFERENCE_IMG = "ref_img"
REQUIRED_KEYS = [ REQUIRED_KEYS = [
API_KEY.lower(), API_KEY.lower(),
KEY_PROMPT.lower(), KEY_PROMPT.lower(),
KEY_JOB_TYPE.lower(),
] ]
OPTIONAL_KEYS = [ OPTIONAL_KEYS = [
KEY_NEG_PROMPT.lower(), KEY_NEG_PROMPT.lower(),

88
utilities/img2img.py Normal file
View File

@ -0,0 +1,88 @@
import torch
from typing import Union
from PIL import Image
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_SEED
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_STEPS
from utilities.config import Config
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
class Img2Img:
"""
Img2Img class.
"""
def __init__(
self,
model: Model,
output_folder: str = "",
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__output_folder = output_folder
self.__logger = logger
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
def lunch(
self,
prompt: str,
negative_prompt: str = "",
reference_image: Union[Image.Image, None, str] = None,
config: Config = Config(),
) -> dict:
if not prompt:
self.__logger.error("no prompt provided, won't proceed")
return {}
if reference_image is None:
return {}
self.model.set_img2img_scheduler(config.get_scheduler())
t = get_epoch_now()
seed = config.get_seed()
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
if isinstance(reference_image, str):
reference_image
result = self.model.txt2img_pipeline(
prompt=prompt,
image=reference_image,
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result.images[0].save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()
return {
BASE64IMAGE: image_to_base64(result.images[0]),
KEY_SEED.lower(): str(seed),
KEY_WIDTH.lower(): config.get_width(),
KEY_HEIGHT.lower(): config.get_height(),
KEY_STEPS.lower(): config.get_steps(),
}