[Manage] Adds new field is_private and auto updates table if new columns added

This commit is contained in:
HappyZ 2023-05-25 12:52:36 -07:00
parent b9e6ce5253
commit 00fdf87345
1 changed files with 77 additions and 48 deletions

View File

@ -18,51 +18,76 @@ def release_lock(lock_file):
lock_fd.close() lock_fd.close()
def create_table_users(c): USERS_TABLE_COLUMNS = [
"""Create the users table if it doesn't exist""" "id INTEGER PRIMARY KEY AUTOINCREMENT",
"username TEXT UNIQUE",
"apikey TEXT",
"quota INT DEFAULT 50",
]
HISTORY_TABLE_COLUMNS = [
"uuid TEXT PRIMARY KEY",
"created_at TIMESTAMP",
"updated_at TIMESTAMP",
"apikey TEXT",
"priority INT",
"type TEXT",
"status TEXT",
"prompt TEXT",
"lang TEXT",
"neg_prompt TEXT",
"seed TEXT",
"ref_img TEXT",
"mask_img TEXT",
"img TEXT",
"width INT",
"height INT",
"guidance_scale FLOAT",
"steps INT",
"scheduler TEXT",
"strength FLOAT",
"base_model TEXT",
"lora_model TEXT",
"is_private BOOLEAN DEFAULT False",
]
def create_or_update_table(c, table_name):
c.execute( c.execute(
"""CREATE TABLE IF NOT EXISTS users "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
apikey TEXT,
quota INT DEFAULT 50
)"""
) )
existing_table = c.fetchone()
if table_name == "users":
target_columns = USERS_TABLE_COLUMNS
elif table_name == "history":
target_columns = HISTORY_TABLE_COLUMNS
else:
target_columns = []
def create_table_history(c): if existing_table is None:
"""Create the history table if it doesn't exist""" # Table doesn't exist, so create it
c.execute( create_table_query = f"CREATE TABLE {table_name} {', '.join(target_columns)}"
"""CREATE TABLE IF NOT EXISTS history c.execute(create_table_query)
(uuid TEXT PRIMARY KEY, print(f"Table '{table_name}' created successfully.")
created_at TIMESTAMP, else:
updated_at TIMESTAMP, # Table exists, check if any columns are missing
apikey TEXT, c.execute(f"PRAGMA table_info({table_name})")
priority INT, existing_columns = [column[1] for column in c.fetchall()]
type TEXT, missing_columns = [
status TEXT, column
prompt TEXT, for column in target_columns
lang TEXT, if column.strip().split()[0] not in existing_columns
neg_prompt TEXT, ]
seed TEXT,
ref_img TEXT,
mask_img TEXT,
img TEXT,
width INT,
height INT,
guidance_scale FLOAT,
steps INT,
scheduler TEXT,
strength FLOAT,
base_model TEXT,
lora_model TEXT
)"""
)
if len(missing_columns) > 0:
def create_or_update_table(c): # Update the table to add missing columns
create_table_users(c) for column in missing_columns:
create_table_history(c) alter_table_query = (
f"ALTER TABLE {table_name} ADD COLUMN {column.strip()}"
)
c.execute(alter_table_query)
print(f"Column '{column.strip()}' added to table '{table_name}'.")
def modify_table(c, table_name, operation, column_name=None, data_type=None): def modify_table(c, table_name, operation, column_name=None, data_type=None):
@ -138,7 +163,11 @@ def delete_jobs(c, job_uuid="", username=""):
rows = c.fetchall() rows = c.fetchall()
for row in rows: for row in rows:
for filepath in row: for filepath in row:
if filepath is None or 'base64' in filepath or not os.path.isfile(filepath): if (
filepath is None
or "base64" in filepath
or not os.path.isfile(filepath)
):
continue continue
try: try:
os.remove(filepath) os.remove(filepath)
@ -160,7 +189,7 @@ def delete_jobs(c, job_uuid="", username=""):
print(f"nothing is found with {job_uuid}") print(f"nothing is found with {job_uuid}")
return return
for filepath in result: for filepath in result:
if filepath is None or 'base64' in filepath or not os.path.isfile(filepath): if filepath is None or "base64" in filepath or not os.path.isfile(filepath):
continue continue
try: try:
os.remove(filepath) os.remove(filepath)
@ -173,6 +202,7 @@ def delete_jobs(c, job_uuid="", username=""):
) )
print(f"removed {c.rowcount} entries") print(f"removed {c.rowcount} entries")
def show_users(c, username="", details=False): def show_users(c, username="", details=False):
"""Print all users in the users table if username is not specified, """Print all users in the users table if username is not specified,
or only the user with the given username otherwise""" or only the user with the given username otherwise"""
@ -185,7 +215,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details: if details:
c.execute( c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?", "SELECT uuid, created_at, updated_at, type, status, width, height, steps, img, ref_img, mask_img, is_private FROM history WHERE apikey=?",
(user[1],), (user[1],),
) )
rows = c.fetchall() rows = c.fetchall()
@ -202,7 +232,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details: if details:
c.execute( c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?", "SELECT * FROM history WHERE apikey=?",
(user[1],), (user[1],),
) )
rows = c.fetchall() rows = c.fetchall()
@ -230,7 +260,8 @@ def manage(args):
c = conn.cursor() c = conn.cursor()
# Create the users and history tables if they don't exist # Create the users and history tables if they don't exist
create_or_update_table(c) create_or_update_table(c, "users")
create_or_update_table(c, "history")
# Perform the requested action # Perform the requested action
if args.action == "create": if args.action == "create":
@ -279,9 +310,7 @@ def manage(args):
def main(): def main():
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--debug", action="store_true", help="Enable debugging mode")
"--debug", action="store_true", help="Enable debugging mode"
)
subparsers = parser.add_subparsers(dest="action", required=True) subparsers = parser.add_subparsers(dest="action", required=True)
# Sub-parser for the "create" action # Sub-parser for the "create" action