From db6d3a94a7f2011f584fa450c754eb852bc61b5c Mon Sep 17 00:00:00 2001 From: HappyZ Date: Fri, 5 May 2023 00:09:23 -0700 Subject: [PATCH] transfer to use sqlite3 db instead of internal memory --- .gitignore | 3 + BUILD | 1 + main.py | 211 ++++++++++++++++----------------- manage.py | 155 ++++++++++++++++++++++++ templates/index.html | 6 +- utilities/BUILD | 24 ++-- utilities/config.py | 12 -- utilities/constants.py | 106 +++++++++-------- utilities/database.py | 262 +++++++++++++++++++++++++++++++++++++++++ utilities/img2img.py | 8 +- utilities/text2img.py | 8 +- 11 files changed, 607 insertions(+), 189 deletions(-) create mode 100644 manage.py create mode 100644 utilities/database.py diff --git a/.gitignore b/.gitignore index 7af88c3..b876447 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# sqlite3 db +*.db + # Distribution / packaging .Python build/ diff --git a/BUILD b/BUILD index 5093cbb..e400d06 100644 --- a/BUILD +++ b/BUILD @@ -8,6 +8,7 @@ par_binary( srcs=["main.py"], deps=[ "//utilities:constants", + "//utilities:database", "//utilities:logger", "//utilities:model", "//utilities:text2img", diff --git a/main.py b/main.py index 9268dee..e4b7a4b 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import argparse import copy import tempfile import pkgutil @@ -10,8 +11,7 @@ from threading import Event from threading import Thread from threading import Lock -from utilities.constants import API_KEY -from utilities.constants import API_KEY_FOR_DEMO +from utilities.constants import APIKEY from utilities.constants import KEY_APP from utilities.constants import KEY_JOB_STATUS from utilities.constants import KEY_JOB_TYPE @@ -33,6 +33,7 @@ from utilities.constants import VALUE_JOB_PENDING from utilities.constants import VALUE_JOB_RUNNING from utilities.constants import VALUE_JOB_DONE from utilities.constants import VALUE_JOB_FAILED +from utilities.database import Database from utilities.envvar import get_env_var_with_default from utilities.envvar import get_env_var from utilities.times import wait_for_seconds @@ -44,10 +45,10 @@ from utilities.img2img import Img2Img app = Flask(__name__) -app.config['TESTING'] = False memory_lock = Lock() event_termination = Event() logger = Logger(name=LOGGER_NAME) +database = Database(logger) use_gpu = True local_job_stack = [] @@ -57,13 +58,14 @@ local_completed_jobs = [] @app.route("/add_job", methods=["POST"]) def add_job(): req = request.get_json() - if API_KEY not in req: + + if APIKEY not in req: + logger.error(f"{APIKEY} not present in {req}") return "", 401 - if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: - if req[API_KEY] != API_KEY_FOR_DEMO: - return "", 401 - else: - # TODO: add logic to validate app key with a particular user + with memory_lock: + user = database.validate_user(req[APIKEY]) + if not user: + logger.error(f"user not found with {req[APIKEY]}") return "", 401 for key in req.keys(): @@ -76,71 +78,58 @@ def add_job(): if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req: return jsonify({"msg": "missing reference image"}), 404 - if len(local_job_stack) > MAX_JOB_NUMBER: - return jsonify({"msg": "too many jobs in queue, please wait"}), 500 + if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER: + return ( + jsonify({"msg": "too many jobs in queue, please wait or cancel some"}), + 500, + ) - req[UUID] = str(uuid.uuid4()) - logger.info("adding a new job with uuid {}..".format(req[UUID])) - - req[KEY_JOB_STATUS] = VALUE_JOB_PENDING - req["position"] = len(local_job_stack) + 1 + job_uuid = str(uuid.uuid4()) + logger.info("adding a new job with uuid {}..".format(job_uuid)) with memory_lock: - local_job_stack.append(req) + database.insert_new_job(req, job_uuid=job_uuid) - return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]}) + return jsonify({"msg": "", UUID: job_uuid}) @app.route("/cancel_job", methods=["POST"]) def cancel_job(): req = request.get_json() - if API_KEY not in req: + if APIKEY not in req: return "", 401 - if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: - if req[API_KEY] != API_KEY_FOR_DEMO: - return "", 401 - else: - # TODO: add logic to validate app key with a particular user + with memory_lock: + user = database.validate_user(req[APIKEY]) + if not user: return "", 401 if UUID not in req: return jsonify({"msg": "missing uuid"}), 404 - logger.info("removing job with uuid {}..".format(req[UUID])) + logger.info("cancelling job with uuid {}..".format(req[UUID])) - cancel_job_position = None with memory_lock: - for job_position in range(len(local_job_stack)): - if local_job_stack[job_position][UUID] == req[UUID]: - cancel_job_position = job_position - break - logger.info("foud {}".format(cancel_job_position)) - if cancel_job_position is not None: - if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]: - return "", 401 - if ( - local_job_stack[cancel_job_position][KEY_JOB_STATUS] - == VALUE_JOB_RUNNING - ): - logger.info( - "job at {} with uuid {} is running and cannot be cancelled".format( - cancel_job_position, req[UUID] + result = database.cancel_job(job_uuid=req[UUID]) + + if result: + msg = "job with uuid {} removed".format(req[UUID]) + return jsonify({"msg": msg}) + + with memory_lock: + jobs = database.get_jobs(job_uuid=req[UUID]) + + if jobs: + return ( + jsonify( + { + "msg": "job {} is not in pending state, unable to cancel".format( + req[UUID] ) - ) - return ( - jsonify( - { - "msg": "job {} is already running, unable to cancel".format( - req[UUID] - ) - } - ), - 405, - ) - del local_job_stack[cancel_job_position] - msg = "job with uuid {} removed".format(req[UUID]) - logger.info(msg) - return jsonify({"msg": msg}) + } + ), + 405, + ) + return ( jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}), 404, @@ -150,37 +139,16 @@ def cancel_job(): @app.route("/get_jobs", methods=["POST"]) def get_jobs(): req = request.get_json() - if API_KEY not in req: + if APIKEY not in req: return "", 401 - if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP: - if req[API_KEY] != API_KEY_FOR_DEMO: - return "", 401 - else: - # TODO: add logic to validate app key with a particular user - return "", 401 - - jobs = [] - - all_job_stack = local_job_stack + local_completed_jobs with memory_lock: - for job_position in range(len(all_job_stack)): - # filter on API_KEY - if all_job_stack[job_position][API_KEY] != req[API_KEY]: - continue - # filter on UUID - if UUID in req and req[UUID] != all_job_stack[job_position][UUID]: - continue - job = copy.deepcopy(all_job_stack[job_position]) - if job[KEY_JOB_STATUS] == VALUE_JOB_DONE: - del job["position"] - del job[API_KEY] - jobs.append(job) + user = database.validate_user(req[APIKEY]) + if not user: + return "", 401 + + with memory_lock: + jobs = database.get_jobs(job_uuid=req[UUID]) - if len(jobs) == 0: - return ( - jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}), - 404, - ) return jsonify({"jobs": jobs}) @@ -213,7 +181,7 @@ def load_model(logger: Logger) -> Model: return model -def backend(event_termination): +def backend(event_termination, db): model = load_model(logger) text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG)) img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG)) @@ -225,15 +193,20 @@ def backend(event_termination): wait_for_seconds(1) with memory_lock: - if len(local_job_stack) == 0: - continue - next_job = local_job_stack[0] - next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING + pending_jobs = database.get_all_pending_jobs() - prompt = next_job[KEY_PROMPT.lower()] - negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "") + if len(pending_jobs) == 0: + continue - config = Config().set_config(next_job) + next_job = pending_jobs[0] + + with memory_lock: + database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID]) + + prompt = next_job[KEY_PROMPT] + negative_prompt = next_job[KEY_NEG_PROMPT] + + config = Config().set_config(next_job) try: if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG: @@ -250,33 +223,35 @@ def backend(event_termination): ) except BaseException as e: logger.error("text2img.lunch error: {}".format(e)) - local_job_stack.pop(0) - next_job[KEY_JOB_STATUS] = VALUE_JOB_FAILED - local_completed_jobs.append(next_job) + with memory_lock: + database.update_job( + {KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID] + ) + continue with memory_lock: - local_job_stack.pop(0) - next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE - next_job.update(result_dict) - local_completed_jobs.append(next_job) + database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID]) + database.update_job(result_dict, job_uuid=next_job[UUID]) logger.critical("stopped") - if len(local_job_stack) > 0: - logger.info( - "remaining {} jobs in stack: {}".format( - len(local_job_stack), local_job_stack - ) - ) -def main(): - if app.testing: +def main(db_filepath, is_testing: bool = False): + database.connect(db_filepath) + + if is_testing: try: app.run(host="0.0.0.0", port="5000") except KeyboardInterrupt: pass return - thread = Thread(target=backend, args=(event_termination,)) + thread = Thread( + target=backend, + args=( + event_termination, + database, + ), + ) thread.start() # ugly solution for now # TODO: use a database to track instead of internal memory @@ -285,8 +260,24 @@ def main(): thread.join() except KeyboardInterrupt: event_termination.set() - thread.join(1) + + database.safe_disconnect() + + thread.join(2) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + + # Add an argument to set the 'testing' flag + parser.add_argument("--testing", action="store_true", help="Enable testing mode") + + # Add an argument to set the path of the database file + parser.add_argument( + "--db", type=str, default="happysd.db", help="Path to SQLite database file" + ) + + args = parser.parse_args() + logger.info(args) + + main(args.db, args.testing) diff --git a/manage.py b/manage.py new file mode 100644 index 0000000..a429f73 --- /dev/null +++ b/manage.py @@ -0,0 +1,155 @@ +import argparse +import sqlite3 +import uuid + + +def create_table_users(c): + """Create the users table if it doesn't exist""" + c.execute( + """CREATE TABLE IF NOT EXISTS users + (id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE, + apikey TEXT)""" + ) + + +def create_table_history(c): + """Create the history table if it doesn't exist""" + c.execute( + """CREATE TABLE IF NOT EXISTS history + (uuid TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP, + apikey TEXT, + priority INT, + type TEXT, + status TEXT, + prompt TEXT, + neg_prompt TEXT, + seed TEXT, + ref_img TEXT, + img TEXT, + width INT, + height INT, + guidance_scale FLOAT, + steps INT, + scheduler TEXT, + strength FLOAT, + base_model TEXT, + lora_model TEXT + )""" + ) + + +def create_user(c, username, apikey): + """Create a user with the given username and apikey, or update the apikey if the username already exists""" + c.execute("SELECT * FROM users WHERE username=?", (username,)) + result = c.fetchone() + if result is not None: + raise ValueError(f"found exisitng user {username}, please use update") + else: + c.execute("INSERT INTO users (username, apikey) VALUES (?, ?)", (username, apikey)) + + +def update_user(c, username, apikey): + """Update the apikey for the user with the given username""" + c.execute("SELECT apikey FROM users WHERE username=?", (username,)) + result = c.fetchone() + if result is not None: + old_apikey = result[0] + c.execute("UPDATE history SET apikey=? WHERE apikey=?", (apikey, old_apikey)) + c.execute("UPDATE users SET apikey=? WHERE username=?", (apikey, username)) + else: + raise ValueError("username does not exist! create it first?") + + +def delete_user(c, username): + """Delete the user with the given username, or ignore the operation if the user does not exist""" + c.execute("DELETE FROM history WHERE apikey=(SELECT apikey FROM users WHERE username=?)", (username,)) + c.execute("DELETE FROM users WHERE username=?", (username,)) + +def show_users(c, username="", details=False): + """Print all users in the users table if username is not specified, + or only the user with the given username otherwise""" + if username: + c.execute("SELECT username, apikey FROM users WHERE username=?", (username,)) + user = c.fetchone() + if user: + c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],)) + count = c.fetchone()[0] + print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") + if details: + c.execute("SELECT * FROM history WHERE apikey=?", (user[1],)) + result = c.fetchall() + print(result) + else: + print(f"No user with username '{username}' found") + else: + c.execute("SELECT username, apikey FROM users") + users = c.fetchall() + for user in users: + c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],)) + count = c.fetchone()[0] + print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") + if details: + c.execute("SELECT * FROM history WHERE apikey=?", (user[1],)) + result = c.fetchall() + print(result) + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="action") + + # Sub-parser for the "create" action + create_parser = subparsers.add_parser("create") + create_parser.add_argument("username") + create_parser.add_argument("apikey") + + # Sub-parser for the "update" action + update_parser = subparsers.add_parser("update") + update_parser.add_argument("username") + update_parser.add_argument("apikey") + + # Sub-parser for the "delete" action + delete_parser = subparsers.add_parser("delete") + delete_parser.add_argument("username") + + # Sub-parser for the "delete" action + list_parser = subparsers.add_parser("list") + list_parser.add_argument("username", nargs="?", default="") + list_parser.add_argument("--details", action="store_true", help="Showing more details") + + args = parser.parse_args() + + # Connect to the database (creates a new file if it doesn't exist) + conn = sqlite3.connect("happysd.db") + c = conn.cursor() + + # Create the users and history tables if they don't exist + create_table_users(c) + create_table_history(c) + + # Perform the requested action + if args.action == "create": + create_user(c, args.username, args.apikey) + print("User created") + elif args.action == "update": + update_user(c, args.username, args.apikey) + print("User updated") + elif args.action == "delete": + delete_user(c, args.username) + print("User deleted") + elif args.action == "list": + show_users(c, args.username, args.details) + + # Commit the changes to the database + conn.commit() + + # Close the connection + conn.close() + + +if __name__ == "__main__": + main() diff --git a/templates/index.html b/templates/index.html index bd7d2d1..423b78a 100644 --- a/templates/index.html +++ b/templates/index.html @@ -213,7 +213,7 @@ url: '/get_jobs', contentType: 'application/json; charset=utf-8', dataType: 'json', - data: JSON.stringify({ 'api_key': apikeyVal, 'uuid': uuidValue }), + data: JSON.stringify({ 'apikey': apikeyVal, 'uuid': uuidValue }), success: function (response) { console.log(response); if (response.jobs.length == 1) { @@ -404,7 +404,7 @@ contentType: 'application/json; charset=utf-8', dataType: 'json', data: JSON.stringify({ - 'api_key': apikeyVal, + 'apikey': apikeyVal, 'type': 'txt', 'prompt': promptVal, 'seed': seedVal, @@ -513,7 +513,7 @@ contentType: 'application/json; charset=utf-8', dataType: 'json', data: JSON.stringify({ - 'api_key': apikeyVal, + 'apikey': apikeyVal, 'type': 'img', 'ref_img': imageData, 'prompt': promptVal, diff --git a/utilities/BUILD b/utilities/BUILD index 5a88b0f..700fb28 100644 --- a/utilities/BUILD +++ b/utilities/BUILD @@ -16,21 +16,29 @@ py_library( srcs=["constants.py"], ) +py_library( + name="database", + srcs=["database.py"], + deps=[ + ":logger", + ], +) + py_library( - name = "envvar", - srcs = ["envvar.py"], + name="envvar", + srcs=["envvar.py"], ) py_test( - name = "envvar_test", - srcs = ["envvar_test.py"], - deps = [":envvar"], + name="envvar_test", + srcs=["envvar_test.py"], + deps=[":envvar"], ) py_library( - name = "images", - srcs = ["images.py"], + name="images", + srcs=["images.py"], ) py_library( @@ -101,4 +109,4 @@ py_test( py_library( name="web", srcs=["web.py"], -) \ No newline at end of file +) diff --git a/utilities/config.py b/utilities/config.py index 7cc0026..1526668 100644 --- a/utilities/config.py +++ b/utilities/config.py @@ -10,8 +10,6 @@ from utilities.constants import KEY_HEIGHT from utilities.constants import VALUE_HEIGHT_DEFAULT from utilities.constants import KEY_STRENGTH from utilities.constants import VALUE_STRENGTH_DEFAULT -from utilities.constants import KEY_PREVIEW -from utilities.constants import VALUE_PREVIEW_DEFAULT from utilities.constants import KEY_SCHEDULER from utilities.constants import VALUE_SCHEDULER_DEFAULT from utilities.constants import VALUE_SCHEDULER_DDIM @@ -84,16 +82,6 @@ class Config: self.__config[KEY_HEIGHT] = value return self - def get_preview(self) -> bool: - return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT) - - def set_preview(self, boolean: bool): - self.__logger.info( - "{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean) - ) - self.__config[KEY_PREVIEW] = boolean - return self - def get_scheduler(self) -> str: return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT) diff --git a/utilities/constants.py b/utilities/constants.py index 5c793a6..211ead0 100644 --- a/utilities/constants.py +++ b/utilities/constants.py @@ -7,71 +7,81 @@ LOGGER_NAME_IMG2IMG = "img2img" MAX_JOB_NUMBER = 10 -KEY_OUTPUT_FOLDER = "OUTFOLDER" + +KEY_OUTPUT_FOLDER = "outfolder" VALUE_OUTPUT_FOLDER_DEFAULT = "" -KEY_SEED = "SEED" -VALUE_SEED_DEFAULT = 0 +# +# Database +# +HISTORY_TABLE_NAME = "history" +USERS_TABLE_NAME = "users" -KEY_WIDTH = "WIDTH" -VALUE_WIDTH_DEFAULT = 512 +# +# REST API Keys +# -KEY_HEIGHT = "HEIGHT" -VALUE_HEIGHT_DEFAULT = 512 +# - input and output +APIKEY = "apikey" -KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE" -VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 +KEY_JOB_TYPE = "type" +VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE +VALUE_JOB_IMG2IMG = "img" +REFERENCE_IMG = "ref_img" +VALUE_JOB_INPAINTING = "inpaint" -KEY_STRENGTH = "STRENGTH" -VALUE_STRENGTH_DEFAULT = 0.5 - -KEY_STEPS = "STEPS" -VALUE_STEPS_DEFAULT = 50 - -KEY_SCHEDULER = "SCHEDULER" -VALUE_SCHEDULER_DEFAULT = "Default" +KEY_PROMPT = "prompt" +KEY_NEG_PROMPT = "neg_prompt" +KEY_SEED = "seed" +VALUE_SEED_DEFAULT = 0 # default value for KEY_SEED +KEY_WIDTH = "width" +VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH +KEY_HEIGHT = "height" +VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT +KEY_GUIDANCE_SCALE = "guidance_scale" +VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE +KEY_STEPS = "steps" +VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS +KEY_SCHEDULER = "scheduler" +VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler" VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler" VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler" VALUE_SCHEDULER_PNDM = "PNDMScheduler" VALUE_SCHEDULER_DDIM = "DDIMScheduler" +KEY_STRENGTH = "strength" +VALUE_STRENGTH_DEFAULT = 0.5 # default value for KEY_STRENGTH -KEY_PROMPT = "PROMPT" -KEY_NEG_PROMPT = "NEG_PROMPT" +REQUIRED_KEYS = [ + APIKEY, # str + KEY_PROMPT, # str + KEY_JOB_TYPE, # str +] +OPTIONAL_KEYS = [ + KEY_NEG_PROMPT, # str + KEY_SEED, # str + KEY_WIDTH, # int + KEY_HEIGHT, # int + KEY_GUIDANCE_SCALE, # float + KEY_STEPS, # int + KEY_SCHEDULER, # str + KEY_STRENGTH, # float + REFERENCE_IMG, # str (base64) +] -KEY_PREVIEW = "PREVIEW" -VALUE_PREVIEW_DEFAULT = True - -# REST API Keys -API_KEY = "api_key" -API_KEY_FOR_DEMO = "demo" +# - output only UUID = "uuid" - BASE64IMAGE = "img" +KEY_PRIORITY = "priority" KEY_JOB_STATUS = "status" -VALUE_JOB_PENDING = "pending" +VALUE_JOB_PENDING = "pending" # default value for KEY_JOB_STATUS VALUE_JOB_RUNNING = "running" VALUE_JOB_DONE = "done" VALUE_JOB_FAILED = "failed" -KEY_JOB_TYPE = "type" -VALUE_JOB_TXT2IMG = "txt" -VALUE_JOB_IMG2IMG = "img" -VALUE_JOB_INPAINTING = "inpaint" -REFERENCE_IMG = "ref_img" -REQUIRED_KEYS = [ - API_KEY.lower(), - KEY_PROMPT.lower(), - KEY_JOB_TYPE.lower(), -] -OPTIONAL_KEYS = [ - KEY_NEG_PROMPT.lower(), - KEY_SEED.lower(), - KEY_WIDTH.lower(), - KEY_HEIGHT.lower(), - KEY_GUIDANCE_SCALE.lower(), - KEY_STEPS.lower(), - KEY_SCHEDULER.lower(), - KEY_STRENGTH.lower(), - REFERENCE_IMG.lower(), -] +OUTPUT_ONLY_KEYS = [ + UUID, # str + KEY_PRIORITY, # int + BASE64IMAGE, # str (base64) + KEY_JOB_STATUS, # str +] \ No newline at end of file diff --git a/utilities/database.py b/utilities/database.py new file mode 100644 index 0000000..c2a05a9 --- /dev/null +++ b/utilities/database.py @@ -0,0 +1,262 @@ +import os +import datetime +import sqlite3 +import uuid + +from utilities.constants import APIKEY +from utilities.constants import UUID +from utilities.constants import KEY_PRIORITY +from utilities.constants import KEY_JOB_TYPE +from utilities.constants import VALUE_JOB_TXT2IMG +from utilities.constants import VALUE_JOB_IMG2IMG +from utilities.constants import VALUE_JOB_INPAINTING +from utilities.constants import KEY_JOB_STATUS +from utilities.constants import VALUE_JOB_PENDING +from utilities.constants import VALUE_JOB_RUNNING +from utilities.constants import VALUE_JOB_DONE + +from utilities.constants import OUTPUT_ONLY_KEYS +from utilities.constants import OPTIONAL_KEYS +from utilities.constants import REQUIRED_KEYS + +from utilities.constants import REFERENCE_IMG +from utilities.constants import BASE64IMAGE + +from utilities.constants import HISTORY_TABLE_NAME +from utilities.constants import USERS_TABLE_NAME +from utilities.logger import DummyLogger + + +class Database: + """This class represents a SQLite database and assumes single-thread usage.""" + + def __init__(self, logger: DummyLogger = DummyLogger()): + """Initialize the class with a logger instance, but without a database connection or cursor.""" + self.__connect = None # the database connection object + self.__cursor = None # the cursor object for executing SQL statements + self.__logger = logger # the logger object for logging messages + + def connect(self, db_filepath) -> bool: + """ + Connect to the SQLite database file specified by `db_filepath`. + + Returns True if the connection was successful, otherwise False. + """ + if not os.path.isfile(db_filepath): + self.__logger.error(f"{db_filepath} does not exist!") + return False + self.__connect = sqlite3.connect(db_filepath, check_same_thread=False) + self.__cursor = self.__connect.cursor() + self.__logger.info(f"Connected to database {db_filepath}") + return True + + def validate_user(self, apikey: str) -> str: + """ + Validate if the provided API key exists in the users table and return the corresponding + username if found, or an empty string otherwise. + """ + if self.__cursor is None: + self.__logger.error("Did you forget to connect to the database?") + return "" + + query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?" + self.__cursor.execute(query, (apikey,)) + result = self.__cursor.fetchone() + + self.__logger.debug(result) + + if result is not None: + return result[0] # the first column is the username + + return "" + + def get_all_pending_jobs(self, apikey: str = "") -> list: + return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING) + + def count_all_pending_jobs(self, apikey: str) -> int: + """ + Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key. + + Returns the number of pending jobs found. + """ + if self.__cursor is None: + self.__logger.error("Did you forget to connect to the database?") + return 0 + + # Construct the SQL query string and list of arguments + query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?" + query_args = (apikey, VALUE_JOB_PENDING) + + # Execute the query and return the count + self.__cursor.execute(query_string, query_args) + result = self.__cursor.fetchone() + return result[0] + + def get_jobs(self, job_uuid="", apikey="", job_status="") -> list: + """ + Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters. + + If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter. + + Returns a list of jobs matching the filters provided. + """ + if self.__cursor is None: + self.__logger.error("Did you forget to connect to the database?") + return [] + + # construct the SQL query string and list of arguments based on the provided filters + query_args = [] + query_filters = [] + if job_uuid: + query_filters.append(f"{UUID} = ?") + query_args.append(job_uuid) + if apikey: + query_filters.append(f"{APIKEY} = ?") + query_args.append(apikey) + if job_status: + query_filters.append(f"{KEY_JOB_STATUS} = ?") + query_args.append(job_status) + columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS + query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}" + if query_filters: + query_string += f" WHERE {' AND '.join(query_filters)}" + + # execute the query and return the results + self.__cursor.execute(query_string, tuple(query_args)) + rows = self.__cursor.fetchall() + + jobs = [] + for row in rows: + job = { + columns[i]: row[i] for i in range(len(columns)) if row[i] is not None + } + jobs.append(job) + + return jobs + + def insert_new_job(self, job_dict: dict, job_uuid="") -> bool: + """ + Insert a new job into the HISTORY_TABLE_NAME table. + + If `job_uuid` is not provided, a new UUID will be generated automatically. + + Returns True if the insertion was successful, otherwise False. + """ + if self.__cursor is None: + self.__logger.error("Did you forget to connect to the database?") + return False + + if not job_uuid: + job_uuid = str(uuid.uuid4()) + self.__logger.info(f"inserting a new job with {job_uuid}") + + values = [job_uuid, VALUE_JOB_PENDING] + columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS + for column in REQUIRED_KEYS + OPTIONAL_KEYS: + values.append(job_dict.get(column, None)) + + query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})" + self.__cursor.execute(query, tuple(values)) + self.__connect.commit() + return True + + def update_job(self, job_dict: dict, job_uuid: str) -> bool: + """ + Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`. + + Returns True if the update was successful, otherwise False. + """ + if self.__cursor is None: + self.__logger.error("Did you forget to connect to the database?") + return False + + values = [] + columns = [] + for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS: + value = job_dict.get(column, None) + if value is not None: + columns.append(column) + values.append(value) + + set_clause = ", ".join([f"{column}=?" for column in columns]) + # Add current timestamp to update query + set_clause += ", updated_at=?" + values.append(datetime.datetime.now()) + + query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?" + + values.append(job_uuid) + + self.__cursor.execute(query, tuple(values)) + self.__connect.commit() + return True + + def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool: + """Cancel the job with the given job_uuid or apikey. + If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending". + + Args: + job_uuid (str, optional): Unique job identifier. Defaults to "". + apikey (str, optional): API key associated with the job. Defaults to "". + + Returns: + bool: True if the job was cancelled successfully, False otherwise. + """ + if not job_uuid and not apikey: + self.__logger.error(f"either {UUID} or {APIKEY} must be provided.") + return False + + if job_uuid: + self.__cursor.execute( + f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?", + ( + job_uuid, + VALUE_JOB_PENDING, + ), + ) + else: + self.__cursor.execute( + f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?", + ( + apikey, + VALUE_JOB_PENDING, + ), + ) + + if self.__cursor.rowcount == 0: + self.__logger.info("No matching rows found.") + return False + else: + self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.") + + self.__connect.commit() + return True + + def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool: + """Delete the job with the given uuid or apikey""" + if job_uuid: + self.__cursor.execute( + f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,) + ) + elif apikey: + self.__cursor.execute( + f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,) + ) + else: + self.logger.error(f"either {UUID} or {APIKEY} must be provided.") + return False + + if self.__cursor.rowcount == 0: + print("No matching rows found.") + else: + self.logger.info(f"{self.__cursor.rowcount} rows deleted.") + self.__connect.commit() + return True + + def safe_disconnect(self): + if self.__connect is not None: + self.__connect.commit() + self.__connect.close() + self.__logger.info("Disconnected from database.") + else: + self.__logger.warn("No database connection to close.") diff --git a/utilities/img2img.py b/utilities/img2img.py index 5288439..aa93873 100644 --- a/utilities/img2img.py +++ b/utilities/img2img.py @@ -84,8 +84,8 @@ class Img2Img: return { BASE64IMAGE: image_to_base64(result.images[0]), - KEY_SEED.lower(): str(seed), - KEY_WIDTH.lower(): config.get_width(), - KEY_HEIGHT.lower(): config.get_height(), - KEY_STEPS.lower(): config.get_steps(), + KEY_SEED: str(seed), + KEY_WIDTH: config.get_width(), + KEY_HEIGHT: config.get_height(), + KEY_STEPS: config.get_steps(), } diff --git a/utilities/text2img.py b/utilities/text2img.py index 3af0a3a..7f89056 100644 --- a/utilities/text2img.py +++ b/utilities/text2img.py @@ -66,8 +66,8 @@ class Text2Img: return { BASE64IMAGE: image_to_base64(result.images[0]), - KEY_SEED.lower(): str(seed), - KEY_WIDTH.lower(): config.get_width(), - KEY_HEIGHT.lower(): config.get_height(), - KEY_STEPS.lower(): config.get_steps(), + KEY_SEED: str(seed), + KEY_WIDTH: config.get_width(), + KEY_HEIGHT: config.get_height(), + KEY_STEPS: config.get_steps(), }