transfer to use sqlite3 db instead of internal memory

This commit is contained in:
HappyZ 2023-05-05 00:09:23 -07:00
parent 11948820f8
commit db6d3a94a7
11 changed files with 607 additions and 189 deletions

3
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# sqlite3 db
*.db
# Distribution / packaging
.Python
build/

1
BUILD
View File

@ -8,6 +8,7 @@ par_binary(
srcs=["main.py"],
deps=[
"//utilities:constants",
"//utilities:database",
"//utilities:logger",
"//utilities:model",
"//utilities:text2img",

193
main.py
View File

@ -1,3 +1,4 @@
import argparse
import copy
import tempfile
import pkgutil
@ -10,8 +11,7 @@ from threading import Event
from threading import Thread
from threading import Lock
from utilities.constants import API_KEY
from utilities.constants import API_KEY_FOR_DEMO
from utilities.constants import APIKEY
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_JOB_TYPE
@ -33,6 +33,7 @@ from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import VALUE_JOB_FAILED
from utilities.database import Database
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
@ -44,10 +45,10 @@ from utilities.img2img import Img2Img
app = Flask(__name__)
app.config['TESTING'] = False
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
database = Database(logger)
use_gpu = True
local_job_stack = []
@ -57,13 +58,14 @@ local_completed_jobs = []
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if API_KEY not in req:
if APIKEY not in req:
logger.error(f"{APIKEY} not present in {req}")
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
logger.error(f"user not found with {req[APIKEY]}")
return "", 401
for key in req.keys():
@ -76,71 +78,58 @@ def add_job():
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
return jsonify({"msg": "missing reference image"}), 404
if len(local_job_stack) > MAX_JOB_NUMBER:
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
return (
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
500,
)
req[UUID] = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(req[UUID]))
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
req["position"] = len(local_job_stack) + 1
job_uuid = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(job_uuid))
with memory_lock:
local_job_stack.append(req)
database.insert_new_job(req, job_uuid=job_uuid)
return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]})
return jsonify({"msg": "", UUID: job_uuid})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if API_KEY not in req:
if APIKEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("removing job with uuid {}..".format(req[UUID]))
logger.info("cancelling job with uuid {}..".format(req[UUID]))
cancel_job_position = None
with memory_lock:
for job_position in range(len(local_job_stack)):
if local_job_stack[job_position][UUID] == req[UUID]:
cancel_job_position = job_position
break
logger.info("foud {}".format(cancel_job_position))
if cancel_job_position is not None:
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
return "", 401
if (
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
== VALUE_JOB_RUNNING
):
logger.info(
"job at {} with uuid {} is running and cannot be cancelled".format(
cancel_job_position, req[UUID]
)
)
result = database.cancel_job(job_uuid=req[UUID])
if result:
msg = "job with uuid {} removed".format(req[UUID])
return jsonify({"msg": msg})
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
if jobs:
return (
jsonify(
{
"msg": "job {} is already running, unable to cancel".format(
"msg": "job {} is not in pending state, unable to cancel".format(
req[UUID]
)
}
),
405,
)
del local_job_stack[cancel_job_position]
msg = "job with uuid {} removed".format(req[UUID])
logger.info(msg)
return jsonify({"msg": msg})
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
@ -150,37 +139,16 @@ def cancel_job():
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if API_KEY not in req:
if APIKEY not in req:
return "", 401
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
if req[API_KEY] != API_KEY_FOR_DEMO:
return "", 401
else:
# TODO: add logic to validate app key with a particular user
return "", 401
jobs = []
all_job_stack = local_job_stack + local_completed_jobs
with memory_lock:
for job_position in range(len(all_job_stack)):
# filter on API_KEY
if all_job_stack[job_position][API_KEY] != req[API_KEY]:
continue
# filter on UUID
if UUID in req and req[UUID] != all_job_stack[job_position][UUID]:
continue
job = copy.deepcopy(all_job_stack[job_position])
if job[KEY_JOB_STATUS] == VALUE_JOB_DONE:
del job["position"]
del job[API_KEY]
jobs.append(job)
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
if len(jobs) == 0:
return (
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
404,
)
return jsonify({"jobs": jobs})
@ -213,7 +181,7 @@ def load_model(logger: Logger) -> Model:
return model
def backend(event_termination):
def backend(event_termination, db):
model = load_model(logger)
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
@ -225,13 +193,18 @@ def backend(event_termination):
wait_for_seconds(1)
with memory_lock:
if len(local_job_stack) == 0:
continue
next_job = local_job_stack[0]
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
pending_jobs = database.get_all_pending_jobs()
prompt = next_job[KEY_PROMPT.lower()]
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
if len(pending_jobs) == 0:
continue
next_job = pending_jobs[0]
with memory_lock:
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
prompt = next_job[KEY_PROMPT]
negative_prompt = next_job[KEY_NEG_PROMPT]
config = Config().set_config(next_job)
@ -250,33 +223,35 @@ def backend(event_termination):
)
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
local_job_stack.pop(0)
next_job[KEY_JOB_STATUS] = VALUE_JOB_FAILED
local_completed_jobs.append(next_job)
with memory_lock:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
)
continue
with memory_lock:
local_job_stack.pop(0)
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
next_job.update(result_dict)
local_completed_jobs.append(next_job)
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
database.update_job(result_dict, job_uuid=next_job[UUID])
logger.critical("stopped")
if len(local_job_stack) > 0:
logger.info(
"remaining {} jobs in stack: {}".format(
len(local_job_stack), local_job_stack
)
)
def main():
if app.testing:
def main(db_filepath, is_testing: bool = False):
database.connect(db_filepath)
if is_testing:
try:
app.run(host="0.0.0.0", port="5000")
except KeyboardInterrupt:
pass
return
thread = Thread(target=backend, args=(event_termination,))
thread = Thread(
target=backend,
args=(
event_termination,
database,
),
)
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
@ -285,8 +260,24 @@ def main():
thread.join()
except KeyboardInterrupt:
event_termination.set()
thread.join(1)
database.safe_disconnect()
thread.join(2)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
# Add an argument to set the 'testing' flag
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
args = parser.parse_args()
logger.info(args)
main(args.db, args.testing)

155
manage.py Normal file
View File

@ -0,0 +1,155 @@
import argparse
import sqlite3
import uuid
def create_table_users(c):
"""Create the users table if it doesn't exist"""
c.execute(
"""CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
apikey TEXT)"""
)
def create_table_history(c):
"""Create the history table if it doesn't exist"""
c.execute(
"""CREATE TABLE IF NOT EXISTS history
(uuid TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP,
apikey TEXT,
priority INT,
type TEXT,
status TEXT,
prompt TEXT,
neg_prompt TEXT,
seed TEXT,
ref_img TEXT,
img TEXT,
width INT,
height INT,
guidance_scale FLOAT,
steps INT,
scheduler TEXT,
strength FLOAT,
base_model TEXT,
lora_model TEXT
)"""
)
def create_user(c, username, apikey):
"""Create a user with the given username and apikey, or update the apikey if the username already exists"""
c.execute("SELECT * FROM users WHERE username=?", (username,))
result = c.fetchone()
if result is not None:
raise ValueError(f"found exisitng user {username}, please use update")
else:
c.execute("INSERT INTO users (username, apikey) VALUES (?, ?)", (username, apikey))
def update_user(c, username, apikey):
"""Update the apikey for the user with the given username"""
c.execute("SELECT apikey FROM users WHERE username=?", (username,))
result = c.fetchone()
if result is not None:
old_apikey = result[0]
c.execute("UPDATE history SET apikey=? WHERE apikey=?", (apikey, old_apikey))
c.execute("UPDATE users SET apikey=? WHERE username=?", (apikey, username))
else:
raise ValueError("username does not exist! create it first?")
def delete_user(c, username):
"""Delete the user with the given username, or ignore the operation if the user does not exist"""
c.execute("DELETE FROM history WHERE apikey=(SELECT apikey FROM users WHERE username=?)", (username,))
c.execute("DELETE FROM users WHERE username=?", (username,))
def show_users(c, username="", details=False):
"""Print all users in the users table if username is not specified,
or only the user with the given username otherwise"""
if username:
c.execute("SELECT username, apikey FROM users WHERE username=?", (username,))
user = c.fetchone()
if user:
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
count = c.fetchone()[0]
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
result = c.fetchall()
print(result)
else:
print(f"No user with username '{username}' found")
else:
c.execute("SELECT username, apikey FROM users")
users = c.fetchall()
for user in users:
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
count = c.fetchone()[0]
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
result = c.fetchall()
print(result)
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="action")
# Sub-parser for the "create" action
create_parser = subparsers.add_parser("create")
create_parser.add_argument("username")
create_parser.add_argument("apikey")
# Sub-parser for the "update" action
update_parser = subparsers.add_parser("update")
update_parser.add_argument("username")
update_parser.add_argument("apikey")
# Sub-parser for the "delete" action
delete_parser = subparsers.add_parser("delete")
delete_parser.add_argument("username")
# Sub-parser for the "delete" action
list_parser = subparsers.add_parser("list")
list_parser.add_argument("username", nargs="?", default="")
list_parser.add_argument("--details", action="store_true", help="Showing more details")
args = parser.parse_args()
# Connect to the database (creates a new file if it doesn't exist)
conn = sqlite3.connect("happysd.db")
c = conn.cursor()
# Create the users and history tables if they don't exist
create_table_users(c)
create_table_history(c)
# Perform the requested action
if args.action == "create":
create_user(c, args.username, args.apikey)
print("User created")
elif args.action == "update":
update_user(c, args.username, args.apikey)
print("User updated")
elif args.action == "delete":
delete_user(c, args.username)
print("User deleted")
elif args.action == "list":
show_users(c, args.username, args.details)
# Commit the changes to the database
conn.commit()
# Close the connection
conn.close()
if __name__ == "__main__":
main()

View File

@ -213,7 +213,7 @@
url: '/get_jobs',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({ 'api_key': apikeyVal, 'uuid': uuidValue }),
data: JSON.stringify({ 'apikey': apikeyVal, 'uuid': uuidValue }),
success: function (response) {
console.log(response);
if (response.jobs.length == 1) {
@ -404,7 +404,7 @@
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({
'api_key': apikeyVal,
'apikey': apikeyVal,
'type': 'txt',
'prompt': promptVal,
'seed': seedVal,
@ -513,7 +513,7 @@
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({
'api_key': apikeyVal,
'apikey': apikeyVal,
'type': 'img',
'ref_img': imageData,
'prompt': promptVal,

View File

@ -16,21 +16,29 @@ py_library(
srcs=["constants.py"],
)
py_library(
name="database",
srcs=["database.py"],
deps=[
":logger",
],
)
py_library(
name = "envvar",
srcs = ["envvar.py"],
name="envvar",
srcs=["envvar.py"],
)
py_test(
name = "envvar_test",
srcs = ["envvar_test.py"],
deps = [":envvar"],
name="envvar_test",
srcs=["envvar_test.py"],
deps=[":envvar"],
)
py_library(
name = "images",
srcs = ["images.py"],
name="images",
srcs=["images.py"],
)
py_library(

View File

@ -10,8 +10,6 @@ from utilities.constants import KEY_HEIGHT
from utilities.constants import VALUE_HEIGHT_DEFAULT
from utilities.constants import KEY_STRENGTH
from utilities.constants import VALUE_STRENGTH_DEFAULT
from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
@ -84,16 +82,6 @@ class Config:
self.__config[KEY_HEIGHT] = value
return self
def get_preview(self) -> bool:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
def set_preview(self, boolean: bool):
self.__logger.info(
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
)
self.__config[KEY_PREVIEW] = boolean
return self
def get_scheduler(self) -> str:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)

View File

@ -7,71 +7,81 @@ LOGGER_NAME_IMG2IMG = "img2img"
MAX_JOB_NUMBER = 10
KEY_OUTPUT_FOLDER = "OUTFOLDER"
KEY_OUTPUT_FOLDER = "outfolder"
VALUE_OUTPUT_FOLDER_DEFAULT = ""
KEY_SEED = "SEED"
VALUE_SEED_DEFAULT = 0
#
# Database
#
HISTORY_TABLE_NAME = "history"
USERS_TABLE_NAME = "users"
KEY_WIDTH = "WIDTH"
VALUE_WIDTH_DEFAULT = 512
#
# REST API Keys
#
KEY_HEIGHT = "HEIGHT"
VALUE_HEIGHT_DEFAULT = 512
# - input and output
APIKEY = "apikey"
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0
KEY_JOB_TYPE = "type"
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
VALUE_JOB_IMG2IMG = "img"
REFERENCE_IMG = "ref_img"
VALUE_JOB_INPAINTING = "inpaint"
KEY_STRENGTH = "STRENGTH"
VALUE_STRENGTH_DEFAULT = 0.5
KEY_STEPS = "STEPS"
VALUE_STEPS_DEFAULT = 50
KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default"
KEY_PROMPT = "prompt"
KEY_NEG_PROMPT = "neg_prompt"
KEY_SEED = "seed"
VALUE_SEED_DEFAULT = 0 # default value for KEY_SEED
KEY_WIDTH = "width"
VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH
KEY_HEIGHT = "height"
VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT
KEY_GUIDANCE_SCALE = "guidance_scale"
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE
KEY_STEPS = "steps"
VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS
KEY_SCHEDULER = "scheduler"
VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
KEY_STRENGTH = "strength"
VALUE_STRENGTH_DEFAULT = 0.5 # default value for KEY_STRENGTH
KEY_PROMPT = "PROMPT"
KEY_NEG_PROMPT = "NEG_PROMPT"
REQUIRED_KEYS = [
APIKEY, # str
KEY_PROMPT, # str
KEY_JOB_TYPE, # str
]
OPTIONAL_KEYS = [
KEY_NEG_PROMPT, # str
KEY_SEED, # str
KEY_WIDTH, # int
KEY_HEIGHT, # int
KEY_GUIDANCE_SCALE, # float
KEY_STEPS, # int
KEY_SCHEDULER, # str
KEY_STRENGTH, # float
REFERENCE_IMG, # str (base64)
]
KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True
# REST API Keys
API_KEY = "api_key"
API_KEY_FOR_DEMO = "demo"
# - output only
UUID = "uuid"
BASE64IMAGE = "img"
KEY_PRIORITY = "priority"
KEY_JOB_STATUS = "status"
VALUE_JOB_PENDING = "pending"
VALUE_JOB_PENDING = "pending" # default value for KEY_JOB_STATUS
VALUE_JOB_RUNNING = "running"
VALUE_JOB_DONE = "done"
VALUE_JOB_FAILED = "failed"
KEY_JOB_TYPE = "type"
VALUE_JOB_TXT2IMG = "txt"
VALUE_JOB_IMG2IMG = "img"
VALUE_JOB_INPAINTING = "inpaint"
REFERENCE_IMG = "ref_img"
REQUIRED_KEYS = [
API_KEY.lower(),
KEY_PROMPT.lower(),
KEY_JOB_TYPE.lower(),
]
OPTIONAL_KEYS = [
KEY_NEG_PROMPT.lower(),
KEY_SEED.lower(),
KEY_WIDTH.lower(),
KEY_HEIGHT.lower(),
KEY_GUIDANCE_SCALE.lower(),
KEY_STEPS.lower(),
KEY_SCHEDULER.lower(),
KEY_STRENGTH.lower(),
REFERENCE_IMG.lower(),
OUTPUT_ONLY_KEYS = [
UUID, # str
KEY_PRIORITY, # int
BASE64IMAGE, # str (base64)
KEY_JOB_STATUS, # str
]

262
utilities/database.py Normal file
View File

@ -0,0 +1,262 @@
import os
import datetime
import sqlite3
import uuid
from utilities.constants import APIKEY
from utilities.constants import UUID
from utilities.constants import KEY_PRIORITY
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import OUTPUT_ONLY_KEYS
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import REFERENCE_IMG
from utilities.constants import BASE64IMAGE
from utilities.constants import HISTORY_TABLE_NAME
from utilities.constants import USERS_TABLE_NAME
from utilities.logger import DummyLogger
class Database:
"""This class represents a SQLite database and assumes single-thread usage."""
def __init__(self, logger: DummyLogger = DummyLogger()):
"""Initialize the class with a logger instance, but without a database connection or cursor."""
self.__connect = None # the database connection object
self.__cursor = None # the cursor object for executing SQL statements
self.__logger = logger # the logger object for logging messages
def connect(self, db_filepath) -> bool:
"""
Connect to the SQLite database file specified by `db_filepath`.
Returns True if the connection was successful, otherwise False.
"""
if not os.path.isfile(db_filepath):
self.__logger.error(f"{db_filepath} does not exist!")
return False
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
self.__cursor = self.__connect.cursor()
self.__logger.info(f"Connected to database {db_filepath}")
return True
def validate_user(self, apikey: str) -> str:
"""
Validate if the provided API key exists in the users table and return the corresponding
username if found, or an empty string otherwise.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return ""
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
self.__cursor.execute(query, (apikey,))
result = self.__cursor.fetchone()
self.__logger.debug(result)
if result is not None:
return result[0] # the first column is the username
return ""
def get_all_pending_jobs(self, apikey: str = "") -> list:
return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING)
def count_all_pending_jobs(self, apikey: str) -> int:
"""
Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key.
Returns the number of pending jobs found.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return 0
# Construct the SQL query string and list of arguments
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
query_args = (apikey, VALUE_JOB_PENDING)
# Execute the query and return the count
self.__cursor.execute(query_string, query_args)
result = self.__cursor.fetchone()
return result[0]
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
"""
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter.
Returns a list of jobs matching the filters provided.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return []
# construct the SQL query string and list of arguments based on the provided filters
query_args = []
query_filters = []
if job_uuid:
query_filters.append(f"{UUID} = ?")
query_args.append(job_uuid)
if apikey:
query_filters.append(f"{APIKEY} = ?")
query_args.append(apikey)
if job_status:
query_filters.append(f"{KEY_JOB_STATUS} = ?")
query_args.append(job_status)
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
if query_filters:
query_string += f" WHERE {' AND '.join(query_filters)}"
# execute the query and return the results
self.__cursor.execute(query_string, tuple(query_args))
rows = self.__cursor.fetchall()
jobs = []
for row in rows:
job = {
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
}
jobs.append(job)
return jobs
def insert_new_job(self, job_dict: dict, job_uuid="") -> bool:
"""
Insert a new job into the HISTORY_TABLE_NAME table.
If `job_uuid` is not provided, a new UUID will be generated automatically.
Returns True if the insertion was successful, otherwise False.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return False
if not job_uuid:
job_uuid = str(uuid.uuid4())
self.__logger.info(f"inserting a new job with {job_uuid}")
values = [job_uuid, VALUE_JOB_PENDING]
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
values.append(job_dict.get(column, None))
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
self.__cursor.execute(query, tuple(values))
self.__connect.commit()
return True
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
"""
Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`.
Returns True if the update was successful, otherwise False.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return False
values = []
columns = []
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
value = job_dict.get(column, None)
if value is not None:
columns.append(column)
values.append(value)
set_clause = ", ".join([f"{column}=?" for column in columns])
# Add current timestamp to update query
set_clause += ", updated_at=?"
values.append(datetime.datetime.now())
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
values.append(job_uuid)
self.__cursor.execute(query, tuple(values))
self.__connect.commit()
return True
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
"""Cancel the job with the given job_uuid or apikey.
If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending".
Args:
job_uuid (str, optional): Unique job identifier. Defaults to "".
apikey (str, optional): API key associated with the job. Defaults to "".
Returns:
bool: True if the job was cancelled successfully, False otherwise.
"""
if not job_uuid and not apikey:
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
return False
if job_uuid:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
(
job_uuid,
VALUE_JOB_PENDING,
),
)
else:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
(
apikey,
VALUE_JOB_PENDING,
),
)
if self.__cursor.rowcount == 0:
self.__logger.info("No matching rows found.")
return False
else:
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
self.__connect.commit()
return True
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
"""Delete the job with the given uuid or apikey"""
if job_uuid:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
)
elif apikey:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
)
else:
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
return False
if self.__cursor.rowcount == 0:
print("No matching rows found.")
else:
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
self.__connect.commit()
return True
def safe_disconnect(self):
if self.__connect is not None:
self.__connect.commit()
self.__connect.close()
self.__logger.info("Disconnected from database.")
else:
self.__logger.warn("No database connection to close.")

View File

@ -84,8 +84,8 @@ class Img2Img:
return {
BASE64IMAGE: image_to_base64(result.images[0]),
KEY_SEED.lower(): str(seed),
KEY_WIDTH.lower(): config.get_width(),
KEY_HEIGHT.lower(): config.get_height(),
KEY_STEPS.lower(): config.get_steps(),
KEY_SEED: str(seed),
KEY_WIDTH: config.get_width(),
KEY_HEIGHT: config.get_height(),
KEY_STEPS: config.get_steps(),
}

View File

@ -66,8 +66,8 @@ class Text2Img:
return {
BASE64IMAGE: image_to_base64(result.images[0]),
KEY_SEED.lower(): str(seed),
KEY_WIDTH.lower(): config.get_width(),
KEY_HEIGHT.lower(): config.get_height(),
KEY_STEPS.lower(): config.get_steps(),
KEY_SEED: str(seed),
KEY_WIDTH: config.get_width(),
KEY_HEIGHT: config.get_height(),
KEY_STEPS: config.get_steps(),
}