stable-diffusion-for-fun/main.py

284 lines
8.2 KiB
Python

import argparse
import copy
import tempfile
import pkgutil
import uuid
from flask import jsonify
from flask import Flask
from flask import render_template
from flask import request
from threading import Event
from threading import Thread
from threading import Lock
from utilities.constants import APIKEY
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import LOGGER_NAME
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import REFERENCE_IMG
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import VALUE_JOB_FAILED
from utilities.database import Database
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
from utilities.logger import Logger
from utilities.model import Model
from utilities.config import Config
from utilities.text2img import Text2Img
from utilities.img2img import Img2Img
app = Flask(__name__)
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
database = Database(logger)
use_gpu = True
local_job_stack = []
local_completed_jobs = []
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if APIKEY not in req:
logger.error(f"{APIKEY} not present in {req}")
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
logger.error(f"user not found with {req[APIKEY]}")
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
return jsonify({"msg": "missing reference image"}), 404
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
return (
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
500,
)
job_uuid = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(job_uuid))
with memory_lock:
database.insert_new_job(req, job_uuid=job_uuid)
return jsonify({"msg": "", UUID: job_uuid})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if APIKEY not in req:
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("cancelling job with uuid {}..".format(req[UUID]))
with memory_lock:
result = database.cancel_job(job_uuid=req[UUID])
if result:
msg = "job with uuid {} removed".format(req[UUID])
return jsonify({"msg": msg})
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
if jobs:
return (
jsonify(
{
"msg": "job {} is not in pending state, unable to cancel".format(
req[UUID]
)
}
),
405,
)
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if APIKEY not in req:
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
return jsonify({"jobs": jobs})
@app.route("/")
def index():
return render_template("index.html")
def load_model(logger: Logger) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "darkstorm2150/Protogen_x5.8_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x5.8_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu:
model.set_low_memory_mode()
model.load_all()
return model
def backend(event_termination, db):
model = load_model(logger)
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
text2img.breakfast()
img2img.breakfast()
while not event_termination.is_set():
wait_for_seconds(1)
with memory_lock:
pending_jobs = database.get_all_pending_jobs()
if len(pending_jobs) == 0:
continue
next_job = pending_jobs[0]
with memory_lock:
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
prompt = next_job[KEY_PROMPT]
negative_prompt = next_job[KEY_NEG_PROMPT]
config = Config().set_config(next_job)
try:
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
ref_img = next_job[REFERENCE_IMG]
result_dict = img2img.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
config=config,
)
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
with memory_lock:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
)
continue
with memory_lock:
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
database.update_job(result_dict, job_uuid=next_job[UUID])
logger.critical("stopped")
def main(db_filepath, is_testing: bool = False):
database.connect(db_filepath)
if is_testing:
try:
app.run(host="0.0.0.0", port="5000")
except KeyboardInterrupt:
pass
return
thread = Thread(
target=backend,
args=(
event_termination,
database,
),
)
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run(host="0.0.0.0", port="8888")
thread.join()
except KeyboardInterrupt:
event_termination.set()
database.safe_disconnect()
thread.join(2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'testing' flag
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
args = parser.parse_args()
logger.info(args)
main(args.db, args.testing)