[Manage] Adds new field is_private and auto updates table if new columns added

This commit is contained in:
HappyZ 2023-05-25 12:52:36 -07:00
parent b9e6ce5253
commit 00fdf87345
1 changed files with 77 additions and 48 deletions

View File

@ -18,51 +18,76 @@ def release_lock(lock_file):
lock_fd.close()
def create_table_users(c):
"""Create the users table if it doesn't exist"""
USERS_TABLE_COLUMNS = [
"id INTEGER PRIMARY KEY AUTOINCREMENT",
"username TEXT UNIQUE",
"apikey TEXT",
"quota INT DEFAULT 50",
]
HISTORY_TABLE_COLUMNS = [
"uuid TEXT PRIMARY KEY",
"created_at TIMESTAMP",
"updated_at TIMESTAMP",
"apikey TEXT",
"priority INT",
"type TEXT",
"status TEXT",
"prompt TEXT",
"lang TEXT",
"neg_prompt TEXT",
"seed TEXT",
"ref_img TEXT",
"mask_img TEXT",
"img TEXT",
"width INT",
"height INT",
"guidance_scale FLOAT",
"steps INT",
"scheduler TEXT",
"strength FLOAT",
"base_model TEXT",
"lora_model TEXT",
"is_private BOOLEAN DEFAULT False",
]
def create_or_update_table(c, table_name):
c.execute(
"""CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
apikey TEXT,
quota INT DEFAULT 50
)"""
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)
)
existing_table = c.fetchone()
if table_name == "users":
target_columns = USERS_TABLE_COLUMNS
elif table_name == "history":
target_columns = HISTORY_TABLE_COLUMNS
else:
target_columns = []
def create_table_history(c):
"""Create the history table if it doesn't exist"""
c.execute(
"""CREATE TABLE IF NOT EXISTS history
(uuid TEXT PRIMARY KEY,
created_at TIMESTAMP,
updated_at TIMESTAMP,
apikey TEXT,
priority INT,
type TEXT,
status TEXT,
prompt TEXT,
lang TEXT,
neg_prompt TEXT,
seed TEXT,
ref_img TEXT,
mask_img TEXT,
img TEXT,
width INT,
height INT,
guidance_scale FLOAT,
steps INT,
scheduler TEXT,
strength FLOAT,
base_model TEXT,
lora_model TEXT
)"""
)
if existing_table is None:
# Table doesn't exist, so create it
create_table_query = f"CREATE TABLE {table_name} {', '.join(target_columns)}"
c.execute(create_table_query)
print(f"Table '{table_name}' created successfully.")
else:
# Table exists, check if any columns are missing
c.execute(f"PRAGMA table_info({table_name})")
existing_columns = [column[1] for column in c.fetchall()]
missing_columns = [
column
for column in target_columns
if column.strip().split()[0] not in existing_columns
]
def create_or_update_table(c):
create_table_users(c)
create_table_history(c)
if len(missing_columns) > 0:
# Update the table to add missing columns
for column in missing_columns:
alter_table_query = (
f"ALTER TABLE {table_name} ADD COLUMN {column.strip()}"
)
c.execute(alter_table_query)
print(f"Column '{column.strip()}' added to table '{table_name}'.")
def modify_table(c, table_name, operation, column_name=None, data_type=None):
@ -138,7 +163,11 @@ def delete_jobs(c, job_uuid="", username=""):
rows = c.fetchall()
for row in rows:
for filepath in row:
if filepath is None or 'base64' in filepath or not os.path.isfile(filepath):
if (
filepath is None
or "base64" in filepath
or not os.path.isfile(filepath)
):
continue
try:
os.remove(filepath)
@ -160,7 +189,7 @@ def delete_jobs(c, job_uuid="", username=""):
print(f"nothing is found with {job_uuid}")
return
for filepath in result:
if filepath is None or 'base64' in filepath or not os.path.isfile(filepath):
if filepath is None or "base64" in filepath or not os.path.isfile(filepath):
continue
try:
os.remove(filepath)
@ -173,6 +202,7 @@ def delete_jobs(c, job_uuid="", username=""):
)
print(f"removed {c.rowcount} entries")
def show_users(c, username="", details=False):
"""Print all users in the users table if username is not specified,
or only the user with the given username otherwise"""
@ -185,7 +215,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, img, ref_img, mask_img, is_private FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()
@ -202,7 +232,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?",
"SELECT * FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()
@ -230,7 +260,8 @@ def manage(args):
c = conn.cursor()
# Create the users and history tables if they don't exist
create_or_update_table(c)
create_or_update_table(c, "users")
create_or_update_table(c, "history")
# Perform the requested action
if args.action == "create":
@ -279,9 +310,7 @@ def manage(args):
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--debug", action="store_true", help="Enable debugging mode"
)
parser.add_argument("--debug", action="store_true", help="Enable debugging mode")
subparsers = parser.add_subparsers(dest="action", required=True)
# Sub-parser for the "create" action