adds web API

This commit is contained in:
HappyZ 2023-04-28 17:00:18 -07:00
parent b0e07eecd4
commit b145e42375
8 changed files with 343 additions and 38 deletions

231
main.py
View File

@ -1,11 +1,37 @@
import copy
import uuid
from flask import jsonify
from flask import Flask
from flask import request
from threading import Event
from threading import Thread
from threading import Lock
from utilities.constants import API_KEY
from utilities.constants import API_KEY_FOR_DEMO
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import LOGGER_NAME from utilities.constants import LOGGER_NAME
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
from utilities.logger import Logger from utilities.logger import Logger
from utilities.model import Model from utilities.model import Model
from utilities.config import Config from utilities.config import Config
from utilities.text2img import Text2Img from utilities.text2img import Text2Img
def prepare(logger: Logger) -> [Model, Config]: def load_model(logger: Logger) -> Model:
# model candidates: # model candidates:
# "runwayml/stable-diffusion-v1-5" # "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4" # "CompVis/stable-diffusion-v1-4"
@ -24,32 +50,201 @@ def prepare(logger: Logger) -> [Model, Config]:
model.set_low_memory_mode() model.set_low_memory_mode()
model.load_all() model.load_all()
config = Config() return model
config.set_output_folder("/tmp/")
return model, config
def main(): app = Flask(__name__)
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME) logger = Logger(name=LOGGER_NAME)
model, config = prepare(logger) local_job_stack = []
text2img = Text2Img(model, config) local_completed_jobs = {}
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if len(local_job_stack) > MAX_JOB_NUMBER:
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
req[UUID] = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(req[UUID]))
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
with memory_lock:
local_job_stack.append(req)
return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("removing job with uuid {}..".format(req[UUID]))
cancel_job_position = None
with memory_lock:
for job_position in range(len(local_job_stack)):
if local_job_stack[job_position][UUID] == req[UUID]:
cancel_job_position = job_position
break
logger.info("foud {}".format(cancel_job_position))
if cancel_job_position is not None:
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
return "", 401
if (
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
== VALUE_JOB_RUNNING
):
logger.info(
"job at {} with uuid {} is running and cannot be cancelled".format(
cancel_job_position, req[UUID]
)
)
return (
jsonify(
{
"msg": "job {} is already running, unable to cancel".format(
req[UUID]
)
}
),
405,
)
del local_job_stack[cancel_job_position]
msg = "job with uuid {} removed".format(req[UUID])
logger.info(msg)
return jsonify({"msg": msg})
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
jobs = []
with memory_lock:
for job_position in range(len(local_job_stack)):
# filter on API_KEY
if local_job_stack[job_position][API_KEY] != req[API_KEY]:
continue
# filter on UUID
if UUID in req and req[UUID] != local_job_stack[job_position][UUID]:
continue
job = copy.deepcopy(local_job_stack[job_position])
del job[API_KEY]
job["position"] = job_position + 1
jobs.append(job)
all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {})
if UUID in req:
all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID])
for key in all_matching_completed_jobs.keys():
jobs.append(all_matching_completed_jobs[key])
if len(jobs) == 0:
return (
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
404,
)
return jsonify({"jobs": jobs})
def backend(event_termination):
model = load_model(logger)
text2img = Text2Img(model, output_folder="/tmp", logger=logger)
text2img.breakfast() text2img.breakfast()
while True: while True:
try: wait_for_seconds(1)
prompt = input("Write prompt: ")
if not prompt: if event_termination.is_set():
prompt = "man riding a horse in space"
negative_prompt = input("Write negative prompt: ")
if not negative_prompt:
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
text2img.lunch(prompt=prompt, negative_prompt=negative_prompt)
except KeyboardInterrupt:
break break
except BaseException: with memory_lock:
raise if len(local_job_stack) == 0:
continue
next_job = local_job_stack[0]
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
prompt = next_job[KEY_PROMPT.lower()]
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
config = Config().set_config(next_job)
base64img = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
with memory_lock:
local_job_stack.pop(0)
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
next_job[BASE64IMAGE] = base64img
if next_job[API_KEY] not in local_completed_jobs:
local_completed_jobs[next_job[API_KEY]] = {}
local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job
logger.critical("stopped")
if len(local_job_stack) > 0:
logger.info(
"remaining {} jobs in stack: {}".format(
len(local_job_stack), local_job_stack
)
)
def main():
thread = Thread(target=backend, args=(event_termination,))
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run()
thread.join()
except KeyboardInterrupt:
event_termination.set()
thread.join()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ accelerate
colorlog colorlog
diffusers diffusers
numpy numpy
Flask
Pillow Pillow
scikit-image scikit-image
torch torch

View File

@ -16,6 +16,18 @@ py_library(
srcs=["constants.py"], srcs=["constants.py"],
) )
py_library(
name = "envvar",
srcs = ["envvar.py"],
)
py_test(
name = "envvar_test",
srcs = ["envvar_test.py"],
deps = [":envvar"],
)
py_library( py_library(
name = "images", name = "images",
srcs = ["images.py"], srcs = ["images.py"],

View File

@ -22,6 +22,7 @@ from utilities.constants import KEY_STEPS
from utilities.constants import VALUE_STEPS_DEFAULT from utilities.constants import VALUE_STEPS_DEFAULT
from utilities.constants import KEY_WIDTH from utilities.constants import KEY_WIDTH
from utilities.constants import VALUE_WIDTH_DEFAULT from utilities.constants import VALUE_WIDTH_DEFAULT
from utilities.constants import OPTIONAL_KEYS
from utilities.logger import DummyLogger from utilities.logger import DummyLogger
@ -37,12 +38,20 @@ class Config:
def get_config(self) -> dict: def get_config(self) -> dict:
return self.__config return self.__config
def set_config(self, config: dict):
for key in config:
if key not in OPTIONAL_KEYS:
continue
self.__config[key.upper()] = config[key]
return self
def get_output_folder(self) -> str: def get_output_folder(self) -> str:
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT) return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
def set_output_folder(self, folder:str): def set_output_folder(self, folder:str):
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder)) self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
self.__config[KEY_OUTPUT_FOLDER] = folder self.__config[KEY_OUTPUT_FOLDER] = folder
return self
def get_guidance_scale(self) -> float: def get_guidance_scale(self) -> float:
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT) return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
@ -50,6 +59,7 @@ class Config:
def set_guidance_scale(self, scale: float): def set_guidance_scale(self, scale: float):
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale)) self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
self.__config[KEY_GUIDANCE_SCALE] = scale self.__config[KEY_GUIDANCE_SCALE] = scale
return self
def get_height(self) -> int: def get_height(self) -> int:
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT) return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
@ -57,6 +67,7 @@ class Config:
def set_height(self, value: int): def set_height(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)) self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
self.__config[KEY_HEIGHT] = value self.__config[KEY_HEIGHT] = value
return self
def get_preview(self) -> bool: def get_preview(self) -> bool:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT) return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
@ -64,6 +75,7 @@ class Config:
def set_preview(self, boolean: bool): def set_preview(self, boolean: bool):
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)) self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
self.__config[KEY_PREVIEW] = boolean self.__config[KEY_PREVIEW] = boolean
return self
def get_scheduler(self) -> str: def get_scheduler(self) -> str:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT) return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
@ -73,6 +85,7 @@ class Config:
scheduler = VALUE_SCHEDULER_DEFAULT scheduler = VALUE_SCHEDULER_DEFAULT
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler)) self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler self.__config[KEY_SCHEDULER] = scheduler
return self
def get_seed(self) -> int: def get_seed(self) -> int:
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT) seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
@ -84,6 +97,7 @@ class Config:
def set_seed(self, seed: int): def set_seed(self, seed: int):
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)) self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
self.__config[KEY_SEED] = seed self.__config[KEY_SEED] = seed
return self
def get_steps(self) -> int: def get_steps(self) -> int:
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT) return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
@ -91,6 +105,7 @@ class Config:
def set_steps(self, steps: int): def set_steps(self, steps: int):
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)) self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
self.__config[KEY_STEPS] = steps self.__config[KEY_STEPS] = steps
return self
def get_width(self) -> int: def get_width(self) -> int:
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT) return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
@ -98,3 +113,4 @@ class Config:
def set_width(self, value: int): def set_width(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)) self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
self.__config[KEY_WIDTH] = value self.__config[KEY_WIDTH] = value
return self

View File

@ -1,4 +1,8 @@
LOGGER_NAME = "main" LOGGER_NAME = "main"
MAX_JOB_NUMBER = 10
KEY_APP = "APP"
VALUE_APP = "demo"
KEY_OUTPUT_FOLDER = "OUTFOLDER" KEY_OUTPUT_FOLDER = "OUTFOLDER"
VALUE_OUTPUT_FOLDER_DEFAULT = "" VALUE_OUTPUT_FOLDER_DEFAULT = ""
@ -26,5 +30,33 @@ VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
VALUE_SCHEDULER_PNDM = "PNDMScheduler" VALUE_SCHEDULER_PNDM = "PNDMScheduler"
VALUE_SCHEDULER_DDIM = "DDIMScheduler" VALUE_SCHEDULER_DDIM = "DDIMScheduler"
KEY_PROMPT = "PROMPT"
KEY_NEG_PROMPT = "NEG_PROMPT"
KEY_PREVIEW = "PREVIEW" KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True VALUE_PREVIEW_DEFAULT = True
# REST API Keys
API_KEY = "api_key"
API_KEY_FOR_DEMO = "demo"
UUID = "uuid"
BASE64IMAGE = "img"
KEY_JOB_STATUS = "status"
VALUE_JOB_PENDING = "pending"
VALUE_JOB_RUNNING = "running"
VALUE_JOB_DONE = "done"
REQUIRED_KEYS = [
API_KEY.lower(),
KEY_PROMPT.lower(),
]
OPTIONAL_KEYS = [
KEY_NEG_PROMPT.lower(),
KEY_SEED.lower(),
KEY_WIDTH.lower(),
KEY_HEIGHT.lower(),
KEY_GUIDANCE_SCALE.lower(),
KEY_STEPS.lower(),
KEY_SCHEDULER.lower(),
]

14
utilities/envvar.py Normal file
View File

@ -0,0 +1,14 @@
import os
from typing import Optional
def get_env_vars() -> dict:
return dict(os.environ)
def get_env_var(key: str) -> Optional[str]:
return os.getenv(key)
def get_env_var_with_default(key: str, default_value: str) -> str:
return os.getenv(key, default_value)

34
utilities/envvar_test.py Normal file
View File

@ -0,0 +1,34 @@
import os
import unittest
from utilities.envvar import get_env_var
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_vars
class TestEnvVar(unittest.TestCase):
@classmethod
def setUpClass(self):
self.key = "TEST_ENV_VAR"
self.value = "1234"
os.environ[self.key] = self.value
def test_existed_vars(self):
env_vars = get_env_vars()
self.assertTrue(self.key in env_vars)
self.assertTrue(get_env_var(self.key) == self.value)
def test_nonexisted_vars(self):
nonexist_key = "TEST_ENV_VAR_RANDOM_STUFF"
self.assertTrue(get_env_var(nonexist_key) is None)
self.assertTrue(get_env_var_with_default(
nonexist_key, nonexist_key) == nonexist_key)
@classmethod
def tearDownClass(self):
del os.environ[self.key]
if __name__ == '__main__':
unittest.main()

View File

@ -7,6 +7,7 @@ from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache from utilities.memory import empty_memory_cache
from utilities.model import Model from utilities.model import Model
from utilities.times import get_epoch_now from utilities.times import get_epoch_now
from utilities.images import image_to_base64
class Text2Img: class Text2Img:
@ -17,45 +18,45 @@ class Text2Img:
def __init__( def __init__(
self, self,
model: Model, model: Model,
config: Union[Config, None], output_folder: str = "",
logger: DummyLogger = DummyLogger(), logger: DummyLogger = DummyLogger(),
): ):
self.model = model self.model = model
self.config = config self.__output_folder = output_folder
self.__logger = logger self.__logger = logger
def update_config(self, config: Config):
self.config = config
def brunch(self, prompt: str, negative_prompt: str = ""): def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast() self.breakfast()
self.lunch(prompt, negative_prompt) self.lunch(prompt, negative_prompt)
def breakfast(self): def breakfast(self):
self.model.set_txt2img_scheduler(self.config.get_scheduler()) pass
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> str:
self.model.set_txt2img_scheduler(config.get_scheduler())
def lunch(self, prompt: str, negative_prompt: str = ""):
t = get_epoch_now() t = get_epoch_now()
seed = self.config.get_seed() seed = config.get_seed()
self.__logger.info("current seed: {}".format(seed))
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
result = self.model.txt2img_pipeline( result = self.model.txt2img_pipeline(
prompt=prompt, prompt=prompt,
width=self.config.get_width(), width=config.get_width(),
height=self.config.get_height(), height=config.get_height(),
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
guidance_scale=self.config.get_guidance_scale(), guidance_scale=config.get_guidance_scale(),
num_inference_steps=self.config.get_steps(), num_inference_steps=config.get_steps(),
generator=generator, generator=generator,
callback=None, callback=None,
callback_steps=10, callback_steps=10,
) )
out_filepath = "{}/{}.png".format(self.config.get_output_folder(), t) if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result.images[0].save(out_filepath) result.images[0].save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath)) self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache() empty_memory_cache()
return image_to_base64(result.images[0])