263 lines
9.6 KiB
Python
263 lines
9.6 KiB
Python
import os
|
|
import datetime
|
|
import sqlite3
|
|
import uuid
|
|
|
|
from utilities.constants import APIKEY
|
|
from utilities.constants import UUID
|
|
from utilities.constants import KEY_PRIORITY
|
|
from utilities.constants import KEY_JOB_TYPE
|
|
from utilities.constants import VALUE_JOB_TXT2IMG
|
|
from utilities.constants import VALUE_JOB_IMG2IMG
|
|
from utilities.constants import VALUE_JOB_INPAINTING
|
|
from utilities.constants import KEY_JOB_STATUS
|
|
from utilities.constants import VALUE_JOB_PENDING
|
|
from utilities.constants import VALUE_JOB_RUNNING
|
|
from utilities.constants import VALUE_JOB_DONE
|
|
|
|
from utilities.constants import OUTPUT_ONLY_KEYS
|
|
from utilities.constants import OPTIONAL_KEYS
|
|
from utilities.constants import REQUIRED_KEYS
|
|
|
|
from utilities.constants import REFERENCE_IMG
|
|
from utilities.constants import BASE64IMAGE
|
|
|
|
from utilities.constants import HISTORY_TABLE_NAME
|
|
from utilities.constants import USERS_TABLE_NAME
|
|
from utilities.logger import DummyLogger
|
|
|
|
|
|
class Database:
|
|
"""This class represents a SQLite database and assumes single-thread usage."""
|
|
|
|
def __init__(self, logger: DummyLogger = DummyLogger()):
|
|
"""Initialize the class with a logger instance, but without a database connection or cursor."""
|
|
self.__connect = None # the database connection object
|
|
self.__cursor = None # the cursor object for executing SQL statements
|
|
self.__logger = logger # the logger object for logging messages
|
|
|
|
def connect(self, db_filepath) -> bool:
|
|
"""
|
|
Connect to the SQLite database file specified by `db_filepath`.
|
|
|
|
Returns True if the connection was successful, otherwise False.
|
|
"""
|
|
if not os.path.isfile(db_filepath):
|
|
self.__logger.error(f"{db_filepath} does not exist!")
|
|
return False
|
|
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
|
|
self.__cursor = self.__connect.cursor()
|
|
self.__logger.info(f"Connected to database {db_filepath}")
|
|
return True
|
|
|
|
def validate_user(self, apikey: str) -> str:
|
|
"""
|
|
Validate if the provided API key exists in the users table and return the corresponding
|
|
username if found, or an empty string otherwise.
|
|
"""
|
|
if self.__cursor is None:
|
|
self.__logger.error("Did you forget to connect to the database?")
|
|
return ""
|
|
|
|
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
|
|
self.__cursor.execute(query, (apikey,))
|
|
result = self.__cursor.fetchone()
|
|
|
|
self.__logger.debug(result)
|
|
|
|
if result is not None:
|
|
return result[0] # the first column is the username
|
|
|
|
return ""
|
|
|
|
def get_all_pending_jobs(self, apikey: str = "") -> list:
|
|
return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING)
|
|
|
|
def count_all_pending_jobs(self, apikey: str) -> int:
|
|
"""
|
|
Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key.
|
|
|
|
Returns the number of pending jobs found.
|
|
"""
|
|
if self.__cursor is None:
|
|
self.__logger.error("Did you forget to connect to the database?")
|
|
return 0
|
|
|
|
# Construct the SQL query string and list of arguments
|
|
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
|
|
query_args = (apikey, VALUE_JOB_PENDING)
|
|
|
|
# Execute the query and return the count
|
|
self.__cursor.execute(query_string, query_args)
|
|
result = self.__cursor.fetchone()
|
|
return result[0]
|
|
|
|
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
|
|
"""
|
|
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
|
|
|
|
If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter.
|
|
|
|
Returns a list of jobs matching the filters provided.
|
|
"""
|
|
if self.__cursor is None:
|
|
self.__logger.error("Did you forget to connect to the database?")
|
|
return []
|
|
|
|
# construct the SQL query string and list of arguments based on the provided filters
|
|
query_args = []
|
|
query_filters = []
|
|
if job_uuid:
|
|
query_filters.append(f"{UUID} = ?")
|
|
query_args.append(job_uuid)
|
|
if apikey:
|
|
query_filters.append(f"{APIKEY} = ?")
|
|
query_args.append(apikey)
|
|
if job_status:
|
|
query_filters.append(f"{KEY_JOB_STATUS} = ?")
|
|
query_args.append(job_status)
|
|
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
|
|
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
|
|
if query_filters:
|
|
query_string += f" WHERE {' AND '.join(query_filters)}"
|
|
|
|
# execute the query and return the results
|
|
self.__cursor.execute(query_string, tuple(query_args))
|
|
rows = self.__cursor.fetchall()
|
|
|
|
jobs = []
|
|
for row in rows:
|
|
job = {
|
|
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
|
|
}
|
|
jobs.append(job)
|
|
|
|
return jobs
|
|
|
|
def insert_new_job(self, job_dict: dict, job_uuid="") -> bool:
|
|
"""
|
|
Insert a new job into the HISTORY_TABLE_NAME table.
|
|
|
|
If `job_uuid` is not provided, a new UUID will be generated automatically.
|
|
|
|
Returns True if the insertion was successful, otherwise False.
|
|
"""
|
|
if self.__cursor is None:
|
|
self.__logger.error("Did you forget to connect to the database?")
|
|
return False
|
|
|
|
if not job_uuid:
|
|
job_uuid = str(uuid.uuid4())
|
|
self.__logger.info(f"inserting a new job with {job_uuid}")
|
|
|
|
values = [job_uuid, VALUE_JOB_PENDING]
|
|
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
|
|
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
|
|
values.append(job_dict.get(column, None))
|
|
|
|
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
|
|
self.__cursor.execute(query, tuple(values))
|
|
self.__connect.commit()
|
|
return True
|
|
|
|
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
|
|
"""
|
|
Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`.
|
|
|
|
Returns True if the update was successful, otherwise False.
|
|
"""
|
|
if self.__cursor is None:
|
|
self.__logger.error("Did you forget to connect to the database?")
|
|
return False
|
|
|
|
values = []
|
|
columns = []
|
|
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
|
|
value = job_dict.get(column, None)
|
|
if value is not None:
|
|
columns.append(column)
|
|
values.append(value)
|
|
|
|
set_clause = ", ".join([f"{column}=?" for column in columns])
|
|
# Add current timestamp to update query
|
|
set_clause += ", updated_at=?"
|
|
values.append(datetime.datetime.now())
|
|
|
|
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
|
|
|
|
values.append(job_uuid)
|
|
|
|
self.__cursor.execute(query, tuple(values))
|
|
self.__connect.commit()
|
|
return True
|
|
|
|
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
|
"""Cancel the job with the given job_uuid or apikey.
|
|
If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending".
|
|
|
|
Args:
|
|
job_uuid (str, optional): Unique job identifier. Defaults to "".
|
|
apikey (str, optional): API key associated with the job. Defaults to "".
|
|
|
|
Returns:
|
|
bool: True if the job was cancelled successfully, False otherwise.
|
|
"""
|
|
if not job_uuid and not apikey:
|
|
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
|
return False
|
|
|
|
if job_uuid:
|
|
self.__cursor.execute(
|
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
|
|
(
|
|
job_uuid,
|
|
VALUE_JOB_PENDING,
|
|
),
|
|
)
|
|
else:
|
|
self.__cursor.execute(
|
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
|
|
(
|
|
apikey,
|
|
VALUE_JOB_PENDING,
|
|
),
|
|
)
|
|
|
|
if self.__cursor.rowcount == 0:
|
|
self.__logger.info("No matching rows found.")
|
|
return False
|
|
else:
|
|
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
|
|
|
|
self.__connect.commit()
|
|
return True
|
|
|
|
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
|
"""Delete the job with the given uuid or apikey"""
|
|
if job_uuid:
|
|
self.__cursor.execute(
|
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
|
|
)
|
|
elif apikey:
|
|
self.__cursor.execute(
|
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
|
|
)
|
|
else:
|
|
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
|
return False
|
|
|
|
if self.__cursor.rowcount == 0:
|
|
print("No matching rows found.")
|
|
else:
|
|
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
|
|
self.__connect.commit()
|
|
return True
|
|
|
|
def safe_disconnect(self):
|
|
if self.__connect is not None:
|
|
self.__connect.commit()
|
|
self.__connect.close()
|
|
self.__logger.info("Disconnected from database.")
|
|
else:
|
|
self.__logger.warn("No database connection to close.")
|