import copy import uuid from flask import jsonify from flask import Flask from flask import request from threading import Event from threading import Thread from threading import Lock from utilities.constants import API_KEY from utilities.constants import API_KEY_FOR_DEMO from utilities.constants import BASE64IMAGE from utilities.constants import KEY_APP from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_PROMPT from utilities.constants import KEY_NEG_PROMPT from utilities.constants import LOGGER_NAME from utilities.constants import MAX_JOB_NUMBER from utilities.constants import OPTIONAL_KEYS from utilities.constants import REQUIRED_KEYS from utilities.constants import UUID from utilities.constants import VALUE_APP from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_RUNNING from utilities.envvar import get_env_var_with_default from utilities.envvar import get_env_var from utilities.times import wait_for_seconds from utilities.logger import Logger from utilities.model import Model from utilities.config import Config from utilities.text2img import Text2Img def load_model(logger: Logger) -> Model: # model candidates: # "runwayml/stable-diffusion-v1-5" # "CompVis/stable-diffusion-v1-4" # "stabilityai/stable-diffusion-2-1" # "SG161222/Realistic_Vision_V2.0" # "darkstorm2150/Protogen_x3.4_Official_Release" # "prompthero/openjourney" # "naclbit/trinart_stable_diffusion_v2" # "hakurei/waifu-diffusion" model_name = "darkstorm2150/Protogen_x3.4_Official_Release" # inpainting model candidates: # "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting" model = Model(model_name, inpainting_model_name, logger) model.set_low_memory_mode() model.load_all() return model app = Flask(__name__) memory_lock = Lock() event_termination = Event() logger = Logger(name=LOGGER_NAME) local_job_stack = [] local_completed_jobs = {} @app.route("/add_job", methods=["POST"]) def add_job(): req = request.get_json() if API_KEY not in req: return "", 401 if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: if req[API_KEY] != API_KEY_FOR_DEMO: return "", 401 else: # TODO: add logic to validate app key with a particular user return "", 401 for key in req.keys(): if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS): return jsonify({"msg": "provided one or more unrecognized keys"}), 404 for required_key in REQUIRED_KEYS: if required_key not in req: return jsonify({"msg": "missing one or more required keys"}), 404 if len(local_job_stack) > MAX_JOB_NUMBER: return jsonify({"msg": "too many jobs in queue, please wait"}), 500 req[UUID] = str(uuid.uuid4()) logger.info("adding a new job with uuid {}..".format(req[UUID])) req[KEY_JOB_STATUS] = VALUE_JOB_PENDING with memory_lock: local_job_stack.append(req) return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]}) @app.route("/cancel_job", methods=["POST"]) def cancel_job(): req = request.get_json() if API_KEY not in req: return "", 401 if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: if req[API_KEY] != API_KEY_FOR_DEMO: return "", 401 else: # TODO: add logic to validate app key with a particular user return "", 401 if UUID not in req: return jsonify({"msg": "missing uuid"}), 404 logger.info("removing job with uuid {}..".format(req[UUID])) cancel_job_position = None with memory_lock: for job_position in range(len(local_job_stack)): if local_job_stack[job_position][UUID] == req[UUID]: cancel_job_position = job_position break logger.info("foud {}".format(cancel_job_position)) if cancel_job_position is not None: if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]: return "", 401 if ( local_job_stack[cancel_job_position][KEY_JOB_STATUS] == VALUE_JOB_RUNNING ): logger.info( "job at {} with uuid {} is running and cannot be cancelled".format( cancel_job_position, req[UUID] ) ) return ( jsonify( { "msg": "job {} is already running, unable to cancel".format( req[UUID] ) } ), 405, ) del local_job_stack[cancel_job_position] msg = "job with uuid {} removed".format(req[UUID]) logger.info(msg) return jsonify({"msg": msg}) return ( jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}), 404, ) @app.route("/get_jobs", methods=["POST"]) def get_jobs(): req = request.get_json() if API_KEY not in req: return "", 401 if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: if req[API_KEY] != API_KEY_FOR_DEMO: return "", 401 else: # TODO: add logic to validate app key with a particular user return "", 401 jobs = [] with memory_lock: for job_position in range(len(local_job_stack)): # filter on API_KEY if local_job_stack[job_position][API_KEY] != req[API_KEY]: continue # filter on UUID if UUID in req and req[UUID] != local_job_stack[job_position][UUID]: continue job = copy.deepcopy(local_job_stack[job_position]) del job[API_KEY] job["position"] = job_position + 1 jobs.append(job) all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {}) if UUID in req: all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID]) for key in all_matching_completed_jobs.keys(): jobs.append(all_matching_completed_jobs[key]) if len(jobs) == 0: return ( jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}), 404, ) return jsonify({"jobs": jobs}) def backend(event_termination): model = load_model(logger) text2img = Text2Img(model, output_folder="/tmp", logger=logger) text2img.breakfast() while True: wait_for_seconds(1) if event_termination.is_set(): break with memory_lock: if len(local_job_stack) == 0: continue next_job = local_job_stack[0] next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING prompt = next_job[KEY_PROMPT.lower()] negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "") config = Config().set_config(next_job) base64img = text2img.lunch( prompt=prompt, negative_prompt=negative_prompt, config=config ) with memory_lock: local_job_stack.pop(0) next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE next_job[BASE64IMAGE] = base64img if next_job[API_KEY] not in local_completed_jobs: local_completed_jobs[next_job[API_KEY]] = {} local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job logger.critical("stopped") if len(local_job_stack) > 0: logger.info( "remaining {} jobs in stack: {}".format( len(local_job_stack), local_job_stack ) ) def main(): thread = Thread(target=backend, args=(event_termination,)) thread.start() # ugly solution for now # TODO: use a database to track instead of internal memory try: app.run() thread.join() except KeyboardInterrupt: event_termination.set() thread.join() if __name__ == "__main__": main()