diff --git a/BUILD b/BUILD index bf5105a..8a80bc9 100644 --- a/BUILD +++ b/BUILD @@ -11,5 +11,10 @@ par_binary( "//utilities:logger", "//utilities:model", "//utilities:text2img", + "//utilities:envvar", + "//utilities:times", + ], + data=[ + "templates/index.html", ], ) diff --git a/main.py b/main.py index 17d374f..13d7d24 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ import copy +import tempfile +import pkgutil import uuid from flask import jsonify from flask import Flask +from flask import render_template from flask import request from threading import Event from threading import Thread @@ -9,7 +12,6 @@ from threading import Lock from utilities.constants import API_KEY from utilities.constants import API_KEY_FOR_DEMO -from utilities.constants import BASE64IMAGE from utilities.constants import KEY_APP from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_PROMPT @@ -23,7 +25,6 @@ from utilities.constants import VALUE_APP from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_RUNNING from utilities.constants import VALUE_JOB_DONE -from utilities.web import web from utilities.envvar import get_env_var_with_default from utilities.envvar import get_env_var from utilities.times import wait_for_seconds @@ -33,35 +34,14 @@ from utilities.config import Config from utilities.text2img import Text2Img -def load_model(logger: Logger) -> Model: - # model candidates: - # "runwayml/stable-diffusion-v1-5" - # "CompVis/stable-diffusion-v1-4" - # "stabilityai/stable-diffusion-2-1" - # "SG161222/Realistic_Vision_V2.0" - # "darkstorm2150/Protogen_x3.4_Official_Release" - # "prompthero/openjourney" - # "naclbit/trinart_stable_diffusion_v2" - # "hakurei/waifu-diffusion" - model_name = "darkstorm2150/Protogen_x3.4_Official_Release" - # inpainting model candidates: - # "runwayml/stable-diffusion-inpainting" - inpainting_model_name = "runwayml/stable-diffusion-inpainting" - - model = Model(model_name, inpainting_model_name, logger) - model.set_low_memory_mode() - model.load_all() - - return model - - app = Flask(__name__) memory_lock = Lock() event_termination = Event() logger = Logger(name=LOGGER_NAME) +use_gpu = True local_job_stack = [] -local_completed_jobs = {} +local_completed_jobs = [] @app.route("/add_job", methods=["POST"]) @@ -90,11 +70,12 @@ def add_job(): logger.info("adding a new job with uuid {}..".format(req[UUID])) req[KEY_JOB_STATUS] = VALUE_JOB_PENDING + req["position"] = len(local_job_stack) + 1 with memory_lock: local_job_stack.append(req) - return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]}) + return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]}) @app.route("/cancel_job", methods=["POST"]) @@ -167,23 +148,20 @@ def get_jobs(): jobs = [] + all_job_stack = local_job_stack + local_completed_jobs with memory_lock: - for job_position in range(len(local_job_stack)): + for job_position in range(len(all_job_stack)): # filter on API_KEY - if local_job_stack[job_position][API_KEY] != req[API_KEY]: + if all_job_stack[job_position][API_KEY] != req[API_KEY]: continue # filter on UUID - if UUID in req and req[UUID] != local_job_stack[job_position][UUID]: + if UUID in req and req[UUID] != all_job_stack[job_position][UUID]: continue - job = copy.deepcopy(local_job_stack[job_position]) + job = copy.deepcopy(all_job_stack[job_position]) + if job[KEY_JOB_STATUS] == VALUE_JOB_DONE: + del job["position"] del job[API_KEY] - job["position"] = job_position + 1 jobs.append(job) - all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {}) - if UUID in req: - all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID], {}) - for key in all_matching_completed_jobs.keys(): - jobs.append(all_matching_completed_jobs[key]) if len(jobs) == 0: return ( @@ -192,22 +170,44 @@ def get_jobs(): ) return jsonify({"jobs": jobs}) + @app.route("/") def index(): - return web() + return render_template("index.html") + + +def load_model(logger: Logger) -> Model: + # model candidates: + # "runwayml/stable-diffusion-v1-5" + # "CompVis/stable-diffusion-v1-4" + # "stabilityai/stable-diffusion-2-1" + # "SG161222/Realistic_Vision_V2.0" + # "darkstorm2150/Protogen_x3.4_Official_Release" + # "prompthero/openjourney" + # "naclbit/trinart_stable_diffusion_v2" + # "hakurei/waifu-diffusion" + model_name = "darkstorm2150/Protogen_x3.4_Official_Release" + # inpainting model candidates: + # "runwayml/stable-diffusion-inpainting" + inpainting_model_name = "runwayml/stable-diffusion-inpainting" + + model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) + if use_gpu: + model.set_low_memory_mode() + model.load_all() + + return model def backend(event_termination): model = load_model(logger) - text2img = Text2Img(model, output_folder="/tmp", logger=logger) + text2img = Text2Img(model, logger=logger) text2img.breakfast() - while True: + while not event_termination.is_set(): wait_for_seconds(1) - if event_termination.is_set(): - break with memory_lock: if len(local_job_stack) == 0: continue @@ -219,17 +219,15 @@ def backend(event_termination): config = Config().set_config(next_job) - base64img = text2img.lunch( + result_dict = text2img.lunch( prompt=prompt, negative_prompt=negative_prompt, config=config ) with memory_lock: local_job_stack.pop(0) next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE - next_job[BASE64IMAGE] = base64img - if next_job[API_KEY] not in local_completed_jobs: - local_completed_jobs[next_job[API_KEY]] = {} - local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job + next_job.update(result_dict) + local_completed_jobs.append(next_job) logger.critical("stopped") if len(local_job_stack) > 0: @@ -241,12 +239,13 @@ def backend(event_termination): def main(): + # app.run(host="0.0.0.0") thread = Thread(target=backend, args=(event_termination,)) thread.start() # ugly solution for now # TODO: use a database to track instead of internal memory try: - app.run(host='0.0.0.0') + app.run(host="0.0.0.0") thread.join() except KeyboardInterrupt: event_termination.set() diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..828aa81 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,145 @@ + + +
+ + +