add img2img backend barebone, reorganized components
This commit is contained in:
parent
5cdd1dc85b
commit
4ba9b8a9a1
1
BUILD
1
BUILD
|
|
@ -11,6 +11,7 @@ par_binary(
|
|||
"//utilities:logger",
|
||||
"//utilities:model",
|
||||
"//utilities:text2img",
|
||||
"//utilities:img2img",
|
||||
"//utilities:envvar",
|
||||
"//utilities:times",
|
||||
],
|
||||
|
|
|
|||
31
main.py
31
main.py
|
|
@ -14,14 +14,21 @@ from utilities.constants import API_KEY
|
|||
from utilities.constants import API_KEY_FOR_DEMO
|
||||
from utilities.constants import KEY_APP
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import KEY_PROMPT
|
||||
from utilities.constants import KEY_NEG_PROMPT
|
||||
from utilities.constants import LOGGER_NAME
|
||||
from utilities.constants import LOGGER_NAME_IMG2IMG
|
||||
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import MAX_JOB_NUMBER
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.constants import REQUIRED_KEYS
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import VALUE_APP
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import VALUE_JOB_INPAINTING
|
||||
from utilities.constants import VALUE_JOB_PENDING
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
|
|
@ -36,7 +43,7 @@ from utilities.text2img import Text2Img
|
|||
|
||||
|
||||
app = Flask(__name__)
|
||||
fast_web_debugging = False
|
||||
fast_web_debugging = True
|
||||
memory_lock = Lock()
|
||||
event_termination = Event()
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
|
|
@ -65,6 +72,9 @@ def add_job():
|
|||
if required_key not in req:
|
||||
return jsonify({"msg": "missing one or more required keys"}), 404
|
||||
|
||||
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||
return jsonify({"msg": "missing reference image"}), 404
|
||||
|
||||
if len(local_job_stack) > MAX_JOB_NUMBER:
|
||||
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
||||
|
||||
|
|
@ -204,9 +214,11 @@ def load_model(logger: Logger) -> Model:
|
|||
|
||||
def backend(event_termination):
|
||||
model = load_model(logger)
|
||||
text2img = Text2Img(model, logger=logger)
|
||||
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||
|
||||
text2img.breakfast()
|
||||
img2img.breakfast()
|
||||
|
||||
while not event_termination.is_set():
|
||||
wait_for_seconds(1)
|
||||
|
|
@ -223,9 +235,18 @@ def backend(event_termination):
|
|||
config = Config().set_config(next_job)
|
||||
|
||||
try:
|
||||
result_dict = text2img.lunch(
|
||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||
)
|
||||
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
||||
result_dict = text2img.lunch(
|
||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||
)
|
||||
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
|
||||
ref_img = next_job[REFERENCE_IMG]
|
||||
result_dict = img2img.lunch(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
reference_image=ref_img,
|
||||
config=config,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error("text2img.lunch error: {}".format(e))
|
||||
local_job_stack.pop(0)
|
||||
|
|
|
|||
|
|
@ -14,107 +14,111 @@
|
|||
<body>
|
||||
<div class="container">
|
||||
<div class="card mb-3">
|
||||
<div class="card-header">
|
||||
<ul class="nav nav-tabs card-header-tabs">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link active" href="#">Text-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link disabled" href="#">Image-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link disabled" href="#">Inpainting</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<form>
|
||||
<div class="row mb-3">
|
||||
<div class="col-sm-8">
|
||||
<label for="apiKey" class="form-label">API Key</label>
|
||||
<input type="password" class="form-control" id="apiKey" value="demo">
|
||||
<div class="row mb-3">
|
||||
<div class="col-sm-8">
|
||||
<label for="apiKey" class="form-label">API Key</label>
|
||||
<input type="password" class="form-control" id="apiKey" value="demo">
|
||||
</div>
|
||||
<div class="col-sm-4">
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
|
||||
<label class="form-check-label" for="showPreview">
|
||||
Preview Image
|
||||
</label>
|
||||
</div>
|
||||
<div class="col-sm-4">
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
|
||||
<label class="form-check-label" for="showPreview">
|
||||
Preview Image
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="prompt" class="form-label">Describe Your Image</label>
|
||||
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
|
||||
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
|
||||
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
|
||||
quality, film grain, Fujifilm XT3"</div>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
|
||||
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" value="">
|
||||
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
|
||||
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
|
||||
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
|
||||
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
|
||||
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
|
||||
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
|
||||
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
|
||||
neck"</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col-md-3">
|
||||
<div class="form-row">
|
||||
<label for="inputSeed">Seed</label>
|
||||
<input type="text" class="form-control" id="inputSeed" aria-describedby="inputSeedHelp"
|
||||
value="">
|
||||
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
|
||||
seed
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputSteps">Steps</label>
|
||||
<input type="number" class="form-control" id="inputSteps" aria-describedby="inputStepsHelp"
|
||||
placeholder="default is 50">
|
||||
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
|
||||
(GPU)
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputWidth">Width</label>
|
||||
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputHeight">Height</label>
|
||||
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="prompt" class="form-label">Describe Your Image</label>
|
||||
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
|
||||
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
|
||||
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
|
||||
quality, film grain, Fujifilm XT3"</div>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
|
||||
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp"
|
||||
value="">
|
||||
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
|
||||
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
|
||||
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
|
||||
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
|
||||
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
|
||||
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
|
||||
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
|
||||
neck"</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col-md-6">
|
||||
<div class="row">
|
||||
<div class="form-row col-md-6">
|
||||
<label for="inputSeed">Seed</label>
|
||||
<input type="text" class="form-control" id="inputSeed"
|
||||
aria-describedby="inputSeedHelp" value="">
|
||||
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
|
||||
seed
|
||||
<div class="col-md-9">
|
||||
<div class="card mb-3">
|
||||
<div class="card-header">
|
||||
<ul class="nav nav-tabs card-header-tabs">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link active" href="#card-txt">Text-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#card-img">Image-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#card-inpainting">Inpainting</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-img" style="display:none">
|
||||
img
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-inpainting" style="display:none">
|
||||
TBD
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-txt">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
Result
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row col-md-6">
|
||||
<label for="inputSteps">Steps</label>
|
||||
<input type="number" class="form-control" id="inputSteps"
|
||||
aria-describedby="inputStepsHelp" placeholder="default is 50">
|
||||
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
|
||||
(GPU)
|
||||
<div class="card-body">
|
||||
<ul class="list-group">
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultStatus"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultSeed"></li>
|
||||
</ul>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="newJobImg">
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="form-row col-md-6">
|
||||
<label for="inputWidth">Width</label>
|
||||
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<div class="form-row col-md-6">
|
||||
<label for="inputHeight">Height</label>
|
||||
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
</div>
|
||||
<button id="newJob" type="submit" class="btn btn-primary">Let's Go!</button>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
Result
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<ul class="list-group">
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultStatus"></li>
|
||||
<li class="list-group-item d-flex justify-content-between align-items-center"
|
||||
id="resultSeed"></li>
|
||||
</ul>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="newJobImg">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -138,7 +142,6 @@
|
|||
crossorigin="anonymous"></script>
|
||||
|
||||
<script>
|
||||
|
||||
function waitForImage(apikeyVal, uuidValue) {
|
||||
// Wait until image is done
|
||||
$.ajax({
|
||||
|
|
@ -186,6 +189,17 @@
|
|||
}
|
||||
});
|
||||
|
||||
$(".nav-link").click(function (e) {
|
||||
e.preventDefault();
|
||||
var target = $(this).attr("href"); // get the href value of the clicked link
|
||||
|
||||
// hide all card divs and show the corresponding one
|
||||
$(".card-specific").hide();
|
||||
$(".nav-link").removeClass("active");
|
||||
$(this).addClass("active");
|
||||
$(target).show();
|
||||
});
|
||||
|
||||
$('#newJob').click(function (e) {
|
||||
e.preventDefault(); // Prevent the default form submission
|
||||
|
||||
|
|
@ -248,6 +262,7 @@
|
|||
dataType: 'json',
|
||||
data: JSON.stringify({
|
||||
'api_key': apikeyVal,
|
||||
'type': 'txt',
|
||||
'prompt': promptVal,
|
||||
'seed': seedVal,
|
||||
'steps': stepsVal,
|
||||
|
|
|
|||
|
|
@ -73,6 +73,20 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="img2img",
|
||||
srcs=["img2img.py"],
|
||||
deps=[
|
||||
":constants",
|
||||
":config",
|
||||
":logger",
|
||||
":images",
|
||||
":memory",
|
||||
":model",
|
||||
":times",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="times",
|
||||
srcs=["times.py"],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import random
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from utilities.constants import KEY_OUTPUT_FOLDER
|
||||
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
|
||||
|
|
@ -37,27 +38,37 @@ class Config:
|
|||
|
||||
def get_config(self) -> dict:
|
||||
return self.__config
|
||||
|
||||
|
||||
def set_config(self, config: dict):
|
||||
for key in config:
|
||||
if key not in OPTIONAL_KEYS:
|
||||
continue
|
||||
self.__config[key.upper()] = config[key]
|
||||
return self
|
||||
|
||||
|
||||
def get_output_folder(self) -> str:
|
||||
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
||||
|
||||
def set_output_folder(self, folder:str):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
|
||||
def set_output_folder(self, folder: str):
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(
|
||||
KEY_OUTPUT_FOLDER, self.get_output_folder(), folder
|
||||
)
|
||||
)
|
||||
self.__config[KEY_OUTPUT_FOLDER] = folder
|
||||
return self
|
||||
|
||||
def get_guidance_scale(self) -> float:
|
||||
return float(self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT))
|
||||
return float(
|
||||
self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||
)
|
||||
|
||||
def set_guidance_scale(self, scale: float):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(
|
||||
KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale
|
||||
)
|
||||
)
|
||||
self.__config[KEY_GUIDANCE_SCALE] = scale
|
||||
return self
|
||||
|
||||
|
|
@ -65,7 +76,9 @@ class Config:
|
|||
return int(self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT))
|
||||
|
||||
def set_height(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)
|
||||
)
|
||||
self.__config[KEY_HEIGHT] = value
|
||||
return self
|
||||
|
||||
|
|
@ -73,7 +86,9 @@ class Config:
|
|||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||
|
||||
def set_preview(self, boolean: bool):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
|
||||
)
|
||||
self.__config[KEY_PREVIEW] = boolean
|
||||
return self
|
||||
|
||||
|
|
@ -83,7 +98,11 @@ class Config:
|
|||
def set_scheduler(self, scheduler: str):
|
||||
if not scheduler:
|
||||
scheduler = VALUE_SCHEDULER_DEFAULT
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(
|
||||
KEY_SCHEDULER, self.get_scheduler(), scheduler
|
||||
)
|
||||
)
|
||||
self.__config[KEY_SCHEDULER] = scheduler
|
||||
return self
|
||||
|
||||
|
|
@ -95,7 +114,9 @@ class Config:
|
|||
return seed
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)
|
||||
)
|
||||
self.__config[KEY_SEED] = seed
|
||||
return self
|
||||
|
||||
|
|
@ -103,7 +124,9 @@ class Config:
|
|||
return int(self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT))
|
||||
|
||||
def set_steps(self, steps: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)
|
||||
)
|
||||
self.__config[KEY_STEPS] = steps
|
||||
return self
|
||||
|
||||
|
|
@ -111,6 +134,8 @@ class Config:
|
|||
return int(self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT))
|
||||
|
||||
def set_width(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)
|
||||
)
|
||||
self.__config[KEY_WIDTH] = value
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
LOGGER_NAME = "main"
|
||||
MAX_JOB_NUMBER = 10
|
||||
|
||||
KEY_APP = "APP"
|
||||
VALUE_APP = "demo"
|
||||
|
||||
LOGGER_NAME = VALUE_APP
|
||||
LOGGER_NAME_TXT2IMG = "txt2img"
|
||||
LOGGER_NAME_IMG2IMG = "img2img"
|
||||
MAX_JOB_NUMBER = 10
|
||||
|
||||
|
||||
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||
|
||||
|
|
@ -20,7 +23,7 @@ KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
|||
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
||||
|
||||
KEY_STEPS = "STEPS"
|
||||
VALUE_STEPS_DEFAULT = 100
|
||||
VALUE_STEPS_DEFAULT = 50
|
||||
|
||||
KEY_SCHEDULER = "SCHEDULER"
|
||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||
|
|
@ -47,10 +50,16 @@ VALUE_JOB_PENDING = "pending"
|
|||
VALUE_JOB_RUNNING = "running"
|
||||
VALUE_JOB_DONE = "done"
|
||||
VALUE_JOB_FAILED = "failed"
|
||||
KEY_JOB_TYPE = "type"
|
||||
VALUE_JOB_TXT2IMG = "txt"
|
||||
VALUE_JOB_IMG2IMG = "img"
|
||||
VALUE_JOB_INPAINTING = "inpaint"
|
||||
REFERENCE_IMG = "ref_img"
|
||||
|
||||
REQUIRED_KEYS = [
|
||||
API_KEY.lower(),
|
||||
KEY_PROMPT.lower(),
|
||||
KEY_JOB_TYPE.lower(),
|
||||
]
|
||||
OPTIONAL_KEYS = [
|
||||
KEY_NEG_PROMPT.lower(),
|
||||
|
|
@ -60,4 +69,4 @@ OPTIONAL_KEYS = [
|
|||
KEY_GUIDANCE_SCALE.lower(),
|
||||
KEY_STEPS.lower(),
|
||||
KEY_SCHEDULER.lower(),
|
||||
]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
|
||||
from utilities.constants import BASE64IMAGE
|
||||
from utilities.constants import KEY_SEED
|
||||
from utilities.constants import KEY_WIDTH
|
||||
from utilities.constants import KEY_HEIGHT
|
||||
from utilities.constants import KEY_STEPS
|
||||
from utilities.config import Config
|
||||
from utilities.logger import DummyLogger
|
||||
from utilities.memory import empty_memory_cache
|
||||
from utilities.model import Model
|
||||
from utilities.times import get_epoch_now
|
||||
from utilities.images import image_to_base64
|
||||
|
||||
|
||||
class Img2Img:
|
||||
"""
|
||||
Img2Img class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
output_folder: str = "",
|
||||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
def brunch(self, prompt: str, negative_prompt: str = ""):
|
||||
self.breakfast()
|
||||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
pass
|
||||
|
||||
def lunch(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
reference_image: Union[Image.Image, None, str] = None,
|
||||
config: Config = Config(),
|
||||
) -> dict:
|
||||
if not prompt:
|
||||
self.__logger.error("no prompt provided, won't proceed")
|
||||
return {}
|
||||
if reference_image is None:
|
||||
return {}
|
||||
|
||||
self.model.set_img2img_scheduler(config.get_scheduler())
|
||||
|
||||
t = get_epoch_now()
|
||||
seed = config.get_seed()
|
||||
generator = torch.Generator(self.__device).manual_seed(seed)
|
||||
self.__logger.info("current seed: {}".format(seed))
|
||||
|
||||
if isinstance(reference_image, str):
|
||||
reference_image
|
||||
|
||||
result = self.model.txt2img_pipeline(
|
||||
prompt=prompt,
|
||||
image=reference_image,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
callback=None,
|
||||
callback_steps=10,
|
||||
)
|
||||
|
||||
if self.__output_folder:
|
||||
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||
result.images[0].save(out_filepath)
|
||||
self.__logger.info("output to file: {}".format(out_filepath))
|
||||
|
||||
empty_memory_cache()
|
||||
|
||||
return {
|
||||
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||
KEY_SEED.lower(): str(seed),
|
||||
KEY_WIDTH.lower(): config.get_width(),
|
||||
KEY_HEIGHT.lower(): config.get_height(),
|
||||
KEY_STEPS.lower(): config.get_steps(),
|
||||
}
|
||||
Loading…
Reference in New Issue