adds web API
This commit is contained in:
parent
b0e07eecd4
commit
b145e42375
233
main.py
233
main.py
|
|
@ -1,11 +1,37 @@
|
|||
import copy
|
||||
import uuid
|
||||
from flask import jsonify
|
||||
from flask import Flask
|
||||
from flask import request
|
||||
from threading import Event
|
||||
from threading import Thread
|
||||
from threading import Lock
|
||||
|
||||
from utilities.constants import API_KEY
|
||||
from utilities.constants import API_KEY_FOR_DEMO
|
||||
from utilities.constants import BASE64IMAGE
|
||||
from utilities.constants import KEY_APP
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import KEY_PROMPT
|
||||
from utilities.constants import KEY_NEG_PROMPT
|
||||
from utilities.constants import LOGGER_NAME
|
||||
from utilities.constants import MAX_JOB_NUMBER
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.constants import REQUIRED_KEYS
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import VALUE_APP
|
||||
from utilities.constants import VALUE_JOB_PENDING
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.envvar import get_env_var_with_default
|
||||
from utilities.envvar import get_env_var
|
||||
from utilities.times import wait_for_seconds
|
||||
from utilities.logger import Logger
|
||||
from utilities.model import Model
|
||||
from utilities.config import Config
|
||||
from utilities.text2img import Text2Img
|
||||
|
||||
|
||||
def prepare(logger: Logger) -> [Model, Config]:
|
||||
def load_model(logger: Logger) -> Model:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
# "CompVis/stable-diffusion-v1-4"
|
||||
|
|
@ -24,32 +50,201 @@ def prepare(logger: Logger) -> [Model, Config]:
|
|||
model.set_low_memory_mode()
|
||||
model.load_all()
|
||||
|
||||
config = Config()
|
||||
config.set_output_folder("/tmp/")
|
||||
return model, config
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
app = Flask(__name__)
|
||||
memory_lock = Lock()
|
||||
event_termination = Event()
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
|
||||
model, config = prepare(logger)
|
||||
text2img = Text2Img(model, config)
|
||||
local_job_stack = []
|
||||
local_completed_jobs = {}
|
||||
|
||||
|
||||
@app.route("/add_job", methods=["POST"])
|
||||
def add_job():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
return "", 401
|
||||
|
||||
for key in req.keys():
|
||||
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
|
||||
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
|
||||
for required_key in REQUIRED_KEYS:
|
||||
if required_key not in req:
|
||||
return jsonify({"msg": "missing one or more required keys"}), 404
|
||||
|
||||
if len(local_job_stack) > MAX_JOB_NUMBER:
|
||||
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
||||
|
||||
req[UUID] = str(uuid.uuid4())
|
||||
logger.info("adding a new job with uuid {}..".format(req[UUID]))
|
||||
|
||||
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
|
||||
|
||||
with memory_lock:
|
||||
local_job_stack.append(req)
|
||||
|
||||
return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]})
|
||||
|
||||
|
||||
@app.route("/cancel_job", methods=["POST"])
|
||||
def cancel_job():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
return "", 401
|
||||
|
||||
if UUID not in req:
|
||||
return jsonify({"msg": "missing uuid"}), 404
|
||||
|
||||
logger.info("removing job with uuid {}..".format(req[UUID]))
|
||||
|
||||
cancel_job_position = None
|
||||
with memory_lock:
|
||||
for job_position in range(len(local_job_stack)):
|
||||
if local_job_stack[job_position][UUID] == req[UUID]:
|
||||
cancel_job_position = job_position
|
||||
break
|
||||
logger.info("foud {}".format(cancel_job_position))
|
||||
if cancel_job_position is not None:
|
||||
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
|
||||
return "", 401
|
||||
if (
|
||||
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
|
||||
== VALUE_JOB_RUNNING
|
||||
):
|
||||
logger.info(
|
||||
"job at {} with uuid {} is running and cannot be cancelled".format(
|
||||
cancel_job_position, req[UUID]
|
||||
)
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"msg": "job {} is already running, unable to cancel".format(
|
||||
req[UUID]
|
||||
)
|
||||
}
|
||||
),
|
||||
405,
|
||||
)
|
||||
del local_job_stack[cancel_job_position]
|
||||
msg = "job with uuid {} removed".format(req[UUID])
|
||||
logger.info(msg)
|
||||
return jsonify({"msg": msg})
|
||||
return (
|
||||
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
||||
404,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/get_jobs", methods=["POST"])
|
||||
def get_jobs():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
return "", 401
|
||||
|
||||
jobs = []
|
||||
|
||||
with memory_lock:
|
||||
for job_position in range(len(local_job_stack)):
|
||||
# filter on API_KEY
|
||||
if local_job_stack[job_position][API_KEY] != req[API_KEY]:
|
||||
continue
|
||||
# filter on UUID
|
||||
if UUID in req and req[UUID] != local_job_stack[job_position][UUID]:
|
||||
continue
|
||||
job = copy.deepcopy(local_job_stack[job_position])
|
||||
del job[API_KEY]
|
||||
job["position"] = job_position + 1
|
||||
jobs.append(job)
|
||||
all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {})
|
||||
if UUID in req:
|
||||
all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID])
|
||||
for key in all_matching_completed_jobs.keys():
|
||||
jobs.append(all_matching_completed_jobs[key])
|
||||
|
||||
if len(jobs) == 0:
|
||||
return (
|
||||
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
|
||||
404,
|
||||
)
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
def backend(event_termination):
|
||||
model = load_model(logger)
|
||||
text2img = Text2Img(model, output_folder="/tmp", logger=logger)
|
||||
|
||||
text2img.breakfast()
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input("Write prompt: ")
|
||||
if not prompt:
|
||||
prompt = "man riding a horse in space"
|
||||
negative_prompt = input("Write negative prompt: ")
|
||||
if not negative_prompt:
|
||||
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||
text2img.lunch(prompt=prompt, negative_prompt=negative_prompt)
|
||||
except KeyboardInterrupt:
|
||||
wait_for_seconds(1)
|
||||
|
||||
if event_termination.is_set():
|
||||
break
|
||||
except BaseException:
|
||||
raise
|
||||
with memory_lock:
|
||||
if len(local_job_stack) == 0:
|
||||
continue
|
||||
next_job = local_job_stack[0]
|
||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
|
||||
|
||||
prompt = next_job[KEY_PROMPT.lower()]
|
||||
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
|
||||
|
||||
config = Config().set_config(next_job)
|
||||
|
||||
base64img = text2img.lunch(
|
||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||
)
|
||||
|
||||
with memory_lock:
|
||||
local_job_stack.pop(0)
|
||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
|
||||
next_job[BASE64IMAGE] = base64img
|
||||
if next_job[API_KEY] not in local_completed_jobs:
|
||||
local_completed_jobs[next_job[API_KEY]] = {}
|
||||
local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job
|
||||
|
||||
logger.critical("stopped")
|
||||
if len(local_job_stack) > 0:
|
||||
logger.info(
|
||||
"remaining {} jobs in stack: {}".format(
|
||||
len(local_job_stack), local_job_stack
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
thread = Thread(target=backend, args=(event_termination,))
|
||||
thread.start()
|
||||
# ugly solution for now
|
||||
# TODO: use a database to track instead of internal memory
|
||||
try:
|
||||
app.run()
|
||||
thread.join()
|
||||
except KeyboardInterrupt:
|
||||
event_termination.set()
|
||||
thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ accelerate
|
|||
colorlog
|
||||
diffusers
|
||||
numpy
|
||||
Flask
|
||||
Pillow
|
||||
scikit-image
|
||||
torch
|
||||
|
|
|
|||
|
|
@ -16,6 +16,18 @@ py_library(
|
|||
srcs=["constants.py"],
|
||||
)
|
||||
|
||||
|
||||
py_library(
|
||||
name = "envvar",
|
||||
srcs = ["envvar.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "envvar_test",
|
||||
srcs = ["envvar_test.py"],
|
||||
deps = [":envvar"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "images",
|
||||
srcs = ["images.py"],
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from utilities.constants import KEY_STEPS
|
|||
from utilities.constants import VALUE_STEPS_DEFAULT
|
||||
from utilities.constants import KEY_WIDTH
|
||||
from utilities.constants import VALUE_WIDTH_DEFAULT
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.logger import DummyLogger
|
||||
|
||||
|
||||
|
|
@ -37,12 +38,20 @@ class Config:
|
|||
def get_config(self) -> dict:
|
||||
return self.__config
|
||||
|
||||
def set_config(self, config: dict):
|
||||
for key in config:
|
||||
if key not in OPTIONAL_KEYS:
|
||||
continue
|
||||
self.__config[key.upper()] = config[key]
|
||||
return self
|
||||
|
||||
def get_output_folder(self) -> str:
|
||||
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
||||
|
||||
def set_output_folder(self, folder:str):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
|
||||
self.__config[KEY_OUTPUT_FOLDER] = folder
|
||||
return self
|
||||
|
||||
def get_guidance_scale(self) -> float:
|
||||
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||
|
|
@ -50,6 +59,7 @@ class Config:
|
|||
def set_guidance_scale(self, scale: float):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
||||
self.__config[KEY_GUIDANCE_SCALE] = scale
|
||||
return self
|
||||
|
||||
def get_height(self) -> int:
|
||||
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
|
||||
|
|
@ -57,6 +67,7 @@ class Config:
|
|||
def set_height(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
||||
self.__config[KEY_HEIGHT] = value
|
||||
return self
|
||||
|
||||
def get_preview(self) -> bool:
|
||||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||
|
|
@ -64,6 +75,7 @@ class Config:
|
|||
def set_preview(self, boolean: bool):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
||||
self.__config[KEY_PREVIEW] = boolean
|
||||
return self
|
||||
|
||||
def get_scheduler(self) -> str:
|
||||
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||
|
|
@ -73,6 +85,7 @@ class Config:
|
|||
scheduler = VALUE_SCHEDULER_DEFAULT
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
||||
self.__config[KEY_SCHEDULER] = scheduler
|
||||
return self
|
||||
|
||||
def get_seed(self) -> int:
|
||||
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
|
||||
|
|
@ -84,6 +97,7 @@ class Config:
|
|||
def set_seed(self, seed: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
||||
self.__config[KEY_SEED] = seed
|
||||
return self
|
||||
|
||||
def get_steps(self) -> int:
|
||||
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
|
||||
|
|
@ -91,6 +105,7 @@ class Config:
|
|||
def set_steps(self, steps: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
||||
self.__config[KEY_STEPS] = steps
|
||||
return self
|
||||
|
||||
def get_width(self) -> int:
|
||||
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
|
||||
|
|
@ -98,3 +113,4 @@ class Config:
|
|||
def set_width(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
||||
self.__config[KEY_WIDTH] = value
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
LOGGER_NAME = "main"
|
||||
MAX_JOB_NUMBER = 10
|
||||
|
||||
KEY_APP = "APP"
|
||||
VALUE_APP = "demo"
|
||||
|
||||
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||
|
|
@ -26,5 +30,33 @@ VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
|
|||
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
|
||||
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
|
||||
|
||||
KEY_PROMPT = "PROMPT"
|
||||
KEY_NEG_PROMPT = "NEG_PROMPT"
|
||||
|
||||
KEY_PREVIEW = "PREVIEW"
|
||||
VALUE_PREVIEW_DEFAULT = True
|
||||
|
||||
# REST API Keys
|
||||
API_KEY = "api_key"
|
||||
API_KEY_FOR_DEMO = "demo"
|
||||
UUID = "uuid"
|
||||
|
||||
BASE64IMAGE = "img"
|
||||
KEY_JOB_STATUS = "status"
|
||||
VALUE_JOB_PENDING = "pending"
|
||||
VALUE_JOB_RUNNING = "running"
|
||||
VALUE_JOB_DONE = "done"
|
||||
|
||||
REQUIRED_KEYS = [
|
||||
API_KEY.lower(),
|
||||
KEY_PROMPT.lower(),
|
||||
]
|
||||
OPTIONAL_KEYS = [
|
||||
KEY_NEG_PROMPT.lower(),
|
||||
KEY_SEED.lower(),
|
||||
KEY_WIDTH.lower(),
|
||||
KEY_HEIGHT.lower(),
|
||||
KEY_GUIDANCE_SCALE.lower(),
|
||||
KEY_STEPS.lower(),
|
||||
KEY_SCHEDULER.lower(),
|
||||
]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_env_vars() -> dict:
|
||||
return dict(os.environ)
|
||||
|
||||
|
||||
def get_env_var(key: str) -> Optional[str]:
|
||||
return os.getenv(key)
|
||||
|
||||
|
||||
def get_env_var_with_default(key: str, default_value: str) -> str:
|
||||
return os.getenv(key, default_value)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from utilities.envvar import get_env_var
|
||||
from utilities.envvar import get_env_var_with_default
|
||||
from utilities.envvar import get_env_vars
|
||||
|
||||
|
||||
class TestEnvVar(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
self.key = "TEST_ENV_VAR"
|
||||
self.value = "1234"
|
||||
os.environ[self.key] = self.value
|
||||
|
||||
def test_existed_vars(self):
|
||||
env_vars = get_env_vars()
|
||||
self.assertTrue(self.key in env_vars)
|
||||
self.assertTrue(get_env_var(self.key) == self.value)
|
||||
|
||||
def test_nonexisted_vars(self):
|
||||
nonexist_key = "TEST_ENV_VAR_RANDOM_STUFF"
|
||||
self.assertTrue(get_env_var(nonexist_key) is None)
|
||||
self.assertTrue(get_env_var_with_default(
|
||||
nonexist_key, nonexist_key) == nonexist_key)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(self):
|
||||
del os.environ[self.key]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -7,6 +7,7 @@ from utilities.logger import DummyLogger
|
|||
from utilities.memory import empty_memory_cache
|
||||
from utilities.model import Model
|
||||
from utilities.times import get_epoch_now
|
||||
from utilities.images import image_to_base64
|
||||
|
||||
|
||||
class Text2Img:
|
||||
|
|
@ -17,45 +18,45 @@ class Text2Img:
|
|||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
config: Union[Config, None],
|
||||
output_folder: str = "",
|
||||
logger: DummyLogger = DummyLogger(),
|
||||
):
|
||||
self.model = model
|
||||
self.config = config
|
||||
self.__output_folder = output_folder
|
||||
self.__logger = logger
|
||||
|
||||
def update_config(self, config: Config):
|
||||
self.config = config
|
||||
|
||||
def brunch(self, prompt: str, negative_prompt: str = ""):
|
||||
self.breakfast()
|
||||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
self.model.set_txt2img_scheduler(self.config.get_scheduler())
|
||||
|
||||
def lunch(self, prompt: str, negative_prompt: str = ""):
|
||||
t = get_epoch_now()
|
||||
seed = self.config.get_seed()
|
||||
self.__logger.info("current seed: {}".format(seed))
|
||||
pass
|
||||
|
||||
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> str:
|
||||
self.model.set_txt2img_scheduler(config.get_scheduler())
|
||||
|
||||
t = get_epoch_now()
|
||||
seed = config.get_seed()
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
self.__logger.info("current seed: {}".format(seed))
|
||||
|
||||
result = self.model.txt2img_pipeline(
|
||||
prompt=prompt,
|
||||
width=self.config.get_width(),
|
||||
height=self.config.get_height(),
|
||||
width=config.get_width(),
|
||||
height=config.get_height(),
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=self.config.get_guidance_scale(),
|
||||
num_inference_steps=self.config.get_steps(),
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
callback=None,
|
||||
callback_steps=10,
|
||||
)
|
||||
|
||||
out_filepath = "{}/{}.png".format(self.config.get_output_folder(), t)
|
||||
result.images[0].save(out_filepath)
|
||||
|
||||
self.__logger.info("output to file: {}".format(out_filepath))
|
||||
if self.__output_folder:
|
||||
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||
result.images[0].save(out_filepath)
|
||||
self.__logger.info("output to file: {}".format(out_filepath))
|
||||
|
||||
empty_memory_cache()
|
||||
|
||||
return image_to_base64(result.images[0])
|
||||
|
|
|
|||
Loading…
Reference in New Issue