292 lines
9.4 KiB
Python
292 lines
9.4 KiB
Python
import copy
|
|
import tempfile
|
|
import pkgutil
|
|
import uuid
|
|
from flask import jsonify
|
|
from flask import Flask
|
|
from flask import render_template
|
|
from flask import request
|
|
from threading import Event
|
|
from threading import Thread
|
|
from threading import Lock
|
|
|
|
from utilities.constants import API_KEY
|
|
from utilities.constants import API_KEY_FOR_DEMO
|
|
from utilities.constants import KEY_APP
|
|
from utilities.constants import KEY_JOB_STATUS
|
|
from utilities.constants import KEY_JOB_TYPE
|
|
from utilities.constants import KEY_PROMPT
|
|
from utilities.constants import KEY_NEG_PROMPT
|
|
from utilities.constants import LOGGER_NAME
|
|
from utilities.constants import LOGGER_NAME_IMG2IMG
|
|
from utilities.constants import LOGGER_NAME_TXT2IMG
|
|
from utilities.constants import REFERENCE_IMG
|
|
from utilities.constants import MAX_JOB_NUMBER
|
|
from utilities.constants import OPTIONAL_KEYS
|
|
from utilities.constants import REQUIRED_KEYS
|
|
from utilities.constants import UUID
|
|
from utilities.constants import VALUE_APP
|
|
from utilities.constants import VALUE_JOB_TXT2IMG
|
|
from utilities.constants import VALUE_JOB_IMG2IMG
|
|
from utilities.constants import VALUE_JOB_INPAINTING
|
|
from utilities.constants import VALUE_JOB_PENDING
|
|
from utilities.constants import VALUE_JOB_RUNNING
|
|
from utilities.constants import VALUE_JOB_DONE
|
|
from utilities.constants import VALUE_JOB_FAILED
|
|
from utilities.envvar import get_env_var_with_default
|
|
from utilities.envvar import get_env_var
|
|
from utilities.times import wait_for_seconds
|
|
from utilities.logger import Logger
|
|
from utilities.model import Model
|
|
from utilities.config import Config
|
|
from utilities.text2img import Text2Img
|
|
|
|
|
|
app = Flask(__name__)
|
|
fast_web_debugging = True
|
|
memory_lock = Lock()
|
|
event_termination = Event()
|
|
logger = Logger(name=LOGGER_NAME)
|
|
use_gpu = True
|
|
|
|
local_job_stack = []
|
|
local_completed_jobs = []
|
|
|
|
|
|
@app.route("/add_job", methods=["POST"])
|
|
def add_job():
|
|
req = request.get_json()
|
|
if API_KEY not in req:
|
|
return "", 401
|
|
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
|
if req[API_KEY] != API_KEY_FOR_DEMO:
|
|
return "", 401
|
|
else:
|
|
# TODO: add logic to validate app key with a particular user
|
|
return "", 401
|
|
|
|
for key in req.keys():
|
|
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
|
|
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
|
|
for required_key in REQUIRED_KEYS:
|
|
if required_key not in req:
|
|
return jsonify({"msg": "missing one or more required keys"}), 404
|
|
|
|
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
|
return jsonify({"msg": "missing reference image"}), 404
|
|
|
|
if len(local_job_stack) > MAX_JOB_NUMBER:
|
|
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
|
|
|
req[UUID] = str(uuid.uuid4())
|
|
logger.info("adding a new job with uuid {}..".format(req[UUID]))
|
|
|
|
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
|
|
req["position"] = len(local_job_stack) + 1
|
|
|
|
with memory_lock:
|
|
local_job_stack.append(req)
|
|
|
|
return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]})
|
|
|
|
|
|
@app.route("/cancel_job", methods=["POST"])
|
|
def cancel_job():
|
|
req = request.get_json()
|
|
if API_KEY not in req:
|
|
return "", 401
|
|
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
|
if req[API_KEY] != API_KEY_FOR_DEMO:
|
|
return "", 401
|
|
else:
|
|
# TODO: add logic to validate app key with a particular user
|
|
return "", 401
|
|
|
|
if UUID not in req:
|
|
return jsonify({"msg": "missing uuid"}), 404
|
|
|
|
logger.info("removing job with uuid {}..".format(req[UUID]))
|
|
|
|
cancel_job_position = None
|
|
with memory_lock:
|
|
for job_position in range(len(local_job_stack)):
|
|
if local_job_stack[job_position][UUID] == req[UUID]:
|
|
cancel_job_position = job_position
|
|
break
|
|
logger.info("foud {}".format(cancel_job_position))
|
|
if cancel_job_position is not None:
|
|
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
|
|
return "", 401
|
|
if (
|
|
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
|
|
== VALUE_JOB_RUNNING
|
|
):
|
|
logger.info(
|
|
"job at {} with uuid {} is running and cannot be cancelled".format(
|
|
cancel_job_position, req[UUID]
|
|
)
|
|
)
|
|
return (
|
|
jsonify(
|
|
{
|
|
"msg": "job {} is already running, unable to cancel".format(
|
|
req[UUID]
|
|
)
|
|
}
|
|
),
|
|
405,
|
|
)
|
|
del local_job_stack[cancel_job_position]
|
|
msg = "job with uuid {} removed".format(req[UUID])
|
|
logger.info(msg)
|
|
return jsonify({"msg": msg})
|
|
return (
|
|
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
|
404,
|
|
)
|
|
|
|
|
|
@app.route("/get_jobs", methods=["POST"])
|
|
def get_jobs():
|
|
req = request.get_json()
|
|
if API_KEY not in req:
|
|
return "", 401
|
|
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
|
if req[API_KEY] != API_KEY_FOR_DEMO:
|
|
return "", 401
|
|
else:
|
|
# TODO: add logic to validate app key with a particular user
|
|
return "", 401
|
|
|
|
jobs = []
|
|
|
|
all_job_stack = local_job_stack + local_completed_jobs
|
|
with memory_lock:
|
|
for job_position in range(len(all_job_stack)):
|
|
# filter on API_KEY
|
|
if all_job_stack[job_position][API_KEY] != req[API_KEY]:
|
|
continue
|
|
# filter on UUID
|
|
if UUID in req and req[UUID] != all_job_stack[job_position][UUID]:
|
|
continue
|
|
job = copy.deepcopy(all_job_stack[job_position])
|
|
if job[KEY_JOB_STATUS] == VALUE_JOB_DONE:
|
|
del job["position"]
|
|
del job[API_KEY]
|
|
jobs.append(job)
|
|
|
|
if len(jobs) == 0:
|
|
return (
|
|
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
|
|
404,
|
|
)
|
|
return jsonify({"jobs": jobs})
|
|
|
|
|
|
@app.route("/")
|
|
def index():
|
|
return render_template("index.html")
|
|
|
|
|
|
def load_model(logger: Logger) -> Model:
|
|
# model candidates:
|
|
# "runwayml/stable-diffusion-v1-5"
|
|
# "CompVis/stable-diffusion-v1-4"
|
|
# "stabilityai/stable-diffusion-2-1"
|
|
# "SG161222/Realistic_Vision_V2.0"
|
|
# "darkstorm2150/Protogen_x3.4_Official_Release"
|
|
# "darkstorm2150/Protogen_x5.8_Official_Release"
|
|
# "prompthero/openjourney"
|
|
# "naclbit/trinart_stable_diffusion_v2"
|
|
# "hakurei/waifu-diffusion"
|
|
model_name = "darkstorm2150/Protogen_x5.8_Official_Release"
|
|
# inpainting model candidates:
|
|
# "runwayml/stable-diffusion-inpainting"
|
|
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
|
|
|
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
|
|
if use_gpu:
|
|
model.set_low_memory_mode()
|
|
model.load_all()
|
|
|
|
return model
|
|
|
|
|
|
def backend(event_termination):
|
|
model = load_model(logger)
|
|
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
|
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
|
|
|
text2img.breakfast()
|
|
img2img.breakfast()
|
|
|
|
while not event_termination.is_set():
|
|
wait_for_seconds(1)
|
|
|
|
with memory_lock:
|
|
if len(local_job_stack) == 0:
|
|
continue
|
|
next_job = local_job_stack[0]
|
|
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
|
|
|
|
prompt = next_job[KEY_PROMPT.lower()]
|
|
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
|
|
|
|
config = Config().set_config(next_job)
|
|
|
|
try:
|
|
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
|
result_dict = text2img.lunch(
|
|
prompt=prompt, negative_prompt=negative_prompt, config=config
|
|
)
|
|
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
|
|
ref_img = next_job[REFERENCE_IMG]
|
|
result_dict = img2img.lunch(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
reference_image=ref_img,
|
|
config=config,
|
|
)
|
|
except BaseException as e:
|
|
logger.error("text2img.lunch error: {}".format(e))
|
|
local_job_stack.pop(0)
|
|
next_job[KEY_JOB_STATUS] = VALUE_JOB_FAILED
|
|
local_completed_jobs.append(next_job)
|
|
|
|
with memory_lock:
|
|
local_job_stack.pop(0)
|
|
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
|
|
next_job.update(result_dict)
|
|
local_completed_jobs.append(next_job)
|
|
|
|
logger.critical("stopped")
|
|
if len(local_job_stack) > 0:
|
|
logger.info(
|
|
"remaining {} jobs in stack: {}".format(
|
|
len(local_job_stack), local_job_stack
|
|
)
|
|
)
|
|
|
|
|
|
def main():
|
|
if fast_web_debugging:
|
|
try:
|
|
app.run(host="0.0.0.0")
|
|
except KeyboardInterrupt:
|
|
pass
|
|
return
|
|
thread = Thread(target=backend, args=(event_termination,))
|
|
thread.start()
|
|
# ugly solution for now
|
|
# TODO: use a database to track instead of internal memory
|
|
try:
|
|
app.run(host="0.0.0.0")
|
|
thread.join()
|
|
except KeyboardInterrupt:
|
|
event_termination.set()
|
|
thread.join(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|