add img2img backend barebone, reorganized components
This commit is contained in:
parent
5cdd1dc85b
commit
4ba9b8a9a1
1
BUILD
1
BUILD
|
|
@ -11,6 +11,7 @@ par_binary(
|
||||||
"//utilities:logger",
|
"//utilities:logger",
|
||||||
"//utilities:model",
|
"//utilities:model",
|
||||||
"//utilities:text2img",
|
"//utilities:text2img",
|
||||||
|
"//utilities:img2img",
|
||||||
"//utilities:envvar",
|
"//utilities:envvar",
|
||||||
"//utilities:times",
|
"//utilities:times",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
31
main.py
31
main.py
|
|
@ -14,14 +14,21 @@ from utilities.constants import API_KEY
|
||||||
from utilities.constants import API_KEY_FOR_DEMO
|
from utilities.constants import API_KEY_FOR_DEMO
|
||||||
from utilities.constants import KEY_APP
|
from utilities.constants import KEY_APP
|
||||||
from utilities.constants import KEY_JOB_STATUS
|
from utilities.constants import KEY_JOB_STATUS
|
||||||
|
from utilities.constants import KEY_JOB_TYPE
|
||||||
from utilities.constants import KEY_PROMPT
|
from utilities.constants import KEY_PROMPT
|
||||||
from utilities.constants import KEY_NEG_PROMPT
|
from utilities.constants import KEY_NEG_PROMPT
|
||||||
from utilities.constants import LOGGER_NAME
|
from utilities.constants import LOGGER_NAME
|
||||||
|
from utilities.constants import LOGGER_NAME_IMG2IMG
|
||||||
|
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||||
|
from utilities.constants import REFERENCE_IMG
|
||||||
from utilities.constants import MAX_JOB_NUMBER
|
from utilities.constants import MAX_JOB_NUMBER
|
||||||
from utilities.constants import OPTIONAL_KEYS
|
from utilities.constants import OPTIONAL_KEYS
|
||||||
from utilities.constants import REQUIRED_KEYS
|
from utilities.constants import REQUIRED_KEYS
|
||||||
from utilities.constants import UUID
|
from utilities.constants import UUID
|
||||||
from utilities.constants import VALUE_APP
|
from utilities.constants import VALUE_APP
|
||||||
|
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||||
|
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||||
|
from utilities.constants import VALUE_JOB_INPAINTING
|
||||||
from utilities.constants import VALUE_JOB_PENDING
|
from utilities.constants import VALUE_JOB_PENDING
|
||||||
from utilities.constants import VALUE_JOB_RUNNING
|
from utilities.constants import VALUE_JOB_RUNNING
|
||||||
from utilities.constants import VALUE_JOB_DONE
|
from utilities.constants import VALUE_JOB_DONE
|
||||||
|
|
@ -36,7 +43,7 @@ from utilities.text2img import Text2Img
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
fast_web_debugging = False
|
fast_web_debugging = True
|
||||||
memory_lock = Lock()
|
memory_lock = Lock()
|
||||||
event_termination = Event()
|
event_termination = Event()
|
||||||
logger = Logger(name=LOGGER_NAME)
|
logger = Logger(name=LOGGER_NAME)
|
||||||
|
|
@ -65,6 +72,9 @@ def add_job():
|
||||||
if required_key not in req:
|
if required_key not in req:
|
||||||
return jsonify({"msg": "missing one or more required keys"}), 404
|
return jsonify({"msg": "missing one or more required keys"}), 404
|
||||||
|
|
||||||
|
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||||
|
return jsonify({"msg": "missing reference image"}), 404
|
||||||
|
|
||||||
if len(local_job_stack) > MAX_JOB_NUMBER:
|
if len(local_job_stack) > MAX_JOB_NUMBER:
|
||||||
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
||||||
|
|
||||||
|
|
@ -204,9 +214,11 @@ def load_model(logger: Logger) -> Model:
|
||||||
|
|
||||||
def backend(event_termination):
|
def backend(event_termination):
|
||||||
model = load_model(logger)
|
model = load_model(logger)
|
||||||
text2img = Text2Img(model, logger=logger)
|
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||||
|
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||||
|
|
||||||
text2img.breakfast()
|
text2img.breakfast()
|
||||||
|
img2img.breakfast()
|
||||||
|
|
||||||
while not event_termination.is_set():
|
while not event_termination.is_set():
|
||||||
wait_for_seconds(1)
|
wait_for_seconds(1)
|
||||||
|
|
@ -223,9 +235,18 @@ def backend(event_termination):
|
||||||
config = Config().set_config(next_job)
|
config = Config().set_config(next_job)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_dict = text2img.lunch(
|
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
||||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
result_dict = text2img.lunch(
|
||||||
)
|
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||||
|
)
|
||||||
|
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
|
||||||
|
ref_img = next_job[REFERENCE_IMG]
|
||||||
|
result_dict = img2img.lunch(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
reference_image=ref_img,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("text2img.lunch error: {}".format(e))
|
logger.error("text2img.lunch error: {}".format(e))
|
||||||
local_job_stack.pop(0)
|
local_job_stack.pop(0)
|
||||||
|
|
|
||||||
|
|
@ -14,107 +14,111 @@
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<div class="card mb-3">
|
<div class="card mb-3">
|
||||||
<div class="card-header">
|
|
||||||
<ul class="nav nav-tabs card-header-tabs">
|
|
||||||
<li class="nav-item">
|
|
||||||
<a class="nav-link active" href="#">Text-to-Image</a>
|
|
||||||
</li>
|
|
||||||
<li class="nav-item">
|
|
||||||
<a class="nav-link disabled" href="#">Image-to-Image</a>
|
|
||||||
</li>
|
|
||||||
<li class="nav-item">
|
|
||||||
<a class="nav-link disabled" href="#">Inpainting</a>
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
<div class="card-body">
|
<div class="card-body">
|
||||||
<form>
|
<div class="row mb-3">
|
||||||
<div class="row mb-3">
|
<div class="col-sm-8">
|
||||||
<div class="col-sm-8">
|
<label for="apiKey" class="form-label">API Key</label>
|
||||||
<label for="apiKey" class="form-label">API Key</label>
|
<input type="password" class="form-control" id="apiKey" value="demo">
|
||||||
<input type="password" class="form-control" id="apiKey" value="demo">
|
</div>
|
||||||
|
<div class="col-sm-4">
|
||||||
|
<div class="form-check">
|
||||||
|
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
|
||||||
|
<label class="form-check-label" for="showPreview">
|
||||||
|
Preview Image
|
||||||
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<div class="col-sm-4">
|
</div>
|
||||||
<div class="form-check">
|
</div>
|
||||||
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
|
<div class="form-row mb-3">
|
||||||
<label class="form-check-label" for="showPreview">
|
<label for="prompt" class="form-label">Describe Your Image</label>
|
||||||
Preview Image
|
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
|
||||||
</label>
|
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
|
||||||
|
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
|
||||||
|
quality, film grain, Fujifilm XT3"</div>
|
||||||
|
</div>
|
||||||
|
<div class="form-row mb-3">
|
||||||
|
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
|
||||||
|
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" value="">
|
||||||
|
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
|
||||||
|
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
|
||||||
|
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
|
||||||
|
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
|
||||||
|
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
|
||||||
|
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
|
||||||
|
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
|
||||||
|
neck"</div>
|
||||||
|
</div>
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="form-row">
|
||||||
|
<label for="inputSeed">Seed</label>
|
||||||
|
<input type="text" class="form-control" id="inputSeed" aria-describedby="inputSeedHelp"
|
||||||
|
value="">
|
||||||
|
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
|
||||||
|
seed
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="form-row">
|
||||||
|
<label for="inputSteps">Steps</label>
|
||||||
|
<input type="number" class="form-control" id="inputSteps" aria-describedby="inputStepsHelp"
|
||||||
|
placeholder="default is 50">
|
||||||
|
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
|
||||||
|
(GPU)
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="form-row">
|
||||||
|
<label for="inputWidth">Width</label>
|
||||||
|
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
|
||||||
|
max="1024">
|
||||||
|
</div>
|
||||||
|
<div class="form-row">
|
||||||
|
<label for="inputHeight">Height</label>
|
||||||
|
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
||||||
|
max="1024">
|
||||||
|
</div>
|
||||||
|
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-row mb-3">
|
<div class="col-md-9">
|
||||||
<label for="prompt" class="form-label">Describe Your Image</label>
|
<div class="card mb-3">
|
||||||
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
|
<div class="card-header">
|
||||||
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
|
<ul class="nav nav-tabs card-header-tabs">
|
||||||
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
|
<li class="nav-item">
|
||||||
quality, film grain, Fujifilm XT3"</div>
|
<a class="nav-link active" href="#card-txt">Text-to-Image</a>
|
||||||
</div>
|
</li>
|
||||||
<div class="form-row mb-3">
|
<li class="nav-item">
|
||||||
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
|
<a class="nav-link" href="#card-img">Image-to-Image</a>
|
||||||
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp"
|
</li>
|
||||||
value="">
|
<li class="nav-item">
|
||||||
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
|
<a class="nav-link" href="#card-inpainting">Inpainting</a>
|
||||||
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
|
</li>
|
||||||
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
|
</ul>
|
||||||
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
|
</div>
|
||||||
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
|
<div class="card-body card-specific" id="card-img" style="display:none">
|
||||||
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
|
img
|
||||||
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
|
</div>
|
||||||
neck"</div>
|
<div class="card-body card-specific" id="card-inpainting" style="display:none">
|
||||||
</div>
|
TBD
|
||||||
<div class="row">
|
</div>
|
||||||
<div class="col-md-6">
|
<div class="card-body card-specific" id="card-txt">
|
||||||
<div class="row">
|
<div class="card">
|
||||||
<div class="form-row col-md-6">
|
<div class="card-header">
|
||||||
<label for="inputSeed">Seed</label>
|
Result
|
||||||
<input type="text" class="form-control" id="inputSeed"
|
|
||||||
aria-describedby="inputSeedHelp" value="">
|
|
||||||
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
|
|
||||||
seed
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
<div class="card-body">
|
||||||
<div class="form-row col-md-6">
|
<ul class="list-group">
|
||||||
<label for="inputSteps">Steps</label>
|
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||||
<input type="number" class="form-control" id="inputSteps"
|
id="resultStatus"></li>
|
||||||
aria-describedby="inputStepsHelp" placeholder="default is 50">
|
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||||
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
|
id="resultSeed"></li>
|
||||||
(GPU)
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
<img class="card-img-bottom" id="newJobImg">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="row">
|
|
||||||
<div class="form-row col-md-6">
|
|
||||||
<label for="inputWidth">Width</label>
|
|
||||||
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
|
|
||||||
max="1024">
|
|
||||||
</div>
|
|
||||||
<div class="form-row col-md-6">
|
|
||||||
<label for="inputHeight">Height</label>
|
|
||||||
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
|
||||||
max="1024">
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-6">
|
|
||||||
<div class="card">
|
|
||||||
<div class="card-header">
|
|
||||||
Result
|
|
||||||
</div>
|
|
||||||
<div class="card-body">
|
|
||||||
<ul class="list-group">
|
|
||||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
|
||||||
id="resultStatus"></li>
|
|
||||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
|
||||||
id="resultSeed"></li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
<img class="card-img-bottom" id="newJobImg">
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
@ -138,7 +142,6 @@
|
||||||
crossorigin="anonymous"></script>
|
crossorigin="anonymous"></script>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
|
||||||
function waitForImage(apikeyVal, uuidValue) {
|
function waitForImage(apikeyVal, uuidValue) {
|
||||||
// Wait until image is done
|
// Wait until image is done
|
||||||
$.ajax({
|
$.ajax({
|
||||||
|
|
@ -186,6 +189,17 @@
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
$(".nav-link").click(function (e) {
|
||||||
|
e.preventDefault();
|
||||||
|
var target = $(this).attr("href"); // get the href value of the clicked link
|
||||||
|
|
||||||
|
// hide all card divs and show the corresponding one
|
||||||
|
$(".card-specific").hide();
|
||||||
|
$(".nav-link").removeClass("active");
|
||||||
|
$(this).addClass("active");
|
||||||
|
$(target).show();
|
||||||
|
});
|
||||||
|
|
||||||
$('#newJob').click(function (e) {
|
$('#newJob').click(function (e) {
|
||||||
e.preventDefault(); // Prevent the default form submission
|
e.preventDefault(); // Prevent the default form submission
|
||||||
|
|
||||||
|
|
@ -248,6 +262,7 @@
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
data: JSON.stringify({
|
data: JSON.stringify({
|
||||||
'api_key': apikeyVal,
|
'api_key': apikeyVal,
|
||||||
|
'type': 'txt',
|
||||||
'prompt': promptVal,
|
'prompt': promptVal,
|
||||||
'seed': seedVal,
|
'seed': seedVal,
|
||||||
'steps': stepsVal,
|
'steps': stepsVal,
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,20 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="img2img",
|
||||||
|
srcs=["img2img.py"],
|
||||||
|
deps=[
|
||||||
|
":constants",
|
||||||
|
":config",
|
||||||
|
":logger",
|
||||||
|
":images",
|
||||||
|
":memory",
|
||||||
|
":model",
|
||||||
|
":times",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name="times",
|
name="times",
|
||||||
srcs=["times.py"],
|
srcs=["times.py"],
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from utilities.constants import KEY_OUTPUT_FOLDER
|
from utilities.constants import KEY_OUTPUT_FOLDER
|
||||||
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
|
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
|
||||||
|
|
@ -37,27 +38,37 @@ class Config:
|
||||||
|
|
||||||
def get_config(self) -> dict:
|
def get_config(self) -> dict:
|
||||||
return self.__config
|
return self.__config
|
||||||
|
|
||||||
def set_config(self, config: dict):
|
def set_config(self, config: dict):
|
||||||
for key in config:
|
for key in config:
|
||||||
if key not in OPTIONAL_KEYS:
|
if key not in OPTIONAL_KEYS:
|
||||||
continue
|
continue
|
||||||
self.__config[key.upper()] = config[key]
|
self.__config[key.upper()] = config[key]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_output_folder(self) -> str:
|
def get_output_folder(self) -> str:
|
||||||
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
||||||
|
|
||||||
def set_output_folder(self, folder:str):
|
def set_output_folder(self, folder: str):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(
|
||||||
|
KEY_OUTPUT_FOLDER, self.get_output_folder(), folder
|
||||||
|
)
|
||||||
|
)
|
||||||
self.__config[KEY_OUTPUT_FOLDER] = folder
|
self.__config[KEY_OUTPUT_FOLDER] = folder
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_guidance_scale(self) -> float:
|
def get_guidance_scale(self) -> float:
|
||||||
return float(self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT))
|
return float(
|
||||||
|
self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||||
|
)
|
||||||
|
|
||||||
def set_guidance_scale(self, scale: float):
|
def set_guidance_scale(self, scale: float):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(
|
||||||
|
KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale
|
||||||
|
)
|
||||||
|
)
|
||||||
self.__config[KEY_GUIDANCE_SCALE] = scale
|
self.__config[KEY_GUIDANCE_SCALE] = scale
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -65,7 +76,9 @@ class Config:
|
||||||
return int(self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT))
|
return int(self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT))
|
||||||
|
|
||||||
def set_height(self, value: int):
|
def set_height(self, value: int):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)
|
||||||
|
)
|
||||||
self.__config[KEY_HEIGHT] = value
|
self.__config[KEY_HEIGHT] = value
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -73,7 +86,9 @@ class Config:
|
||||||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||||
|
|
||||||
def set_preview(self, boolean: bool):
|
def set_preview(self, boolean: bool):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
|
||||||
|
)
|
||||||
self.__config[KEY_PREVIEW] = boolean
|
self.__config[KEY_PREVIEW] = boolean
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -83,7 +98,11 @@ class Config:
|
||||||
def set_scheduler(self, scheduler: str):
|
def set_scheduler(self, scheduler: str):
|
||||||
if not scheduler:
|
if not scheduler:
|
||||||
scheduler = VALUE_SCHEDULER_DEFAULT
|
scheduler = VALUE_SCHEDULER_DEFAULT
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(
|
||||||
|
KEY_SCHEDULER, self.get_scheduler(), scheduler
|
||||||
|
)
|
||||||
|
)
|
||||||
self.__config[KEY_SCHEDULER] = scheduler
|
self.__config[KEY_SCHEDULER] = scheduler
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -95,7 +114,9 @@ class Config:
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
def set_seed(self, seed: int):
|
def set_seed(self, seed: int):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)
|
||||||
|
)
|
||||||
self.__config[KEY_SEED] = seed
|
self.__config[KEY_SEED] = seed
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -103,7 +124,9 @@ class Config:
|
||||||
return int(self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT))
|
return int(self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT))
|
||||||
|
|
||||||
def set_steps(self, steps: int):
|
def set_steps(self, steps: int):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)
|
||||||
|
)
|
||||||
self.__config[KEY_STEPS] = steps
|
self.__config[KEY_STEPS] = steps
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -111,6 +134,8 @@ class Config:
|
||||||
return int(self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT))
|
return int(self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT))
|
||||||
|
|
||||||
def set_width(self, value: int):
|
def set_width(self, value: int):
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
self.__logger.info(
|
||||||
|
"{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)
|
||||||
|
)
|
||||||
self.__config[KEY_WIDTH] = value
|
self.__config[KEY_WIDTH] = value
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
LOGGER_NAME = "main"
|
|
||||||
MAX_JOB_NUMBER = 10
|
|
||||||
|
|
||||||
KEY_APP = "APP"
|
KEY_APP = "APP"
|
||||||
VALUE_APP = "demo"
|
VALUE_APP = "demo"
|
||||||
|
|
||||||
|
LOGGER_NAME = VALUE_APP
|
||||||
|
LOGGER_NAME_TXT2IMG = "txt2img"
|
||||||
|
LOGGER_NAME_IMG2IMG = "img2img"
|
||||||
|
MAX_JOB_NUMBER = 10
|
||||||
|
|
||||||
|
|
||||||
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
||||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||||
|
|
||||||
|
|
@ -20,7 +23,7 @@ KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
||||||
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
||||||
|
|
||||||
KEY_STEPS = "STEPS"
|
KEY_STEPS = "STEPS"
|
||||||
VALUE_STEPS_DEFAULT = 100
|
VALUE_STEPS_DEFAULT = 50
|
||||||
|
|
||||||
KEY_SCHEDULER = "SCHEDULER"
|
KEY_SCHEDULER = "SCHEDULER"
|
||||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||||
|
|
@ -47,10 +50,16 @@ VALUE_JOB_PENDING = "pending"
|
||||||
VALUE_JOB_RUNNING = "running"
|
VALUE_JOB_RUNNING = "running"
|
||||||
VALUE_JOB_DONE = "done"
|
VALUE_JOB_DONE = "done"
|
||||||
VALUE_JOB_FAILED = "failed"
|
VALUE_JOB_FAILED = "failed"
|
||||||
|
KEY_JOB_TYPE = "type"
|
||||||
|
VALUE_JOB_TXT2IMG = "txt"
|
||||||
|
VALUE_JOB_IMG2IMG = "img"
|
||||||
|
VALUE_JOB_INPAINTING = "inpaint"
|
||||||
|
REFERENCE_IMG = "ref_img"
|
||||||
|
|
||||||
REQUIRED_KEYS = [
|
REQUIRED_KEYS = [
|
||||||
API_KEY.lower(),
|
API_KEY.lower(),
|
||||||
KEY_PROMPT.lower(),
|
KEY_PROMPT.lower(),
|
||||||
|
KEY_JOB_TYPE.lower(),
|
||||||
]
|
]
|
||||||
OPTIONAL_KEYS = [
|
OPTIONAL_KEYS = [
|
||||||
KEY_NEG_PROMPT.lower(),
|
KEY_NEG_PROMPT.lower(),
|
||||||
|
|
@ -60,4 +69,4 @@ OPTIONAL_KEYS = [
|
||||||
KEY_GUIDANCE_SCALE.lower(),
|
KEY_GUIDANCE_SCALE.lower(),
|
||||||
KEY_STEPS.lower(),
|
KEY_STEPS.lower(),
|
||||||
KEY_SCHEDULER.lower(),
|
KEY_SCHEDULER.lower(),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from utilities.constants import BASE64IMAGE
|
||||||
|
from utilities.constants import KEY_SEED
|
||||||
|
from utilities.constants import KEY_WIDTH
|
||||||
|
from utilities.constants import KEY_HEIGHT
|
||||||
|
from utilities.constants import KEY_STEPS
|
||||||
|
from utilities.config import Config
|
||||||
|
from utilities.logger import DummyLogger
|
||||||
|
from utilities.memory import empty_memory_cache
|
||||||
|
from utilities.model import Model
|
||||||
|
from utilities.times import get_epoch_now
|
||||||
|
from utilities.images import image_to_base64
|
||||||
|
|
||||||
|
|
||||||
|
class Img2Img:
|
||||||
|
"""
|
||||||
|
Img2Img class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
output_folder: str = "",
|
||||||
|
logger: DummyLogger = DummyLogger(),
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||||
|
self.__output_folder = output_folder
|
||||||
|
self.__logger = logger
|
||||||
|
|
||||||
|
def brunch(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
self.breakfast()
|
||||||
|
self.lunch(prompt, negative_prompt)
|
||||||
|
|
||||||
|
def breakfast(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def lunch(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
reference_image: Union[Image.Image, None, str] = None,
|
||||||
|
config: Config = Config(),
|
||||||
|
) -> dict:
|
||||||
|
if not prompt:
|
||||||
|
self.__logger.error("no prompt provided, won't proceed")
|
||||||
|
return {}
|
||||||
|
if reference_image is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
self.model.set_img2img_scheduler(config.get_scheduler())
|
||||||
|
|
||||||
|
t = get_epoch_now()
|
||||||
|
seed = config.get_seed()
|
||||||
|
generator = torch.Generator(self.__device).manual_seed(seed)
|
||||||
|
self.__logger.info("current seed: {}".format(seed))
|
||||||
|
|
||||||
|
if isinstance(reference_image, str):
|
||||||
|
reference_image
|
||||||
|
|
||||||
|
result = self.model.txt2img_pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
image=reference_image,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
guidance_scale=config.get_guidance_scale(),
|
||||||
|
num_inference_steps=config.get_steps(),
|
||||||
|
generator=generator,
|
||||||
|
callback=None,
|
||||||
|
callback_steps=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__output_folder:
|
||||||
|
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||||
|
result.images[0].save(out_filepath)
|
||||||
|
self.__logger.info("output to file: {}".format(out_filepath))
|
||||||
|
|
||||||
|
empty_memory_cache()
|
||||||
|
|
||||||
|
return {
|
||||||
|
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||||
|
KEY_SEED.lower(): str(seed),
|
||||||
|
KEY_WIDTH.lower(): config.get_width(),
|
||||||
|
KEY_HEIGHT.lower(): config.get_height(),
|
||||||
|
KEY_STEPS.lower(): config.get_steps(),
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue