adds web API

This commit is contained in:
HappyZ 2023-04-28 17:00:18 -07:00
parent b0e07eecd4
commit b145e42375
8 changed files with 343 additions and 38 deletions

231
main.py
View File

@ -1,11 +1,37 @@
import copy
import uuid
from flask import jsonify
from flask import Flask
from flask import request
from threading import Event
from threading import Thread
from threading import Lock
from utilities.constants import API_KEY
from utilities.constants import API_KEY_FOR_DEMO
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import LOGGER_NAME
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
from utilities.logger import Logger
from utilities.model import Model
from utilities.config import Config
from utilities.text2img import Text2Img
def prepare(logger: Logger) -> [Model, Config]:
def load_model(logger: Logger) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
@ -24,32 +50,201 @@ def prepare(logger: Logger) -> [Model, Config]:
model.set_low_memory_mode()
model.load_all()
config = Config()
config.set_output_folder("/tmp/")
return model, config
return model
def main():
app = Flask(__name__)
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
model, config = prepare(logger)
text2img = Text2Img(model, config)
local_job_stack = []
local_completed_jobs = {}
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if len(local_job_stack) > MAX_JOB_NUMBER:
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
req[UUID] = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(req[UUID]))
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
with memory_lock:
local_job_stack.append(req)
return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("removing job with uuid {}..".format(req[UUID]))
cancel_job_position = None
with memory_lock:
for job_position in range(len(local_job_stack)):
if local_job_stack[job_position][UUID] == req[UUID]:
cancel_job_position = job_position
break
logger.info("foud {}".format(cancel_job_position))
if cancel_job_position is not None:
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
return "", 401
if (
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
== VALUE_JOB_RUNNING
):
logger.info(
"job at {} with uuid {} is running and cannot be cancelled".format(
cancel_job_position, req[UUID]
)
)
return (
jsonify(
{
"msg": "job {} is already running, unable to cancel".format(
req[UUID]
)
}
),
405,
)
del local_job_stack[cancel_job_position]
msg = "job with uuid {} removed".format(req[UUID])
logger.info(msg)
return jsonify({"msg": msg})
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if API_KEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
jobs = []
with memory_lock:
for job_position in range(len(local_job_stack)):
# filter on API_KEY
if local_job_stack[job_position][API_KEY] != req[API_KEY]:
continue
# filter on UUID
if UUID in req and req[UUID] != local_job_stack[job_position][UUID]:
continue
job = copy.deepcopy(local_job_stack[job_position])
del job[API_KEY]
job["position"] = job_position + 1
jobs.append(job)
all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {})
if UUID in req:
all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID])
for key in all_matching_completed_jobs.keys():
jobs.append(all_matching_completed_jobs[key])
if len(jobs) == 0:
return (
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
404,
)
return jsonify({"jobs": jobs})
def backend(event_termination):
model = load_model(logger)
text2img = Text2Img(model, output_folder="/tmp", logger=logger)
text2img.breakfast()
while True:
try:
prompt = input("Write prompt: ")
if not prompt:
prompt = "man riding a horse in space"
negative_prompt = input("Write negative prompt: ")
if not negative_prompt:
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
text2img.lunch(prompt=prompt, negative_prompt=negative_prompt)
except KeyboardInterrupt:
wait_for_seconds(1)
if event_termination.is_set():
break
except BaseException:
raise
with memory_lock:
if len(local_job_stack) == 0:
continue
next_job = local_job_stack[0]
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
prompt = next_job[KEY_PROMPT.lower()]
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
config = Config().set_config(next_job)
base64img = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
with memory_lock:
local_job_stack.pop(0)
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
next_job[BASE64IMAGE] = base64img
if next_job[API_KEY] not in local_completed_jobs:
local_completed_jobs[next_job[API_KEY]] = {}
local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job
logger.critical("stopped")
if len(local_job_stack) > 0:
logger.info(
"remaining {} jobs in stack: {}".format(
len(local_job_stack), local_job_stack
)
)
def main():
thread = Thread(target=backend, args=(event_termination,))
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run()
thread.join()
except KeyboardInterrupt:
event_termination.set()
thread.join()
if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ accelerate
colorlog
diffusers
numpy
Flask
Pillow
scikit-image
torch

View File

@ -16,6 +16,18 @@ py_library(
srcs=["constants.py"],
)
py_library(
name = "envvar",
srcs = ["envvar.py"],
)
py_test(
name = "envvar_test",
srcs = ["envvar_test.py"],
deps = [":envvar"],
)
py_library(
name = "images",
srcs = ["images.py"],

View File

@ -22,6 +22,7 @@ from utilities.constants import KEY_STEPS
from utilities.constants import VALUE_STEPS_DEFAULT
from utilities.constants import KEY_WIDTH
from utilities.constants import VALUE_WIDTH_DEFAULT
from utilities.constants import OPTIONAL_KEYS
from utilities.logger import DummyLogger
@ -37,12 +38,20 @@ class Config:
def get_config(self) -> dict:
return self.__config
def set_config(self, config: dict):
for key in config:
if key not in OPTIONAL_KEYS:
continue
self.__config[key.upper()] = config[key]
return self
def get_output_folder(self) -> str:
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
def set_output_folder(self, folder:str):
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
self.__config[KEY_OUTPUT_FOLDER] = folder
return self
def get_guidance_scale(self) -> float:
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
@ -50,6 +59,7 @@ class Config:
def set_guidance_scale(self, scale: float):
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
self.__config[KEY_GUIDANCE_SCALE] = scale
return self
def get_height(self) -> int:
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
@ -57,6 +67,7 @@ class Config:
def set_height(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
self.__config[KEY_HEIGHT] = value
return self
def get_preview(self) -> bool:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
@ -64,6 +75,7 @@ class Config:
def set_preview(self, boolean: bool):
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
self.__config[KEY_PREVIEW] = boolean
return self
def get_scheduler(self) -> str:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
@ -73,6 +85,7 @@ class Config:
scheduler = VALUE_SCHEDULER_DEFAULT
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler
return self
def get_seed(self) -> int:
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
@ -84,6 +97,7 @@ class Config:
def set_seed(self, seed: int):
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
self.__config[KEY_SEED] = seed
return self
def get_steps(self) -> int:
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
@ -91,6 +105,7 @@ class Config:
def set_steps(self, steps: int):
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
self.__config[KEY_STEPS] = steps
return self
def get_width(self) -> int:
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
@ -98,3 +113,4 @@ class Config:
def set_width(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
self.__config[KEY_WIDTH] = value
return self

View File

@ -1,4 +1,8 @@
LOGGER_NAME = "main"
MAX_JOB_NUMBER = 10
KEY_APP = "APP"
VALUE_APP = "demo"
KEY_OUTPUT_FOLDER = "OUTFOLDER"
VALUE_OUTPUT_FOLDER_DEFAULT = ""
@ -26,5 +30,33 @@ VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
KEY_PROMPT = "PROMPT"
KEY_NEG_PROMPT = "NEG_PROMPT"
KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True
# REST API Keys
API_KEY = "api_key"
API_KEY_FOR_DEMO = "demo"
UUID = "uuid"
BASE64IMAGE = "img"
KEY_JOB_STATUS = "status"
VALUE_JOB_PENDING = "pending"
VALUE_JOB_RUNNING = "running"
VALUE_JOB_DONE = "done"
REQUIRED_KEYS = [
API_KEY.lower(),
KEY_PROMPT.lower(),
]
OPTIONAL_KEYS = [
KEY_NEG_PROMPT.lower(),
KEY_SEED.lower(),
KEY_WIDTH.lower(),
KEY_HEIGHT.lower(),
KEY_GUIDANCE_SCALE.lower(),
KEY_STEPS.lower(),
KEY_SCHEDULER.lower(),
]

14
utilities/envvar.py Normal file
View File

@ -0,0 +1,14 @@
import os
from typing import Optional
def get_env_vars() -> dict:
return dict(os.environ)
def get_env_var(key: str) -> Optional[str]:
return os.getenv(key)
def get_env_var_with_default(key: str, default_value: str) -> str:
return os.getenv(key, default_value)

34
utilities/envvar_test.py Normal file
View File

@ -0,0 +1,34 @@
import os
import unittest
from utilities.envvar import get_env_var
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_vars
class TestEnvVar(unittest.TestCase):
@classmethod
def setUpClass(self):
self.key = "TEST_ENV_VAR"
self.value = "1234"
os.environ[self.key] = self.value
def test_existed_vars(self):
env_vars = get_env_vars()
self.assertTrue(self.key in env_vars)
self.assertTrue(get_env_var(self.key) == self.value)
def test_nonexisted_vars(self):
nonexist_key = "TEST_ENV_VAR_RANDOM_STUFF"
self.assertTrue(get_env_var(nonexist_key) is None)
self.assertTrue(get_env_var_with_default(
nonexist_key, nonexist_key) == nonexist_key)
@classmethod
def tearDownClass(self):
del os.environ[self.key]
if __name__ == '__main__':
unittest.main()

View File

@ -7,6 +7,7 @@ from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
class Text2Img:
@ -17,45 +18,45 @@ class Text2Img:
def __init__(
self,
model: Model,
config: Union[Config, None],
output_folder: str = "",
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.config = config
self.__output_folder = output_folder
self.__logger = logger
def update_config(self, config: Config):
self.config = config
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self):
self.model.set_txt2img_scheduler(self.config.get_scheduler())
pass
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> str:
self.model.set_txt2img_scheduler(config.get_scheduler())
def lunch(self, prompt: str, negative_prompt: str = ""):
t = get_epoch_now()
seed = self.config.get_seed()
self.__logger.info("current seed: {}".format(seed))
seed = config.get_seed()
generator = torch.Generator("cuda").manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
result = self.model.txt2img_pipeline(
prompt=prompt,
width=self.config.get_width(),
height=self.config.get_height(),
width=config.get_width(),
height=config.get_height(),
negative_prompt=negative_prompt,
guidance_scale=self.config.get_guidance_scale(),
num_inference_steps=self.config.get_steps(),
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
out_filepath = "{}/{}.png".format(self.config.get_output_folder(), t)
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result.images[0].save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()
return image_to_base64(result.images[0])