transfer to use sqlite3 db instead of internal memory
This commit is contained in:
parent
11948820f8
commit
db6d3a94a7
|
|
@ -6,6 +6,9 @@ __pycache__/
|
|||
# C extensions
|
||||
*.so
|
||||
|
||||
# sqlite3 db
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
|
|
|||
1
BUILD
1
BUILD
|
|
@ -8,6 +8,7 @@ par_binary(
|
|||
srcs=["main.py"],
|
||||
deps=[
|
||||
"//utilities:constants",
|
||||
"//utilities:database",
|
||||
"//utilities:logger",
|
||||
"//utilities:model",
|
||||
"//utilities:text2img",
|
||||
|
|
|
|||
193
main.py
193
main.py
|
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import copy
|
||||
import tempfile
|
||||
import pkgutil
|
||||
|
|
@ -10,8 +11,7 @@ from threading import Event
|
|||
from threading import Thread
|
||||
from threading import Lock
|
||||
|
||||
from utilities.constants import API_KEY
|
||||
from utilities.constants import API_KEY_FOR_DEMO
|
||||
from utilities.constants import APIKEY
|
||||
from utilities.constants import KEY_APP
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
|
|
@ -33,6 +33,7 @@ from utilities.constants import VALUE_JOB_PENDING
|
|||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
from utilities.constants import VALUE_JOB_FAILED
|
||||
from utilities.database import Database
|
||||
from utilities.envvar import get_env_var_with_default
|
||||
from utilities.envvar import get_env_var
|
||||
from utilities.times import wait_for_seconds
|
||||
|
|
@ -44,10 +45,10 @@ from utilities.img2img import Img2Img
|
|||
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['TESTING'] = False
|
||||
memory_lock = Lock()
|
||||
event_termination = Event()
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
database = Database(logger)
|
||||
use_gpu = True
|
||||
|
||||
local_job_stack = []
|
||||
|
|
@ -57,13 +58,14 @@ local_completed_jobs = []
|
|||
@app.route("/add_job", methods=["POST"])
|
||||
def add_job():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
|
||||
if APIKEY not in req:
|
||||
logger.error(f"{APIKEY} not present in {req}")
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
with memory_lock:
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
logger.error(f"user not found with {req[APIKEY]}")
|
||||
return "", 401
|
||||
|
||||
for key in req.keys():
|
||||
|
|
@ -76,71 +78,58 @@ def add_job():
|
|||
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||
return jsonify({"msg": "missing reference image"}), 404
|
||||
|
||||
if len(local_job_stack) > MAX_JOB_NUMBER:
|
||||
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
||||
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
|
||||
return (
|
||||
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
|
||||
500,
|
||||
)
|
||||
|
||||
req[UUID] = str(uuid.uuid4())
|
||||
logger.info("adding a new job with uuid {}..".format(req[UUID]))
|
||||
|
||||
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
|
||||
req["position"] = len(local_job_stack) + 1
|
||||
job_uuid = str(uuid.uuid4())
|
||||
logger.info("adding a new job with uuid {}..".format(job_uuid))
|
||||
|
||||
with memory_lock:
|
||||
local_job_stack.append(req)
|
||||
database.insert_new_job(req, job_uuid=job_uuid)
|
||||
|
||||
return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]})
|
||||
return jsonify({"msg": "", UUID: job_uuid})
|
||||
|
||||
|
||||
@app.route("/cancel_job", methods=["POST"])
|
||||
def cancel_job():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
with memory_lock:
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
if UUID not in req:
|
||||
return jsonify({"msg": "missing uuid"}), 404
|
||||
|
||||
logger.info("removing job with uuid {}..".format(req[UUID]))
|
||||
logger.info("cancelling job with uuid {}..".format(req[UUID]))
|
||||
|
||||
cancel_job_position = None
|
||||
with memory_lock:
|
||||
for job_position in range(len(local_job_stack)):
|
||||
if local_job_stack[job_position][UUID] == req[UUID]:
|
||||
cancel_job_position = job_position
|
||||
break
|
||||
logger.info("foud {}".format(cancel_job_position))
|
||||
if cancel_job_position is not None:
|
||||
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
|
||||
return "", 401
|
||||
if (
|
||||
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
|
||||
== VALUE_JOB_RUNNING
|
||||
):
|
||||
logger.info(
|
||||
"job at {} with uuid {} is running and cannot be cancelled".format(
|
||||
cancel_job_position, req[UUID]
|
||||
)
|
||||
)
|
||||
result = database.cancel_job(job_uuid=req[UUID])
|
||||
|
||||
if result:
|
||||
msg = "job with uuid {} removed".format(req[UUID])
|
||||
return jsonify({"msg": msg})
|
||||
|
||||
with memory_lock:
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
if jobs:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"msg": "job {} is already running, unable to cancel".format(
|
||||
"msg": "job {} is not in pending state, unable to cancel".format(
|
||||
req[UUID]
|
||||
)
|
||||
}
|
||||
),
|
||||
405,
|
||||
)
|
||||
del local_job_stack[cancel_job_position]
|
||||
msg = "job with uuid {} removed".format(req[UUID])
|
||||
logger.info(msg)
|
||||
return jsonify({"msg": msg})
|
||||
|
||||
return (
|
||||
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
||||
404,
|
||||
|
|
@ -150,37 +139,16 @@ def cancel_job():
|
|||
@app.route("/get_jobs", methods=["POST"])
|
||||
def get_jobs():
|
||||
req = request.get_json()
|
||||
if API_KEY not in req:
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
||||
return "", 401
|
||||
else:
|
||||
# TODO: add logic to validate app key with a particular user
|
||||
return "", 401
|
||||
|
||||
jobs = []
|
||||
|
||||
all_job_stack = local_job_stack + local_completed_jobs
|
||||
with memory_lock:
|
||||
for job_position in range(len(all_job_stack)):
|
||||
# filter on API_KEY
|
||||
if all_job_stack[job_position][API_KEY] != req[API_KEY]:
|
||||
continue
|
||||
# filter on UUID
|
||||
if UUID in req and req[UUID] != all_job_stack[job_position][UUID]:
|
||||
continue
|
||||
job = copy.deepcopy(all_job_stack[job_position])
|
||||
if job[KEY_JOB_STATUS] == VALUE_JOB_DONE:
|
||||
del job["position"]
|
||||
del job[API_KEY]
|
||||
jobs.append(job)
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
with memory_lock:
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
if len(jobs) == 0:
|
||||
return (
|
||||
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
|
||||
404,
|
||||
)
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
|
|
@ -213,7 +181,7 @@ def load_model(logger: Logger) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
def backend(event_termination):
|
||||
def backend(event_termination, db):
|
||||
model = load_model(logger)
|
||||
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||
|
|
@ -225,13 +193,18 @@ def backend(event_termination):
|
|||
wait_for_seconds(1)
|
||||
|
||||
with memory_lock:
|
||||
if len(local_job_stack) == 0:
|
||||
continue
|
||||
next_job = local_job_stack[0]
|
||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
|
||||
pending_jobs = database.get_all_pending_jobs()
|
||||
|
||||
prompt = next_job[KEY_PROMPT.lower()]
|
||||
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
|
||||
if len(pending_jobs) == 0:
|
||||
continue
|
||||
|
||||
next_job = pending_jobs[0]
|
||||
|
||||
with memory_lock:
|
||||
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
|
||||
|
||||
prompt = next_job[KEY_PROMPT]
|
||||
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||
|
||||
config = Config().set_config(next_job)
|
||||
|
||||
|
|
@ -250,33 +223,35 @@ def backend(event_termination):
|
|||
)
|
||||
except BaseException as e:
|
||||
logger.error("text2img.lunch error: {}".format(e))
|
||||
local_job_stack.pop(0)
|
||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_FAILED
|
||||
local_completed_jobs.append(next_job)
|
||||
with memory_lock:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
|
||||
)
|
||||
continue
|
||||
|
||||
with memory_lock:
|
||||
local_job_stack.pop(0)
|
||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
|
||||
next_job.update(result_dict)
|
||||
local_completed_jobs.append(next_job)
|
||||
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
|
||||
database.update_job(result_dict, job_uuid=next_job[UUID])
|
||||
|
||||
logger.critical("stopped")
|
||||
if len(local_job_stack) > 0:
|
||||
logger.info(
|
||||
"remaining {} jobs in stack: {}".format(
|
||||
len(local_job_stack), local_job_stack
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if app.testing:
|
||||
def main(db_filepath, is_testing: bool = False):
|
||||
database.connect(db_filepath)
|
||||
|
||||
if is_testing:
|
||||
try:
|
||||
app.run(host="0.0.0.0", port="5000")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return
|
||||
thread = Thread(target=backend, args=(event_termination,))
|
||||
thread = Thread(
|
||||
target=backend,
|
||||
args=(
|
||||
event_termination,
|
||||
database,
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
# ugly solution for now
|
||||
# TODO: use a database to track instead of internal memory
|
||||
|
|
@ -285,8 +260,24 @@ def main():
|
|||
thread.join()
|
||||
except KeyboardInterrupt:
|
||||
event_termination.set()
|
||||
thread.join(1)
|
||||
|
||||
database.safe_disconnect()
|
||||
|
||||
thread.join(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add an argument to set the 'testing' flag
|
||||
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
parser.add_argument(
|
||||
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(args)
|
||||
|
||||
main(args.db, args.testing)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,155 @@
|
|||
import argparse
|
||||
import sqlite3
|
||||
import uuid
|
||||
|
||||
|
||||
def create_table_users(c):
|
||||
"""Create the users table if it doesn't exist"""
|
||||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS users
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE,
|
||||
apikey TEXT)"""
|
||||
)
|
||||
|
||||
|
||||
def create_table_history(c):
|
||||
"""Create the history table if it doesn't exist"""
|
||||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS history
|
||||
(uuid TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
apikey TEXT,
|
||||
priority INT,
|
||||
type TEXT,
|
||||
status TEXT,
|
||||
prompt TEXT,
|
||||
neg_prompt TEXT,
|
||||
seed TEXT,
|
||||
ref_img TEXT,
|
||||
img TEXT,
|
||||
width INT,
|
||||
height INT,
|
||||
guidance_scale FLOAT,
|
||||
steps INT,
|
||||
scheduler TEXT,
|
||||
strength FLOAT,
|
||||
base_model TEXT,
|
||||
lora_model TEXT
|
||||
)"""
|
||||
)
|
||||
|
||||
|
||||
def create_user(c, username, apikey):
|
||||
"""Create a user with the given username and apikey, or update the apikey if the username already exists"""
|
||||
c.execute("SELECT * FROM users WHERE username=?", (username,))
|
||||
result = c.fetchone()
|
||||
if result is not None:
|
||||
raise ValueError(f"found exisitng user {username}, please use update")
|
||||
else:
|
||||
c.execute("INSERT INTO users (username, apikey) VALUES (?, ?)", (username, apikey))
|
||||
|
||||
|
||||
def update_user(c, username, apikey):
|
||||
"""Update the apikey for the user with the given username"""
|
||||
c.execute("SELECT apikey FROM users WHERE username=?", (username,))
|
||||
result = c.fetchone()
|
||||
if result is not None:
|
||||
old_apikey = result[0]
|
||||
c.execute("UPDATE history SET apikey=? WHERE apikey=?", (apikey, old_apikey))
|
||||
c.execute("UPDATE users SET apikey=? WHERE username=?", (apikey, username))
|
||||
else:
|
||||
raise ValueError("username does not exist! create it first?")
|
||||
|
||||
|
||||
def delete_user(c, username):
|
||||
"""Delete the user with the given username, or ignore the operation if the user does not exist"""
|
||||
c.execute("DELETE FROM history WHERE apikey=(SELECT apikey FROM users WHERE username=?)", (username,))
|
||||
c.execute("DELETE FROM users WHERE username=?", (username,))
|
||||
|
||||
def show_users(c, username="", details=False):
|
||||
"""Print all users in the users table if username is not specified,
|
||||
or only the user with the given username otherwise"""
|
||||
if username:
|
||||
c.execute("SELECT username, apikey FROM users WHERE username=?", (username,))
|
||||
user = c.fetchone()
|
||||
if user:
|
||||
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
|
||||
count = c.fetchone()[0]
|
||||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
|
||||
result = c.fetchall()
|
||||
print(result)
|
||||
else:
|
||||
print(f"No user with username '{username}' found")
|
||||
else:
|
||||
c.execute("SELECT username, apikey FROM users")
|
||||
users = c.fetchall()
|
||||
for user in users:
|
||||
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
|
||||
count = c.fetchone()[0]
|
||||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
|
||||
result = c.fetchall()
|
||||
print(result)
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="action")
|
||||
|
||||
# Sub-parser for the "create" action
|
||||
create_parser = subparsers.add_parser("create")
|
||||
create_parser.add_argument("username")
|
||||
create_parser.add_argument("apikey")
|
||||
|
||||
# Sub-parser for the "update" action
|
||||
update_parser = subparsers.add_parser("update")
|
||||
update_parser.add_argument("username")
|
||||
update_parser.add_argument("apikey")
|
||||
|
||||
# Sub-parser for the "delete" action
|
||||
delete_parser = subparsers.add_parser("delete")
|
||||
delete_parser.add_argument("username")
|
||||
|
||||
# Sub-parser for the "delete" action
|
||||
list_parser = subparsers.add_parser("list")
|
||||
list_parser.add_argument("username", nargs="?", default="")
|
||||
list_parser.add_argument("--details", action="store_true", help="Showing more details")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Connect to the database (creates a new file if it doesn't exist)
|
||||
conn = sqlite3.connect("happysd.db")
|
||||
c = conn.cursor()
|
||||
|
||||
# Create the users and history tables if they don't exist
|
||||
create_table_users(c)
|
||||
create_table_history(c)
|
||||
|
||||
# Perform the requested action
|
||||
if args.action == "create":
|
||||
create_user(c, args.username, args.apikey)
|
||||
print("User created")
|
||||
elif args.action == "update":
|
||||
update_user(c, args.username, args.apikey)
|
||||
print("User updated")
|
||||
elif args.action == "delete":
|
||||
delete_user(c, args.username)
|
||||
print("User deleted")
|
||||
elif args.action == "list":
|
||||
show_users(c, args.username, args.details)
|
||||
|
||||
# Commit the changes to the database
|
||||
conn.commit()
|
||||
|
||||
# Close the connection
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -213,7 +213,7 @@
|
|||
url: '/get_jobs',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
dataType: 'json',
|
||||
data: JSON.stringify({ 'api_key': apikeyVal, 'uuid': uuidValue }),
|
||||
data: JSON.stringify({ 'apikey': apikeyVal, 'uuid': uuidValue }),
|
||||
success: function (response) {
|
||||
console.log(response);
|
||||
if (response.jobs.length == 1) {
|
||||
|
|
@ -404,7 +404,7 @@
|
|||
contentType: 'application/json; charset=utf-8',
|
||||
dataType: 'json',
|
||||
data: JSON.stringify({
|
||||
'api_key': apikeyVal,
|
||||
'apikey': apikeyVal,
|
||||
'type': 'txt',
|
||||
'prompt': promptVal,
|
||||
'seed': seedVal,
|
||||
|
|
@ -513,7 +513,7 @@
|
|||
contentType: 'application/json; charset=utf-8',
|
||||
dataType: 'json',
|
||||
data: JSON.stringify({
|
||||
'api_key': apikeyVal,
|
||||
'apikey': apikeyVal,
|
||||
'type': 'img',
|
||||
'ref_img': imageData,
|
||||
'prompt': promptVal,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,14 @@ py_library(
|
|||
srcs=["constants.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="database",
|
||||
srcs=["database.py"],
|
||||
deps=[
|
||||
":logger",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
py_library(
|
||||
name="envvar",
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ from utilities.constants import KEY_HEIGHT
|
|||
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
||||
from utilities.constants import KEY_STRENGTH
|
||||
from utilities.constants import VALUE_STRENGTH_DEFAULT
|
||||
from utilities.constants import KEY_PREVIEW
|
||||
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
||||
from utilities.constants import KEY_SCHEDULER
|
||||
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||
from utilities.constants import VALUE_SCHEDULER_DDIM
|
||||
|
|
@ -84,16 +82,6 @@ class Config:
|
|||
self.__config[KEY_HEIGHT] = value
|
||||
return self
|
||||
|
||||
def get_preview(self) -> bool:
|
||||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||
|
||||
def set_preview(self, boolean: bool):
|
||||
self.__logger.info(
|
||||
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
|
||||
)
|
||||
self.__config[KEY_PREVIEW] = boolean
|
||||
return self
|
||||
|
||||
def get_scheduler(self) -> str:
|
||||
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,71 +7,81 @@ LOGGER_NAME_IMG2IMG = "img2img"
|
|||
MAX_JOB_NUMBER = 10
|
||||
|
||||
|
||||
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
||||
|
||||
KEY_OUTPUT_FOLDER = "outfolder"
|
||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||
|
||||
KEY_SEED = "SEED"
|
||||
VALUE_SEED_DEFAULT = 0
|
||||
#
|
||||
# Database
|
||||
#
|
||||
HISTORY_TABLE_NAME = "history"
|
||||
USERS_TABLE_NAME = "users"
|
||||
|
||||
KEY_WIDTH = "WIDTH"
|
||||
VALUE_WIDTH_DEFAULT = 512
|
||||
#
|
||||
# REST API Keys
|
||||
#
|
||||
|
||||
KEY_HEIGHT = "HEIGHT"
|
||||
VALUE_HEIGHT_DEFAULT = 512
|
||||
# - input and output
|
||||
APIKEY = "apikey"
|
||||
|
||||
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0
|
||||
KEY_JOB_TYPE = "type"
|
||||
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
|
||||
VALUE_JOB_IMG2IMG = "img"
|
||||
REFERENCE_IMG = "ref_img"
|
||||
VALUE_JOB_INPAINTING = "inpaint"
|
||||
|
||||
KEY_STRENGTH = "STRENGTH"
|
||||
VALUE_STRENGTH_DEFAULT = 0.5
|
||||
|
||||
KEY_STEPS = "STEPS"
|
||||
VALUE_STEPS_DEFAULT = 50
|
||||
|
||||
KEY_SCHEDULER = "SCHEDULER"
|
||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||
KEY_PROMPT = "prompt"
|
||||
KEY_NEG_PROMPT = "neg_prompt"
|
||||
KEY_SEED = "seed"
|
||||
VALUE_SEED_DEFAULT = 0 # default value for KEY_SEED
|
||||
KEY_WIDTH = "width"
|
||||
VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH
|
||||
KEY_HEIGHT = "height"
|
||||
VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT
|
||||
KEY_GUIDANCE_SCALE = "guidance_scale"
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE
|
||||
KEY_STEPS = "steps"
|
||||
VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS
|
||||
KEY_SCHEDULER = "scheduler"
|
||||
VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER
|
||||
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
|
||||
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
|
||||
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
|
||||
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
|
||||
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
|
||||
KEY_STRENGTH = "strength"
|
||||
VALUE_STRENGTH_DEFAULT = 0.5 # default value for KEY_STRENGTH
|
||||
|
||||
KEY_PROMPT = "PROMPT"
|
||||
KEY_NEG_PROMPT = "NEG_PROMPT"
|
||||
REQUIRED_KEYS = [
|
||||
APIKEY, # str
|
||||
KEY_PROMPT, # str
|
||||
KEY_JOB_TYPE, # str
|
||||
]
|
||||
OPTIONAL_KEYS = [
|
||||
KEY_NEG_PROMPT, # str
|
||||
KEY_SEED, # str
|
||||
KEY_WIDTH, # int
|
||||
KEY_HEIGHT, # int
|
||||
KEY_GUIDANCE_SCALE, # float
|
||||
KEY_STEPS, # int
|
||||
KEY_SCHEDULER, # str
|
||||
KEY_STRENGTH, # float
|
||||
REFERENCE_IMG, # str (base64)
|
||||
]
|
||||
|
||||
KEY_PREVIEW = "PREVIEW"
|
||||
VALUE_PREVIEW_DEFAULT = True
|
||||
|
||||
# REST API Keys
|
||||
API_KEY = "api_key"
|
||||
API_KEY_FOR_DEMO = "demo"
|
||||
# - output only
|
||||
UUID = "uuid"
|
||||
|
||||
BASE64IMAGE = "img"
|
||||
KEY_PRIORITY = "priority"
|
||||
KEY_JOB_STATUS = "status"
|
||||
VALUE_JOB_PENDING = "pending"
|
||||
VALUE_JOB_PENDING = "pending" # default value for KEY_JOB_STATUS
|
||||
VALUE_JOB_RUNNING = "running"
|
||||
VALUE_JOB_DONE = "done"
|
||||
VALUE_JOB_FAILED = "failed"
|
||||
KEY_JOB_TYPE = "type"
|
||||
VALUE_JOB_TXT2IMG = "txt"
|
||||
VALUE_JOB_IMG2IMG = "img"
|
||||
VALUE_JOB_INPAINTING = "inpaint"
|
||||
REFERENCE_IMG = "ref_img"
|
||||
|
||||
REQUIRED_KEYS = [
|
||||
API_KEY.lower(),
|
||||
KEY_PROMPT.lower(),
|
||||
KEY_JOB_TYPE.lower(),
|
||||
]
|
||||
OPTIONAL_KEYS = [
|
||||
KEY_NEG_PROMPT.lower(),
|
||||
KEY_SEED.lower(),
|
||||
KEY_WIDTH.lower(),
|
||||
KEY_HEIGHT.lower(),
|
||||
KEY_GUIDANCE_SCALE.lower(),
|
||||
KEY_STEPS.lower(),
|
||||
KEY_SCHEDULER.lower(),
|
||||
KEY_STRENGTH.lower(),
|
||||
REFERENCE_IMG.lower(),
|
||||
OUTPUT_ONLY_KEYS = [
|
||||
UUID, # str
|
||||
KEY_PRIORITY, # int
|
||||
BASE64IMAGE, # str (base64)
|
||||
KEY_JOB_STATUS, # str
|
||||
]
|
||||
|
|
@ -0,0 +1,262 @@
|
|||
import os
|
||||
import datetime
|
||||
import sqlite3
|
||||
import uuid
|
||||
|
||||
from utilities.constants import APIKEY
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import KEY_PRIORITY
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import VALUE_JOB_INPAINTING
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import VALUE_JOB_PENDING
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
|
||||
from utilities.constants import OUTPUT_ONLY_KEYS
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.constants import REQUIRED_KEYS
|
||||
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import BASE64IMAGE
|
||||
|
||||
from utilities.constants import HISTORY_TABLE_NAME
|
||||
from utilities.constants import USERS_TABLE_NAME
|
||||
from utilities.logger import DummyLogger
|
||||
|
||||
|
||||
class Database:
|
||||
"""This class represents a SQLite database and assumes single-thread usage."""
|
||||
|
||||
def __init__(self, logger: DummyLogger = DummyLogger()):
|
||||
"""Initialize the class with a logger instance, but without a database connection or cursor."""
|
||||
self.__connect = None # the database connection object
|
||||
self.__cursor = None # the cursor object for executing SQL statements
|
||||
self.__logger = logger # the logger object for logging messages
|
||||
|
||||
def connect(self, db_filepath) -> bool:
|
||||
"""
|
||||
Connect to the SQLite database file specified by `db_filepath`.
|
||||
|
||||
Returns True if the connection was successful, otherwise False.
|
||||
"""
|
||||
if not os.path.isfile(db_filepath):
|
||||
self.__logger.error(f"{db_filepath} does not exist!")
|
||||
return False
|
||||
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
|
||||
self.__cursor = self.__connect.cursor()
|
||||
self.__logger.info(f"Connected to database {db_filepath}")
|
||||
return True
|
||||
|
||||
def validate_user(self, apikey: str) -> str:
|
||||
"""
|
||||
Validate if the provided API key exists in the users table and return the corresponding
|
||||
username if found, or an empty string otherwise.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return ""
|
||||
|
||||
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
|
||||
self.__cursor.execute(query, (apikey,))
|
||||
result = self.__cursor.fetchone()
|
||||
|
||||
self.__logger.debug(result)
|
||||
|
||||
if result is not None:
|
||||
return result[0] # the first column is the username
|
||||
|
||||
return ""
|
||||
|
||||
def get_all_pending_jobs(self, apikey: str = "") -> list:
|
||||
return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING)
|
||||
|
||||
def count_all_pending_jobs(self, apikey: str) -> int:
|
||||
"""
|
||||
Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key.
|
||||
|
||||
Returns the number of pending jobs found.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return 0
|
||||
|
||||
# Construct the SQL query string and list of arguments
|
||||
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
|
||||
query_args = (apikey, VALUE_JOB_PENDING)
|
||||
|
||||
# Execute the query and return the count
|
||||
self.__cursor.execute(query_string, query_args)
|
||||
result = self.__cursor.fetchone()
|
||||
return result[0]
|
||||
|
||||
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
|
||||
"""
|
||||
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
|
||||
|
||||
If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter.
|
||||
|
||||
Returns a list of jobs matching the filters provided.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return []
|
||||
|
||||
# construct the SQL query string and list of arguments based on the provided filters
|
||||
query_args = []
|
||||
query_filters = []
|
||||
if job_uuid:
|
||||
query_filters.append(f"{UUID} = ?")
|
||||
query_args.append(job_uuid)
|
||||
if apikey:
|
||||
query_filters.append(f"{APIKEY} = ?")
|
||||
query_args.append(apikey)
|
||||
if job_status:
|
||||
query_filters.append(f"{KEY_JOB_STATUS} = ?")
|
||||
query_args.append(job_status)
|
||||
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
|
||||
if query_filters:
|
||||
query_string += f" WHERE {' AND '.join(query_filters)}"
|
||||
|
||||
# execute the query and return the results
|
||||
self.__cursor.execute(query_string, tuple(query_args))
|
||||
rows = self.__cursor.fetchall()
|
||||
|
||||
jobs = []
|
||||
for row in rows:
|
||||
job = {
|
||||
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
|
||||
}
|
||||
jobs.append(job)
|
||||
|
||||
return jobs
|
||||
|
||||
def insert_new_job(self, job_dict: dict, job_uuid="") -> bool:
|
||||
"""
|
||||
Insert a new job into the HISTORY_TABLE_NAME table.
|
||||
|
||||
If `job_uuid` is not provided, a new UUID will be generated automatically.
|
||||
|
||||
Returns True if the insertion was successful, otherwise False.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return False
|
||||
|
||||
if not job_uuid:
|
||||
job_uuid = str(uuid.uuid4())
|
||||
self.__logger.info(f"inserting a new job with {job_uuid}")
|
||||
|
||||
values = [job_uuid, VALUE_JOB_PENDING]
|
||||
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||
values.append(job_dict.get(column, None))
|
||||
|
||||
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
|
||||
self.__cursor.execute(query, tuple(values))
|
||||
self.__connect.commit()
|
||||
return True
|
||||
|
||||
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
|
||||
"""
|
||||
Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`.
|
||||
|
||||
Returns True if the update was successful, otherwise False.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return False
|
||||
|
||||
values = []
|
||||
columns = []
|
||||
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||
value = job_dict.get(column, None)
|
||||
if value is not None:
|
||||
columns.append(column)
|
||||
values.append(value)
|
||||
|
||||
set_clause = ", ".join([f"{column}=?" for column in columns])
|
||||
# Add current timestamp to update query
|
||||
set_clause += ", updated_at=?"
|
||||
values.append(datetime.datetime.now())
|
||||
|
||||
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
|
||||
|
||||
values.append(job_uuid)
|
||||
|
||||
self.__cursor.execute(query, tuple(values))
|
||||
self.__connect.commit()
|
||||
return True
|
||||
|
||||
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||
"""Cancel the job with the given job_uuid or apikey.
|
||||
If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending".
|
||||
|
||||
Args:
|
||||
job_uuid (str, optional): Unique job identifier. Defaults to "".
|
||||
apikey (str, optional): API key associated with the job. Defaults to "".
|
||||
|
||||
Returns:
|
||||
bool: True if the job was cancelled successfully, False otherwise.
|
||||
"""
|
||||
if not job_uuid and not apikey:
|
||||
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||
return False
|
||||
|
||||
if job_uuid:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
|
||||
(
|
||||
job_uuid,
|
||||
VALUE_JOB_PENDING,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
|
||||
(
|
||||
apikey,
|
||||
VALUE_JOB_PENDING,
|
||||
),
|
||||
)
|
||||
|
||||
if self.__cursor.rowcount == 0:
|
||||
self.__logger.info("No matching rows found.")
|
||||
return False
|
||||
else:
|
||||
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
|
||||
|
||||
self.__connect.commit()
|
||||
return True
|
||||
|
||||
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||
"""Delete the job with the given uuid or apikey"""
|
||||
if job_uuid:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
|
||||
)
|
||||
elif apikey:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
|
||||
)
|
||||
else:
|
||||
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||
return False
|
||||
|
||||
if self.__cursor.rowcount == 0:
|
||||
print("No matching rows found.")
|
||||
else:
|
||||
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
|
||||
self.__connect.commit()
|
||||
return True
|
||||
|
||||
def safe_disconnect(self):
|
||||
if self.__connect is not None:
|
||||
self.__connect.commit()
|
||||
self.__connect.close()
|
||||
self.__logger.info("Disconnected from database.")
|
||||
else:
|
||||
self.__logger.warn("No database connection to close.")
|
||||
|
|
@ -84,8 +84,8 @@ class Img2Img:
|
|||
|
||||
return {
|
||||
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||
KEY_SEED.lower(): str(seed),
|
||||
KEY_WIDTH.lower(): config.get_width(),
|
||||
KEY_HEIGHT.lower(): config.get_height(),
|
||||
KEY_STEPS.lower(): config.get_steps(),
|
||||
KEY_SEED: str(seed),
|
||||
KEY_WIDTH: config.get_width(),
|
||||
KEY_HEIGHT: config.get_height(),
|
||||
KEY_STEPS: config.get_steps(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ class Text2Img:
|
|||
|
||||
return {
|
||||
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||
KEY_SEED.lower(): str(seed),
|
||||
KEY_WIDTH.lower(): config.get_width(),
|
||||
KEY_HEIGHT.lower(): config.get_height(),
|
||||
KEY_STEPS.lower(): config.get_steps(),
|
||||
KEY_SEED: str(seed),
|
||||
KEY_WIDTH: config.get_width(),
|
||||
KEY_HEIGHT: config.get_height(),
|
||||
KEY_STEPS: config.get_steps(),
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue