From f5d04126fd34eaa69bcdc95bed9b001749c10001 Mon Sep 17 00:00:00 2001 From: HappyZ Date: Fri, 28 Apr 2023 23:10:56 -0700 Subject: [PATCH] support cpu mode (very slow) and fix bug for second job submission --- BUILD | 5 ++ main.py | 93 +++++++++++++------------ templates/index.html | 145 +++++++++++++++++++++++++++++++++++++++ utilities/BUILD | 1 + utilities/model.py | 2 +- utilities/text2img.py | 19 +++-- utilities/web.py | 156 ------------------------------------------ 7 files changed, 213 insertions(+), 208 deletions(-) create mode 100644 templates/index.html delete mode 100644 utilities/web.py diff --git a/BUILD b/BUILD index bf5105a..8a80bc9 100644 --- a/BUILD +++ b/BUILD @@ -11,5 +11,10 @@ par_binary( "//utilities:logger", "//utilities:model", "//utilities:text2img", + "//utilities:envvar", + "//utilities:times", + ], + data=[ + "templates/index.html", ], ) diff --git a/main.py b/main.py index 17d374f..13d7d24 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ import copy +import tempfile +import pkgutil import uuid from flask import jsonify from flask import Flask +from flask import render_template from flask import request from threading import Event from threading import Thread @@ -9,7 +12,6 @@ from threading import Lock from utilities.constants import API_KEY from utilities.constants import API_KEY_FOR_DEMO -from utilities.constants import BASE64IMAGE from utilities.constants import KEY_APP from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_PROMPT @@ -23,7 +25,6 @@ from utilities.constants import VALUE_APP from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_RUNNING from utilities.constants import VALUE_JOB_DONE -from utilities.web import web from utilities.envvar import get_env_var_with_default from utilities.envvar import get_env_var from utilities.times import wait_for_seconds @@ -33,35 +34,14 @@ from utilities.config import Config from utilities.text2img import Text2Img -def load_model(logger: Logger) -> Model: - # model candidates: - # "runwayml/stable-diffusion-v1-5" - # "CompVis/stable-diffusion-v1-4" - # "stabilityai/stable-diffusion-2-1" - # "SG161222/Realistic_Vision_V2.0" - # "darkstorm2150/Protogen_x3.4_Official_Release" - # "prompthero/openjourney" - # "naclbit/trinart_stable_diffusion_v2" - # "hakurei/waifu-diffusion" - model_name = "darkstorm2150/Protogen_x3.4_Official_Release" - # inpainting model candidates: - # "runwayml/stable-diffusion-inpainting" - inpainting_model_name = "runwayml/stable-diffusion-inpainting" - - model = Model(model_name, inpainting_model_name, logger) - model.set_low_memory_mode() - model.load_all() - - return model - - app = Flask(__name__) memory_lock = Lock() event_termination = Event() logger = Logger(name=LOGGER_NAME) +use_gpu = True local_job_stack = [] -local_completed_jobs = {} +local_completed_jobs = [] @app.route("/add_job", methods=["POST"]) @@ -90,11 +70,12 @@ def add_job(): logger.info("adding a new job with uuid {}..".format(req[UUID])) req[KEY_JOB_STATUS] = VALUE_JOB_PENDING + req["position"] = len(local_job_stack) + 1 with memory_lock: local_job_stack.append(req) - return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]}) + return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]}) @app.route("/cancel_job", methods=["POST"]) @@ -167,23 +148,20 @@ def get_jobs(): jobs = [] + all_job_stack = local_job_stack + local_completed_jobs with memory_lock: - for job_position in range(len(local_job_stack)): + for job_position in range(len(all_job_stack)): # filter on API_KEY - if local_job_stack[job_position][API_KEY] != req[API_KEY]: + if all_job_stack[job_position][API_KEY] != req[API_KEY]: continue # filter on UUID - if UUID in req and req[UUID] != local_job_stack[job_position][UUID]: + if UUID in req and req[UUID] != all_job_stack[job_position][UUID]: continue - job = copy.deepcopy(local_job_stack[job_position]) + job = copy.deepcopy(all_job_stack[job_position]) + if job[KEY_JOB_STATUS] == VALUE_JOB_DONE: + del job["position"] del job[API_KEY] - job["position"] = job_position + 1 jobs.append(job) - all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {}) - if UUID in req: - all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID], {}) - for key in all_matching_completed_jobs.keys(): - jobs.append(all_matching_completed_jobs[key]) if len(jobs) == 0: return ( @@ -192,22 +170,44 @@ def get_jobs(): ) return jsonify({"jobs": jobs}) + @app.route("/") def index(): - return web() + return render_template("index.html") + + +def load_model(logger: Logger) -> Model: + # model candidates: + # "runwayml/stable-diffusion-v1-5" + # "CompVis/stable-diffusion-v1-4" + # "stabilityai/stable-diffusion-2-1" + # "SG161222/Realistic_Vision_V2.0" + # "darkstorm2150/Protogen_x3.4_Official_Release" + # "prompthero/openjourney" + # "naclbit/trinart_stable_diffusion_v2" + # "hakurei/waifu-diffusion" + model_name = "darkstorm2150/Protogen_x3.4_Official_Release" + # inpainting model candidates: + # "runwayml/stable-diffusion-inpainting" + inpainting_model_name = "runwayml/stable-diffusion-inpainting" + + model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu) + if use_gpu: + model.set_low_memory_mode() + model.load_all() + + return model def backend(event_termination): model = load_model(logger) - text2img = Text2Img(model, output_folder="/tmp", logger=logger) + text2img = Text2Img(model, logger=logger) text2img.breakfast() - while True: + while not event_termination.is_set(): wait_for_seconds(1) - if event_termination.is_set(): - break with memory_lock: if len(local_job_stack) == 0: continue @@ -219,17 +219,15 @@ def backend(event_termination): config = Config().set_config(next_job) - base64img = text2img.lunch( + result_dict = text2img.lunch( prompt=prompt, negative_prompt=negative_prompt, config=config ) with memory_lock: local_job_stack.pop(0) next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE - next_job[BASE64IMAGE] = base64img - if next_job[API_KEY] not in local_completed_jobs: - local_completed_jobs[next_job[API_KEY]] = {} - local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job + next_job.update(result_dict) + local_completed_jobs.append(next_job) logger.critical("stopped") if len(local_job_stack) > 0: @@ -241,12 +239,13 @@ def backend(event_termination): def main(): + # app.run(host="0.0.0.0") thread = Thread(target=backend, args=(event_termination,)) thread.start() # ugly solution for now # TODO: use a database to track instead of internal memory try: - app.run(host='0.0.0.0') + app.run(host="0.0.0.0") thread.join() except KeyboardInterrupt: event_termination.set() diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..828aa81 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,145 @@ + + + + + + Happy Diffusion (Private Access) | 9pm + + + + + + + +
+
+
+ + +
+
+ + +
Less than 77 words otherwise it'll be truncated
+
+
+ + +
Less than 77 words otherwise it'll be truncated
+
+ +
+ +
+
+ +
+ +
+
+ + +
+ + +
+ +
+
+
+ + + + + + + + \ No newline at end of file diff --git a/utilities/BUILD b/utilities/BUILD index 20b62ee..f489c5a 100644 --- a/utilities/BUILD +++ b/utilities/BUILD @@ -63,6 +63,7 @@ py_library( name="text2img", srcs=["text2img.py"], deps=[ + ":constants", ":config", ":logger", ":images", diff --git a/utilities/model.py b/utilities/model.py index 69918b2..3abec9b 100644 --- a/utilities/model.py +++ b/utilities/model.py @@ -32,7 +32,7 @@ class Model: self.__use_gpu = True logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0"))) self.__logger = logger - self.__torch_dtype = "auto" + self.__torch_dtype = torch.float64 # txt2img and img2img are always loaded together self.txt2img_pipeline = None diff --git a/utilities/text2img.py b/utilities/text2img.py index 910572a..7753711 100644 --- a/utilities/text2img.py +++ b/utilities/text2img.py @@ -1,8 +1,12 @@ import torch from typing import Union +from utilities.constants import BASE64IMAGE +from utilities.constants import KEY_SEED +from utilities.constants import KEY_WIDTH +from utilities.constants import KEY_HEIGHT +from utilities.constants import KEY_STEPS from utilities.config import Config -from utilities.images import save_image from utilities.logger import DummyLogger from utilities.memory import empty_memory_cache from utilities.model import Model @@ -22,6 +26,7 @@ class Text2Img: logger: DummyLogger = DummyLogger(), ): self.model = model + self.__device = "cpu" if not self.model.use_gpu() else "cuda" self.__output_folder = output_folder self.__logger = logger @@ -32,12 +37,12 @@ class Text2Img: def breakfast(self): pass - def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> str: + def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> dict: self.model.set_txt2img_scheduler(config.get_scheduler()) t = get_epoch_now() seed = config.get_seed() - generator = torch.Generator("cuda").manual_seed(seed) + generator = torch.Generator(self.__device).manual_seed(seed) self.__logger.info("current seed: {}".format(seed)) result = self.model.txt2img_pipeline( @@ -59,4 +64,10 @@ class Text2Img: empty_memory_cache() - return image_to_base64(result.images[0]) + return { + BASE64IMAGE: image_to_base64(result.images[0]), + KEY_SEED.lower(): seed, + KEY_WIDTH.lower(): config.get_width(), + KEY_HEIGHT.lower(): config.get_height(), + KEY_STEPS.lower(): config.get_steps(), + } diff --git a/utilities/web.py b/utilities/web.py deleted file mode 100644 index 87c3ac7..0000000 --- a/utilities/web.py +++ /dev/null @@ -1,156 +0,0 @@ -def javascript(): - return """ - - """ - -def stylesheet(): - return """ - """ - -def content(): - return """ -
-
- - -
Less than 77 words otherwise it'll be truncated
-
-
- - -
Less than 77 words otherwise it'll be truncated
-
- -
- -
-
- -
- -
-
- - -
- - -
- -
-
- """ - -def web(): - return """ - - - - - - Happy Diffusion (Private Access) | 9pm - - - - -{css} - - - -
{content}
- - - -{js} - - - """.format( - content=content(), - css=stylesheet(), - js=javascript(), - )