stable-diffusion-for-fun/utilities/database.py

358 lines
13 KiB
Python

import os
import datetime
import sqlite3
import fcntl
import uuid
from utilities.constants import APIKEY
from utilities.constants import UUID
from utilities.constants import KEY_PRIORITY
from utilities.constants import KEY_IS_PRIVATE
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_RESTORATION
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import LOCK_FILEPATH
from utilities.constants import OUTPUT_ONLY_KEYS
from utilities.constants import ANONYMOUS_KEYS
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import INTERNAL_KEYS
from utilities.constants import REFERENCE_IMG
from utilities.constants import MASK_IMG
from utilities.constants import BASE64IMAGE
from utilities.constants import HISTORY_TABLE_NAME
from utilities.constants import USERS_TABLE_NAME
from utilities.logger import DummyLogger
from utilities.times import get_epoch_now
from utilities.times import epoch_to_string
from utilities.images import save_image
# Function to acquire a lock on the database file
def acquire_lock():
lock_fd = open(LOCK_FILEPATH, "w")
fcntl.flock(lock_fd, fcntl.LOCK_EX)
# Function to release the lock on the database file
def release_lock():
lock_fd = open(LOCK_FILEPATH, "w")
fcntl.flock(lock_fd, fcntl.LOCK_UN)
lock_fd.close()
class Database:
"""This class represents a SQLite database and assumes single-thread usage."""
def __init__(self, logger: DummyLogger = DummyLogger(), image_folderpath=""):
"""Initialize the class with a logger instance, but without a database connection or cursor."""
self.__connect = None # the database connection object
self.is_connected = False
self.__cursor = None # the cursor object for executing SQL statements
self.__logger = logger # the logger object for logging messages
self.__image_output_folder = ""
self.set_image_output_folder(image_folderpath)
def set_image_output_folder(self, image_folderpath):
self.__image_output_folder = image_folderpath # output image to a folder instead of storing it in sqlite3
if image_folderpath:
try:
os.makedirs(image_folderpath, exist_ok=True)
except OSError as err:
self.__logger.warn(f"{image_folderpath} failed to create: {err}")
self.__image_output_folder = ""
def connect(self, db_filepath) -> bool:
"""
Connect to the SQLite database file specified by `db_filepath`.
Returns True if the connection was successful, otherwise False.
"""
if not os.path.isfile(db_filepath):
self.__logger.error(f"{db_filepath} does not exist!")
return False
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
self.__logger.info(f"Connected to database {db_filepath}")
self.is_connected = True
return True
def get_cursor(self):
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
return self.__connect.cursor()
def commit(self):
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
return self.__connect.commit()
def validate_user(self, apikey: str) -> str:
"""
Validate if the provided API key exists in the users table and return the corresponding
username if found, or an empty string otherwise.
"""
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
c = self.get_cursor()
result = c.execute(query, (apikey,)).fetchone()
if result is not None:
return result[0]
return ""
def get_one_pending_job(self, apikey: str = "") -> list:
return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING, limit_count=1)
def count_all_pending_jobs(self, apikey: str) -> int:
"""
Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key.
Returns the number of pending jobs found.
"""
# Construct the SQL query string and list of arguments
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
query_args = (apikey, VALUE_JOB_PENDING)
# Execute the query and return the count
c = self.get_cursor()
result = c.execute(query_string, query_args).fetchone()
return result[0]
def get_random_jobs(self, limit_count=0) -> list:
query = f"SELECT {', '.join(ANONYMOUS_KEYS)} FROM {HISTORY_TABLE_NAME} WHERE rowid IN (SELECT rowid FROM {HISTORY_TABLE_NAME} WHERE {KEY_JOB_STATUS} = ? AND {KEY_IS_PRIVATE} = ? AND {KEY_JOB_TYPE} IN (?, ?, ?) ORDER BY RANDOM() LIMIT ?)"
# execute the query and return the results
c = self.get_cursor()
rows = c.execute(
query,
(
VALUE_JOB_DONE,
False,
VALUE_JOB_IMG2IMG,
VALUE_JOB_INPAINTING,
VALUE_JOB_TXT2IMG,
limit_count,
),
).fetchall()
jobs = []
for row in rows:
job = {
ANONYMOUS_KEYS[i]: row[i]
for i in range(len(ANONYMOUS_KEYS))
if row[i] is not None
}
jobs.append(job)
return jobs
def get_jobs(
self, job_uuid="", apikey="", job_status="", job_types=[], limit_count=0
) -> list:
"""
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
If `job_uuid` or `apikey` or `job_status` or `job_type` is provided, the query will include that filter.
Returns a list of jobs matching the filters provided.
"""
# construct the SQL query string and list of arguments based on the provided filters
values = []
query_filters = []
if job_uuid:
query_filters.append(f"{UUID} = ?")
values.append(job_uuid)
if apikey:
query_filters.append(f"{APIKEY} = ?")
values.append(apikey)
if job_status:
query_filters.append(f"{KEY_JOB_STATUS} = ?")
values.append(job_status)
if job_types:
query_filters.append(
f"{KEY_JOB_TYPE} IN ({', '.join(['?' for _ in job_types])})"
)
values += job_types
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
if query_filters:
query += f" WHERE {' AND '.join(query_filters)}"
query += f" ORDER BY created_at DESC"
if limit_count:
query += f" LIMIT {limit_count}"
# execute the query and return the results
c = self.get_cursor()
rows = c.execute(query, tuple(values)).fetchall()
jobs = []
for row in rows:
job = {
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
}
jobs.append(job)
return jobs
def insert_new_job(self, job_dict: dict, job_uuid="") -> bool:
"""
Insert a new job into the HISTORY_TABLE_NAME table.
If `job_uuid` is not provided, a new UUID will be generated automatically.
Returns True if the insertion was successful, otherwise False.
"""
if not job_uuid:
job_uuid = str(uuid.uuid4())
self.__logger.info(f"inserting a new job with {job_uuid}")
current_epoch = get_epoch_now()
# store image to job_dict if has one
if (
self.__image_output_folder
and REFERENCE_IMG in job_dict
and "base64" in job_dict[REFERENCE_IMG]
):
ref_img_filepath = f"{self.__image_output_folder}/{current_epoch}_ref.png"
self.__logger.info(f"saving reference image to {ref_img_filepath}")
if save_image(job_dict[REFERENCE_IMG], ref_img_filepath):
job_dict[REFERENCE_IMG] = ref_img_filepath
if (
self.__image_output_folder
and MASK_IMG in job_dict
and "base64" in job_dict[MASK_IMG]
):
mask_img_filepath = f"{self.__image_output_folder}/{current_epoch}_mask.png"
self.__logger.info(f"saving mask image to {mask_img_filepath}")
if save_image(job_dict[MASK_IMG], mask_img_filepath):
job_dict[MASK_IMG] = mask_img_filepath
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
values.append(job_dict.get(column, None))
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
self.commit()
finally:
release_lock()
return True
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
"""
Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`.
Returns True if the update was successful, otherwise False.
"""
if not job_dict:
return False
# store image to job_dict if has one
if (
self.__image_output_folder
and BASE64IMAGE in job_dict
and "base64" in job_dict[BASE64IMAGE]
):
out_img_filepath = f"{self.__image_output_folder}/{get_epoch_now()}_out.png"
self.__logger.info(f"saving output image to {out_img_filepath}")
if save_image(job_dict[BASE64IMAGE], out_img_filepath):
job_dict[BASE64IMAGE] = out_img_filepath
values = []
columns = []
for column in INTERNAL_KEYS + OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
value = job_dict.get(column, None)
if value is not None:
columns.append(column)
values.append(value)
set_clause = ", ".join([f"{column}=?" for column in columns])
# Add current timestamp to update query
set_clause += ", updated_at=?"
values.append(datetime.datetime.now())
values.append(job_uuid)
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
self.commit()
finally:
release_lock()
return True
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
"""Cancel the job with the given job_uuid or apikey.
If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending".
Args:
job_uuid (str, optional): Unique job identifier. Defaults to "".
apikey (str, optional): API key associated with the job. Defaults to "".
Returns:
bool: True if the job was cancelled successfully, False otherwise.
"""
return self.delete_job(
job_uuid=job_uuid, apikey=apikey, status=VALUE_JOB_PENDING
)
def delete_job(
self, job_uuid: str = "", apikey: str = "", status: str = ""
) -> bool:
if not job_uuid and not apikey:
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
return False
query = f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?"
if status:
query += f" AND {KEY_JOB_STATUS}=?"
values = []
if job_uuid:
values.append(job_uuid)
elif apikey:
values.append(apikey)
if status:
values.append(status)
rows_removed = 0
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
rows_removed = c.rowcount
self.commit()
finally:
release_lock()
if rows_removed == 0:
self.__logger.info("No matching rows found.")
return False
self.__logger.info(f"{rows_removed} rows removed.")
return True
def safe_disconnect(self):
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
self.commit()
self.__connect.close()
self.__logger.info("Disconnected from database.")