diff --git a/manage_db.py b/manage_db.py index 969cbc0..222bf70 100644 --- a/manage_db.py +++ b/manage_db.py @@ -18,51 +18,76 @@ def release_lock(lock_file): lock_fd.close() -def create_table_users(c): - """Create the users table if it doesn't exist""" +USERS_TABLE_COLUMNS = [ + "id INTEGER PRIMARY KEY AUTOINCREMENT", + "username TEXT UNIQUE", + "apikey TEXT", + "quota INT DEFAULT 50", +] + +HISTORY_TABLE_COLUMNS = [ + "uuid TEXT PRIMARY KEY", + "created_at TIMESTAMP", + "updated_at TIMESTAMP", + "apikey TEXT", + "priority INT", + "type TEXT", + "status TEXT", + "prompt TEXT", + "lang TEXT", + "neg_prompt TEXT", + "seed TEXT", + "ref_img TEXT", + "mask_img TEXT", + "img TEXT", + "width INT", + "height INT", + "guidance_scale FLOAT", + "steps INT", + "scheduler TEXT", + "strength FLOAT", + "base_model TEXT", + "lora_model TEXT", + "is_private BOOLEAN DEFAULT False", +] + + +def create_or_update_table(c, table_name): c.execute( - """CREATE TABLE IF NOT EXISTS users - (id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE, - apikey TEXT, - quota INT DEFAULT 50 - )""" + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,) ) + existing_table = c.fetchone() + if table_name == "users": + target_columns = USERS_TABLE_COLUMNS + elif table_name == "history": + target_columns = HISTORY_TABLE_COLUMNS + else: + target_columns = [] -def create_table_history(c): - """Create the history table if it doesn't exist""" - c.execute( - """CREATE TABLE IF NOT EXISTS history - (uuid TEXT PRIMARY KEY, - created_at TIMESTAMP, - updated_at TIMESTAMP, - apikey TEXT, - priority INT, - type TEXT, - status TEXT, - prompt TEXT, - lang TEXT, - neg_prompt TEXT, - seed TEXT, - ref_img TEXT, - mask_img TEXT, - img TEXT, - width INT, - height INT, - guidance_scale FLOAT, - steps INT, - scheduler TEXT, - strength FLOAT, - base_model TEXT, - lora_model TEXT - )""" - ) + if existing_table is None: + # Table doesn't exist, so create it + create_table_query = f"CREATE TABLE {table_name} {', '.join(target_columns)}" + c.execute(create_table_query) + print(f"Table '{table_name}' created successfully.") + else: + # Table exists, check if any columns are missing + c.execute(f"PRAGMA table_info({table_name})") + existing_columns = [column[1] for column in c.fetchall()] + missing_columns = [ + column + for column in target_columns + if column.strip().split()[0] not in existing_columns + ] - -def create_or_update_table(c): - create_table_users(c) - create_table_history(c) + if len(missing_columns) > 0: + # Update the table to add missing columns + for column in missing_columns: + alter_table_query = ( + f"ALTER TABLE {table_name} ADD COLUMN {column.strip()}" + ) + c.execute(alter_table_query) + print(f"Column '{column.strip()}' added to table '{table_name}'.") def modify_table(c, table_name, operation, column_name=None, data_type=None): @@ -138,7 +163,11 @@ def delete_jobs(c, job_uuid="", username=""): rows = c.fetchall() for row in rows: for filepath in row: - if filepath is None or 'base64' in filepath or not os.path.isfile(filepath): + if ( + filepath is None + or "base64" in filepath + or not os.path.isfile(filepath) + ): continue try: os.remove(filepath) @@ -160,7 +189,7 @@ def delete_jobs(c, job_uuid="", username=""): print(f"nothing is found with {job_uuid}") return for filepath in result: - if filepath is None or 'base64' in filepath or not os.path.isfile(filepath): + if filepath is None or "base64" in filepath or not os.path.isfile(filepath): continue try: os.remove(filepath) @@ -173,6 +202,7 @@ def delete_jobs(c, job_uuid="", username=""): ) print(f"removed {c.rowcount} entries") + def show_users(c, username="", details=False): """Print all users in the users table if username is not specified, or only the user with the given username otherwise""" @@ -185,7 +215,7 @@ def show_users(c, username="", details=False): print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") if details: c.execute( - "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?", + "SELECT uuid, created_at, updated_at, type, status, width, height, steps, img, ref_img, mask_img, is_private FROM history WHERE apikey=?", (user[1],), ) rows = c.fetchall() @@ -202,7 +232,7 @@ def show_users(c, username="", details=False): print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}") if details: c.execute( - "SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt, img, ref_img, mask_img FROM history WHERE apikey=?", + "SELECT * FROM history WHERE apikey=?", (user[1],), ) rows = c.fetchall() @@ -230,7 +260,8 @@ def manage(args): c = conn.cursor() # Create the users and history tables if they don't exist - create_or_update_table(c) + create_or_update_table(c, "users") + create_or_update_table(c, "history") # Perform the requested action if args.action == "create": @@ -279,9 +310,7 @@ def manage(args): def main(): # Parse command-line arguments parser = argparse.ArgumentParser() - parser.add_argument( - "--debug", action="store_true", help="Enable debugging mode" - ) + parser.add_argument("--debug", action="store_true", help="Enable debugging mode") subparsers = parser.add_subparsers(dest="action", required=True) # Sub-parser for the "create" action