stable-diffusion-for-fun/backend.py

170 lines
5.5 KiB
Python

import argparse
from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import UUID
from utilities.constants import KEY_LANGUAGE
from utilities.constants import VALUE_LANGUAGE_EN
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import VALUE_JOB_FAILED
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import REFERENCE_IMG
from utilities.translator import translate_prompt
from utilities.config import Config
from utilities.database import Database
from utilities.logger import Logger
from utilities.model import Model
from utilities.text2img import Text2Img
from utilities.img2img import Img2Img
from utilities.times import wait_for_seconds
logger = Logger(name=LOGGER_NAME_BACKEND)
database = Database(logger)
def load_model(logger: Logger, use_gpu: bool) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "darkstorm2150/Protogen_x5.8_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "SG161222/Realistic_Vision_V2.0"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu:
model.set_low_memory_mode()
model.load_all()
return model
def backend(model, is_debugging: bool):
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
text2img.breakfast()
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
img2img.breakfast()
while 1:
wait_for_seconds(1)
if is_debugging:
pending_jobs = database.get_jobs()
else:
pending_jobs = database.get_one_pending_job()
if len(pending_jobs) == 0:
continue
next_job = pending_jobs[0]
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]
)
prompt = next_job[KEY_PROMPT]
negative_prompt = next_job[KEY_NEG_PROMPT]
if KEY_LANGUAGE in next_job:
logger.info(
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
)
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
logger.info(f"translated {prompt} to {prompt_en}")
prompt = prompt_en
if negative_prompt:
negative_prompt_en = translate_prompt(
negative_prompt, next_job[KEY_LANGUAGE]
)
logger.info(f"translated {negative_prompt} to {negative_prompt_en}")
negative_prompt = negative_prompt_en
config = Config().set_config(next_job)
try:
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
ref_img = next_job[REFERENCE_IMG]
result_dict = img2img.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
config=config,
)
except KeyboardInterrupt:
break
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
)
continue
if not is_debugging:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID]
)
database.update_job(result_dict, job_uuid=next_job[UUID])
logger.critical("stopped")
def main(args):
database.set_image_output_folder(args.image_output_folder)
database.connect(args.db)
model = load_model(logger, args.gpu)
backend(model, args.debug)
database.safe_disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'debug' flag
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
# Add an argument to set the path of the database file
parser.add_argument("--gpu", action="store_true", help="Enable to use GPU device")
# Add an argument to set the path of the database file
parser.add_argument(
"--image-output-folder",
"-o",
type=str,
default="",
help="Path to output images",
)
args = parser.parse_args()
main(args)