transfer to use sqlite3 db instead of internal memory
This commit is contained in:
parent
11948820f8
commit
db6d3a94a7
|
|
@ -6,6 +6,9 @@ __pycache__/
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
# sqlite3 db
|
||||||
|
*.db
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
|
|
||||||
1
BUILD
1
BUILD
|
|
@ -8,6 +8,7 @@ par_binary(
|
||||||
srcs=["main.py"],
|
srcs=["main.py"],
|
||||||
deps=[
|
deps=[
|
||||||
"//utilities:constants",
|
"//utilities:constants",
|
||||||
|
"//utilities:database",
|
||||||
"//utilities:logger",
|
"//utilities:logger",
|
||||||
"//utilities:model",
|
"//utilities:model",
|
||||||
"//utilities:text2img",
|
"//utilities:text2img",
|
||||||
|
|
|
||||||
211
main.py
211
main.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import tempfile
|
import tempfile
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
@ -10,8 +11,7 @@ from threading import Event
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
from utilities.constants import API_KEY
|
from utilities.constants import APIKEY
|
||||||
from utilities.constants import API_KEY_FOR_DEMO
|
|
||||||
from utilities.constants import KEY_APP
|
from utilities.constants import KEY_APP
|
||||||
from utilities.constants import KEY_JOB_STATUS
|
from utilities.constants import KEY_JOB_STATUS
|
||||||
from utilities.constants import KEY_JOB_TYPE
|
from utilities.constants import KEY_JOB_TYPE
|
||||||
|
|
@ -33,6 +33,7 @@ from utilities.constants import VALUE_JOB_PENDING
|
||||||
from utilities.constants import VALUE_JOB_RUNNING
|
from utilities.constants import VALUE_JOB_RUNNING
|
||||||
from utilities.constants import VALUE_JOB_DONE
|
from utilities.constants import VALUE_JOB_DONE
|
||||||
from utilities.constants import VALUE_JOB_FAILED
|
from utilities.constants import VALUE_JOB_FAILED
|
||||||
|
from utilities.database import Database
|
||||||
from utilities.envvar import get_env_var_with_default
|
from utilities.envvar import get_env_var_with_default
|
||||||
from utilities.envvar import get_env_var
|
from utilities.envvar import get_env_var
|
||||||
from utilities.times import wait_for_seconds
|
from utilities.times import wait_for_seconds
|
||||||
|
|
@ -44,10 +45,10 @@ from utilities.img2img import Img2Img
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config['TESTING'] = False
|
|
||||||
memory_lock = Lock()
|
memory_lock = Lock()
|
||||||
event_termination = Event()
|
event_termination = Event()
|
||||||
logger = Logger(name=LOGGER_NAME)
|
logger = Logger(name=LOGGER_NAME)
|
||||||
|
database = Database(logger)
|
||||||
use_gpu = True
|
use_gpu = True
|
||||||
|
|
||||||
local_job_stack = []
|
local_job_stack = []
|
||||||
|
|
@ -57,13 +58,14 @@ local_completed_jobs = []
|
||||||
@app.route("/add_job", methods=["POST"])
|
@app.route("/add_job", methods=["POST"])
|
||||||
def add_job():
|
def add_job():
|
||||||
req = request.get_json()
|
req = request.get_json()
|
||||||
if API_KEY not in req:
|
|
||||||
|
if APIKEY not in req:
|
||||||
|
logger.error(f"{APIKEY} not present in {req}")
|
||||||
return "", 401
|
return "", 401
|
||||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
with memory_lock:
|
||||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
user = database.validate_user(req[APIKEY])
|
||||||
return "", 401
|
if not user:
|
||||||
else:
|
logger.error(f"user not found with {req[APIKEY]}")
|
||||||
# TODO: add logic to validate app key with a particular user
|
|
||||||
return "", 401
|
return "", 401
|
||||||
|
|
||||||
for key in req.keys():
|
for key in req.keys():
|
||||||
|
|
@ -76,71 +78,58 @@ def add_job():
|
||||||
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||||
return jsonify({"msg": "missing reference image"}), 404
|
return jsonify({"msg": "missing reference image"}), 404
|
||||||
|
|
||||||
if len(local_job_stack) > MAX_JOB_NUMBER:
|
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
|
||||||
return jsonify({"msg": "too many jobs in queue, please wait"}), 500
|
return (
|
||||||
|
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
req[UUID] = str(uuid.uuid4())
|
job_uuid = str(uuid.uuid4())
|
||||||
logger.info("adding a new job with uuid {}..".format(req[UUID]))
|
logger.info("adding a new job with uuid {}..".format(job_uuid))
|
||||||
|
|
||||||
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
|
|
||||||
req["position"] = len(local_job_stack) + 1
|
|
||||||
|
|
||||||
with memory_lock:
|
with memory_lock:
|
||||||
local_job_stack.append(req)
|
database.insert_new_job(req, job_uuid=job_uuid)
|
||||||
|
|
||||||
return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]})
|
return jsonify({"msg": "", UUID: job_uuid})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/cancel_job", methods=["POST"])
|
@app.route("/cancel_job", methods=["POST"])
|
||||||
def cancel_job():
|
def cancel_job():
|
||||||
req = request.get_json()
|
req = request.get_json()
|
||||||
if API_KEY not in req:
|
if APIKEY not in req:
|
||||||
return "", 401
|
return "", 401
|
||||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
with memory_lock:
|
||||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
user = database.validate_user(req[APIKEY])
|
||||||
return "", 401
|
if not user:
|
||||||
else:
|
|
||||||
# TODO: add logic to validate app key with a particular user
|
|
||||||
return "", 401
|
return "", 401
|
||||||
|
|
||||||
if UUID not in req:
|
if UUID not in req:
|
||||||
return jsonify({"msg": "missing uuid"}), 404
|
return jsonify({"msg": "missing uuid"}), 404
|
||||||
|
|
||||||
logger.info("removing job with uuid {}..".format(req[UUID]))
|
logger.info("cancelling job with uuid {}..".format(req[UUID]))
|
||||||
|
|
||||||
cancel_job_position = None
|
|
||||||
with memory_lock:
|
with memory_lock:
|
||||||
for job_position in range(len(local_job_stack)):
|
result = database.cancel_job(job_uuid=req[UUID])
|
||||||
if local_job_stack[job_position][UUID] == req[UUID]:
|
|
||||||
cancel_job_position = job_position
|
if result:
|
||||||
break
|
msg = "job with uuid {} removed".format(req[UUID])
|
||||||
logger.info("foud {}".format(cancel_job_position))
|
return jsonify({"msg": msg})
|
||||||
if cancel_job_position is not None:
|
|
||||||
if local_job_stack[cancel_job_position][API_KEY] != req[API_KEY]:
|
with memory_lock:
|
||||||
return "", 401
|
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||||
if (
|
|
||||||
local_job_stack[cancel_job_position][KEY_JOB_STATUS]
|
if jobs:
|
||||||
== VALUE_JOB_RUNNING
|
return (
|
||||||
):
|
jsonify(
|
||||||
logger.info(
|
{
|
||||||
"job at {} with uuid {} is running and cannot be cancelled".format(
|
"msg": "job {} is not in pending state, unable to cancel".format(
|
||||||
cancel_job_position, req[UUID]
|
req[UUID]
|
||||||
)
|
)
|
||||||
)
|
}
|
||||||
return (
|
),
|
||||||
jsonify(
|
405,
|
||||||
{
|
)
|
||||||
"msg": "job {} is already running, unable to cancel".format(
|
|
||||||
req[UUID]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
),
|
|
||||||
405,
|
|
||||||
)
|
|
||||||
del local_job_stack[cancel_job_position]
|
|
||||||
msg = "job with uuid {} removed".format(req[UUID])
|
|
||||||
logger.info(msg)
|
|
||||||
return jsonify({"msg": msg})
|
|
||||||
return (
|
return (
|
||||||
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
||||||
404,
|
404,
|
||||||
|
|
@ -150,37 +139,16 @@ def cancel_job():
|
||||||
@app.route("/get_jobs", methods=["POST"])
|
@app.route("/get_jobs", methods=["POST"])
|
||||||
def get_jobs():
|
def get_jobs():
|
||||||
req = request.get_json()
|
req = request.get_json()
|
||||||
if API_KEY not in req:
|
if APIKEY not in req:
|
||||||
return "", 401
|
return "", 401
|
||||||
if get_env_var_with_default(KEY_APP, VALUE_APP) == VALUE_APP:
|
|
||||||
if req[API_KEY] != API_KEY_FOR_DEMO:
|
|
||||||
return "", 401
|
|
||||||
else:
|
|
||||||
# TODO: add logic to validate app key with a particular user
|
|
||||||
return "", 401
|
|
||||||
|
|
||||||
jobs = []
|
|
||||||
|
|
||||||
all_job_stack = local_job_stack + local_completed_jobs
|
|
||||||
with memory_lock:
|
with memory_lock:
|
||||||
for job_position in range(len(all_job_stack)):
|
user = database.validate_user(req[APIKEY])
|
||||||
# filter on API_KEY
|
if not user:
|
||||||
if all_job_stack[job_position][API_KEY] != req[API_KEY]:
|
return "", 401
|
||||||
continue
|
|
||||||
# filter on UUID
|
with memory_lock:
|
||||||
if UUID in req and req[UUID] != all_job_stack[job_position][UUID]:
|
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||||
continue
|
|
||||||
job = copy.deepcopy(all_job_stack[job_position])
|
|
||||||
if job[KEY_JOB_STATUS] == VALUE_JOB_DONE:
|
|
||||||
del job["position"]
|
|
||||||
del job[API_KEY]
|
|
||||||
jobs.append(job)
|
|
||||||
|
|
||||||
if len(jobs) == 0:
|
|
||||||
return (
|
|
||||||
jsonify({"msg": "found no jobs for api_key={}".format(req[API_KEY])}),
|
|
||||||
404,
|
|
||||||
)
|
|
||||||
return jsonify({"jobs": jobs})
|
return jsonify({"jobs": jobs})
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -213,7 +181,7 @@ def load_model(logger: Logger) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def backend(event_termination):
|
def backend(event_termination, db):
|
||||||
model = load_model(logger)
|
model = load_model(logger)
|
||||||
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||||
|
|
@ -225,15 +193,20 @@ def backend(event_termination):
|
||||||
wait_for_seconds(1)
|
wait_for_seconds(1)
|
||||||
|
|
||||||
with memory_lock:
|
with memory_lock:
|
||||||
if len(local_job_stack) == 0:
|
pending_jobs = database.get_all_pending_jobs()
|
||||||
continue
|
|
||||||
next_job = local_job_stack[0]
|
|
||||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_RUNNING
|
|
||||||
|
|
||||||
prompt = next_job[KEY_PROMPT.lower()]
|
if len(pending_jobs) == 0:
|
||||||
negative_prompt = next_job.get(KEY_NEG_PROMPT.lower(), "")
|
continue
|
||||||
|
|
||||||
config = Config().set_config(next_job)
|
next_job = pending_jobs[0]
|
||||||
|
|
||||||
|
with memory_lock:
|
||||||
|
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
|
||||||
|
|
||||||
|
prompt = next_job[KEY_PROMPT]
|
||||||
|
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||||
|
|
||||||
|
config = Config().set_config(next_job)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
||||||
|
|
@ -250,33 +223,35 @@ def backend(event_termination):
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("text2img.lunch error: {}".format(e))
|
logger.error("text2img.lunch error: {}".format(e))
|
||||||
local_job_stack.pop(0)
|
with memory_lock:
|
||||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_FAILED
|
database.update_job(
|
||||||
local_completed_jobs.append(next_job)
|
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
with memory_lock:
|
with memory_lock:
|
||||||
local_job_stack.pop(0)
|
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
|
||||||
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
|
database.update_job(result_dict, job_uuid=next_job[UUID])
|
||||||
next_job.update(result_dict)
|
|
||||||
local_completed_jobs.append(next_job)
|
|
||||||
|
|
||||||
logger.critical("stopped")
|
logger.critical("stopped")
|
||||||
if len(local_job_stack) > 0:
|
|
||||||
logger.info(
|
|
||||||
"remaining {} jobs in stack: {}".format(
|
|
||||||
len(local_job_stack), local_job_stack
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(db_filepath, is_testing: bool = False):
|
||||||
if app.testing:
|
database.connect(db_filepath)
|
||||||
|
|
||||||
|
if is_testing:
|
||||||
try:
|
try:
|
||||||
app.run(host="0.0.0.0", port="5000")
|
app.run(host="0.0.0.0", port="5000")
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
thread = Thread(target=backend, args=(event_termination,))
|
thread = Thread(
|
||||||
|
target=backend,
|
||||||
|
args=(
|
||||||
|
event_termination,
|
||||||
|
database,
|
||||||
|
),
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
# ugly solution for now
|
# ugly solution for now
|
||||||
# TODO: use a database to track instead of internal memory
|
# TODO: use a database to track instead of internal memory
|
||||||
|
|
@ -285,8 +260,24 @@ def main():
|
||||||
thread.join()
|
thread.join()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
event_termination.set()
|
event_termination.set()
|
||||||
thread.join(1)
|
|
||||||
|
database.safe_disconnect()
|
||||||
|
|
||||||
|
thread.join(2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# Add an argument to set the 'testing' flag
|
||||||
|
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
|
||||||
|
|
||||||
|
# Add an argument to set the path of the database file
|
||||||
|
parser.add_argument(
|
||||||
|
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(args)
|
||||||
|
|
||||||
|
main(args.db, args.testing)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def create_table_users(c):
|
||||||
|
"""Create the users table if it doesn't exist"""
|
||||||
|
c.execute(
|
||||||
|
"""CREATE TABLE IF NOT EXISTS users
|
||||||
|
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
username TEXT UNIQUE,
|
||||||
|
apikey TEXT)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_table_history(c):
|
||||||
|
"""Create the history table if it doesn't exist"""
|
||||||
|
c.execute(
|
||||||
|
"""CREATE TABLE IF NOT EXISTS history
|
||||||
|
(uuid TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP,
|
||||||
|
apikey TEXT,
|
||||||
|
priority INT,
|
||||||
|
type TEXT,
|
||||||
|
status TEXT,
|
||||||
|
prompt TEXT,
|
||||||
|
neg_prompt TEXT,
|
||||||
|
seed TEXT,
|
||||||
|
ref_img TEXT,
|
||||||
|
img TEXT,
|
||||||
|
width INT,
|
||||||
|
height INT,
|
||||||
|
guidance_scale FLOAT,
|
||||||
|
steps INT,
|
||||||
|
scheduler TEXT,
|
||||||
|
strength FLOAT,
|
||||||
|
base_model TEXT,
|
||||||
|
lora_model TEXT
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(c, username, apikey):
|
||||||
|
"""Create a user with the given username and apikey, or update the apikey if the username already exists"""
|
||||||
|
c.execute("SELECT * FROM users WHERE username=?", (username,))
|
||||||
|
result = c.fetchone()
|
||||||
|
if result is not None:
|
||||||
|
raise ValueError(f"found exisitng user {username}, please use update")
|
||||||
|
else:
|
||||||
|
c.execute("INSERT INTO users (username, apikey) VALUES (?, ?)", (username, apikey))
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(c, username, apikey):
|
||||||
|
"""Update the apikey for the user with the given username"""
|
||||||
|
c.execute("SELECT apikey FROM users WHERE username=?", (username,))
|
||||||
|
result = c.fetchone()
|
||||||
|
if result is not None:
|
||||||
|
old_apikey = result[0]
|
||||||
|
c.execute("UPDATE history SET apikey=? WHERE apikey=?", (apikey, old_apikey))
|
||||||
|
c.execute("UPDATE users SET apikey=? WHERE username=?", (apikey, username))
|
||||||
|
else:
|
||||||
|
raise ValueError("username does not exist! create it first?")
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user(c, username):
|
||||||
|
"""Delete the user with the given username, or ignore the operation if the user does not exist"""
|
||||||
|
c.execute("DELETE FROM history WHERE apikey=(SELECT apikey FROM users WHERE username=?)", (username,))
|
||||||
|
c.execute("DELETE FROM users WHERE username=?", (username,))
|
||||||
|
|
||||||
|
def show_users(c, username="", details=False):
|
||||||
|
"""Print all users in the users table if username is not specified,
|
||||||
|
or only the user with the given username otherwise"""
|
||||||
|
if username:
|
||||||
|
c.execute("SELECT username, apikey FROM users WHERE username=?", (username,))
|
||||||
|
user = c.fetchone()
|
||||||
|
if user:
|
||||||
|
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
|
||||||
|
count = c.fetchone()[0]
|
||||||
|
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||||
|
if details:
|
||||||
|
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
|
||||||
|
result = c.fetchall()
|
||||||
|
print(result)
|
||||||
|
else:
|
||||||
|
print(f"No user with username '{username}' found")
|
||||||
|
else:
|
||||||
|
c.execute("SELECT username, apikey FROM users")
|
||||||
|
users = c.fetchall()
|
||||||
|
for user in users:
|
||||||
|
c.execute("SELECT COUNT(*) FROM history WHERE apikey=?", (user[1],))
|
||||||
|
count = c.fetchone()[0]
|
||||||
|
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||||
|
if details:
|
||||||
|
c.execute("SELECT * FROM history WHERE apikey=?", (user[1],))
|
||||||
|
result = c.fetchall()
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers(dest="action")
|
||||||
|
|
||||||
|
# Sub-parser for the "create" action
|
||||||
|
create_parser = subparsers.add_parser("create")
|
||||||
|
create_parser.add_argument("username")
|
||||||
|
create_parser.add_argument("apikey")
|
||||||
|
|
||||||
|
# Sub-parser for the "update" action
|
||||||
|
update_parser = subparsers.add_parser("update")
|
||||||
|
update_parser.add_argument("username")
|
||||||
|
update_parser.add_argument("apikey")
|
||||||
|
|
||||||
|
# Sub-parser for the "delete" action
|
||||||
|
delete_parser = subparsers.add_parser("delete")
|
||||||
|
delete_parser.add_argument("username")
|
||||||
|
|
||||||
|
# Sub-parser for the "delete" action
|
||||||
|
list_parser = subparsers.add_parser("list")
|
||||||
|
list_parser.add_argument("username", nargs="?", default="")
|
||||||
|
list_parser.add_argument("--details", action="store_true", help="Showing more details")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Connect to the database (creates a new file if it doesn't exist)
|
||||||
|
conn = sqlite3.connect("happysd.db")
|
||||||
|
c = conn.cursor()
|
||||||
|
|
||||||
|
# Create the users and history tables if they don't exist
|
||||||
|
create_table_users(c)
|
||||||
|
create_table_history(c)
|
||||||
|
|
||||||
|
# Perform the requested action
|
||||||
|
if args.action == "create":
|
||||||
|
create_user(c, args.username, args.apikey)
|
||||||
|
print("User created")
|
||||||
|
elif args.action == "update":
|
||||||
|
update_user(c, args.username, args.apikey)
|
||||||
|
print("User updated")
|
||||||
|
elif args.action == "delete":
|
||||||
|
delete_user(c, args.username)
|
||||||
|
print("User deleted")
|
||||||
|
elif args.action == "list":
|
||||||
|
show_users(c, args.username, args.details)
|
||||||
|
|
||||||
|
# Commit the changes to the database
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -213,7 +213,7 @@
|
||||||
url: '/get_jobs',
|
url: '/get_jobs',
|
||||||
contentType: 'application/json; charset=utf-8',
|
contentType: 'application/json; charset=utf-8',
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
data: JSON.stringify({ 'api_key': apikeyVal, 'uuid': uuidValue }),
|
data: JSON.stringify({ 'apikey': apikeyVal, 'uuid': uuidValue }),
|
||||||
success: function (response) {
|
success: function (response) {
|
||||||
console.log(response);
|
console.log(response);
|
||||||
if (response.jobs.length == 1) {
|
if (response.jobs.length == 1) {
|
||||||
|
|
@ -404,7 +404,7 @@
|
||||||
contentType: 'application/json; charset=utf-8',
|
contentType: 'application/json; charset=utf-8',
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
data: JSON.stringify({
|
data: JSON.stringify({
|
||||||
'api_key': apikeyVal,
|
'apikey': apikeyVal,
|
||||||
'type': 'txt',
|
'type': 'txt',
|
||||||
'prompt': promptVal,
|
'prompt': promptVal,
|
||||||
'seed': seedVal,
|
'seed': seedVal,
|
||||||
|
|
@ -513,7 +513,7 @@
|
||||||
contentType: 'application/json; charset=utf-8',
|
contentType: 'application/json; charset=utf-8',
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
data: JSON.stringify({
|
data: JSON.stringify({
|
||||||
'api_key': apikeyVal,
|
'apikey': apikeyVal,
|
||||||
'type': 'img',
|
'type': 'img',
|
||||||
'ref_img': imageData,
|
'ref_img': imageData,
|
||||||
'prompt': promptVal,
|
'prompt': promptVal,
|
||||||
|
|
|
||||||
|
|
@ -16,21 +16,29 @@ py_library(
|
||||||
srcs=["constants.py"],
|
srcs=["constants.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="database",
|
||||||
|
srcs=["database.py"],
|
||||||
|
deps=[
|
||||||
|
":logger",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "envvar",
|
name="envvar",
|
||||||
srcs = ["envvar.py"],
|
srcs=["envvar.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "envvar_test",
|
name="envvar_test",
|
||||||
srcs = ["envvar_test.py"],
|
srcs=["envvar_test.py"],
|
||||||
deps = [":envvar"],
|
deps=[":envvar"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "images",
|
name="images",
|
||||||
srcs = ["images.py"],
|
srcs=["images.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@ from utilities.constants import KEY_HEIGHT
|
||||||
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
||||||
from utilities.constants import KEY_STRENGTH
|
from utilities.constants import KEY_STRENGTH
|
||||||
from utilities.constants import VALUE_STRENGTH_DEFAULT
|
from utilities.constants import VALUE_STRENGTH_DEFAULT
|
||||||
from utilities.constants import KEY_PREVIEW
|
|
||||||
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
|
||||||
from utilities.constants import KEY_SCHEDULER
|
from utilities.constants import KEY_SCHEDULER
|
||||||
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||||
from utilities.constants import VALUE_SCHEDULER_DDIM
|
from utilities.constants import VALUE_SCHEDULER_DDIM
|
||||||
|
|
@ -84,16 +82,6 @@ class Config:
|
||||||
self.__config[KEY_HEIGHT] = value
|
self.__config[KEY_HEIGHT] = value
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_preview(self) -> bool:
|
|
||||||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
|
||||||
|
|
||||||
def set_preview(self, boolean: bool):
|
|
||||||
self.__logger.info(
|
|
||||||
"{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)
|
|
||||||
)
|
|
||||||
self.__config[KEY_PREVIEW] = boolean
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_scheduler(self) -> str:
|
def get_scheduler(self) -> str:
|
||||||
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,71 +7,81 @@ LOGGER_NAME_IMG2IMG = "img2img"
|
||||||
MAX_JOB_NUMBER = 10
|
MAX_JOB_NUMBER = 10
|
||||||
|
|
||||||
|
|
||||||
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
|
||||||
|
KEY_OUTPUT_FOLDER = "outfolder"
|
||||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||||
|
|
||||||
KEY_SEED = "SEED"
|
#
|
||||||
VALUE_SEED_DEFAULT = 0
|
# Database
|
||||||
|
#
|
||||||
|
HISTORY_TABLE_NAME = "history"
|
||||||
|
USERS_TABLE_NAME = "users"
|
||||||
|
|
||||||
KEY_WIDTH = "WIDTH"
|
#
|
||||||
VALUE_WIDTH_DEFAULT = 512
|
# REST API Keys
|
||||||
|
#
|
||||||
|
|
||||||
KEY_HEIGHT = "HEIGHT"
|
# - input and output
|
||||||
VALUE_HEIGHT_DEFAULT = 512
|
APIKEY = "apikey"
|
||||||
|
|
||||||
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
KEY_JOB_TYPE = "type"
|
||||||
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0
|
VALUE_JOB_TXT2IMG = "txt" # default value for KEY_JOB_TYPE
|
||||||
|
VALUE_JOB_IMG2IMG = "img"
|
||||||
|
REFERENCE_IMG = "ref_img"
|
||||||
|
VALUE_JOB_INPAINTING = "inpaint"
|
||||||
|
|
||||||
KEY_STRENGTH = "STRENGTH"
|
KEY_PROMPT = "prompt"
|
||||||
VALUE_STRENGTH_DEFAULT = 0.5
|
KEY_NEG_PROMPT = "neg_prompt"
|
||||||
|
KEY_SEED = "seed"
|
||||||
KEY_STEPS = "STEPS"
|
VALUE_SEED_DEFAULT = 0 # default value for KEY_SEED
|
||||||
VALUE_STEPS_DEFAULT = 50
|
KEY_WIDTH = "width"
|
||||||
|
VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH
|
||||||
KEY_SCHEDULER = "SCHEDULER"
|
KEY_HEIGHT = "height"
|
||||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT
|
||||||
|
KEY_GUIDANCE_SCALE = "guidance_scale"
|
||||||
|
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE
|
||||||
|
KEY_STEPS = "steps"
|
||||||
|
VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS
|
||||||
|
KEY_SCHEDULER = "scheduler"
|
||||||
|
VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER
|
||||||
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
|
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
|
||||||
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
|
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
|
||||||
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
|
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
|
||||||
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
|
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
|
||||||
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
|
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
|
||||||
|
KEY_STRENGTH = "strength"
|
||||||
|
VALUE_STRENGTH_DEFAULT = 0.5 # default value for KEY_STRENGTH
|
||||||
|
|
||||||
KEY_PROMPT = "PROMPT"
|
REQUIRED_KEYS = [
|
||||||
KEY_NEG_PROMPT = "NEG_PROMPT"
|
APIKEY, # str
|
||||||
|
KEY_PROMPT, # str
|
||||||
|
KEY_JOB_TYPE, # str
|
||||||
|
]
|
||||||
|
OPTIONAL_KEYS = [
|
||||||
|
KEY_NEG_PROMPT, # str
|
||||||
|
KEY_SEED, # str
|
||||||
|
KEY_WIDTH, # int
|
||||||
|
KEY_HEIGHT, # int
|
||||||
|
KEY_GUIDANCE_SCALE, # float
|
||||||
|
KEY_STEPS, # int
|
||||||
|
KEY_SCHEDULER, # str
|
||||||
|
KEY_STRENGTH, # float
|
||||||
|
REFERENCE_IMG, # str (base64)
|
||||||
|
]
|
||||||
|
|
||||||
KEY_PREVIEW = "PREVIEW"
|
# - output only
|
||||||
VALUE_PREVIEW_DEFAULT = True
|
|
||||||
|
|
||||||
# REST API Keys
|
|
||||||
API_KEY = "api_key"
|
|
||||||
API_KEY_FOR_DEMO = "demo"
|
|
||||||
UUID = "uuid"
|
UUID = "uuid"
|
||||||
|
|
||||||
BASE64IMAGE = "img"
|
BASE64IMAGE = "img"
|
||||||
|
KEY_PRIORITY = "priority"
|
||||||
KEY_JOB_STATUS = "status"
|
KEY_JOB_STATUS = "status"
|
||||||
VALUE_JOB_PENDING = "pending"
|
VALUE_JOB_PENDING = "pending" # default value for KEY_JOB_STATUS
|
||||||
VALUE_JOB_RUNNING = "running"
|
VALUE_JOB_RUNNING = "running"
|
||||||
VALUE_JOB_DONE = "done"
|
VALUE_JOB_DONE = "done"
|
||||||
VALUE_JOB_FAILED = "failed"
|
VALUE_JOB_FAILED = "failed"
|
||||||
KEY_JOB_TYPE = "type"
|
|
||||||
VALUE_JOB_TXT2IMG = "txt"
|
|
||||||
VALUE_JOB_IMG2IMG = "img"
|
|
||||||
VALUE_JOB_INPAINTING = "inpaint"
|
|
||||||
REFERENCE_IMG = "ref_img"
|
|
||||||
|
|
||||||
REQUIRED_KEYS = [
|
OUTPUT_ONLY_KEYS = [
|
||||||
API_KEY.lower(),
|
UUID, # str
|
||||||
KEY_PROMPT.lower(),
|
KEY_PRIORITY, # int
|
||||||
KEY_JOB_TYPE.lower(),
|
BASE64IMAGE, # str (base64)
|
||||||
]
|
KEY_JOB_STATUS, # str
|
||||||
OPTIONAL_KEYS = [
|
|
||||||
KEY_NEG_PROMPT.lower(),
|
|
||||||
KEY_SEED.lower(),
|
|
||||||
KEY_WIDTH.lower(),
|
|
||||||
KEY_HEIGHT.lower(),
|
|
||||||
KEY_GUIDANCE_SCALE.lower(),
|
|
||||||
KEY_STEPS.lower(),
|
|
||||||
KEY_SCHEDULER.lower(),
|
|
||||||
KEY_STRENGTH.lower(),
|
|
||||||
REFERENCE_IMG.lower(),
|
|
||||||
]
|
]
|
||||||
|
|
@ -0,0 +1,262 @@
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from utilities.constants import APIKEY
|
||||||
|
from utilities.constants import UUID
|
||||||
|
from utilities.constants import KEY_PRIORITY
|
||||||
|
from utilities.constants import KEY_JOB_TYPE
|
||||||
|
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||||
|
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||||
|
from utilities.constants import VALUE_JOB_INPAINTING
|
||||||
|
from utilities.constants import KEY_JOB_STATUS
|
||||||
|
from utilities.constants import VALUE_JOB_PENDING
|
||||||
|
from utilities.constants import VALUE_JOB_RUNNING
|
||||||
|
from utilities.constants import VALUE_JOB_DONE
|
||||||
|
|
||||||
|
from utilities.constants import OUTPUT_ONLY_KEYS
|
||||||
|
from utilities.constants import OPTIONAL_KEYS
|
||||||
|
from utilities.constants import REQUIRED_KEYS
|
||||||
|
|
||||||
|
from utilities.constants import REFERENCE_IMG
|
||||||
|
from utilities.constants import BASE64IMAGE
|
||||||
|
|
||||||
|
from utilities.constants import HISTORY_TABLE_NAME
|
||||||
|
from utilities.constants import USERS_TABLE_NAME
|
||||||
|
from utilities.logger import DummyLogger
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""This class represents a SQLite database and assumes single-thread usage."""
|
||||||
|
|
||||||
|
def __init__(self, logger: DummyLogger = DummyLogger()):
|
||||||
|
"""Initialize the class with a logger instance, but without a database connection or cursor."""
|
||||||
|
self.__connect = None # the database connection object
|
||||||
|
self.__cursor = None # the cursor object for executing SQL statements
|
||||||
|
self.__logger = logger # the logger object for logging messages
|
||||||
|
|
||||||
|
def connect(self, db_filepath) -> bool:
|
||||||
|
"""
|
||||||
|
Connect to the SQLite database file specified by `db_filepath`.
|
||||||
|
|
||||||
|
Returns True if the connection was successful, otherwise False.
|
||||||
|
"""
|
||||||
|
if not os.path.isfile(db_filepath):
|
||||||
|
self.__logger.error(f"{db_filepath} does not exist!")
|
||||||
|
return False
|
||||||
|
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
|
||||||
|
self.__cursor = self.__connect.cursor()
|
||||||
|
self.__logger.info(f"Connected to database {db_filepath}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_user(self, apikey: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate if the provided API key exists in the users table and return the corresponding
|
||||||
|
username if found, or an empty string otherwise.
|
||||||
|
"""
|
||||||
|
if self.__cursor is None:
|
||||||
|
self.__logger.error("Did you forget to connect to the database?")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
|
||||||
|
self.__cursor.execute(query, (apikey,))
|
||||||
|
result = self.__cursor.fetchone()
|
||||||
|
|
||||||
|
self.__logger.debug(result)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
return result[0] # the first column is the username
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_all_pending_jobs(self, apikey: str = "") -> list:
|
||||||
|
return self.get_jobs(apikey=apikey, job_status=VALUE_JOB_PENDING)
|
||||||
|
|
||||||
|
def count_all_pending_jobs(self, apikey: str) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of pending jobs in the HISTORY_TABLE_NAME table for the specified API key.
|
||||||
|
|
||||||
|
Returns the number of pending jobs found.
|
||||||
|
"""
|
||||||
|
if self.__cursor is None:
|
||||||
|
self.__logger.error("Did you forget to connect to the database?")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Construct the SQL query string and list of arguments
|
||||||
|
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
|
||||||
|
query_args = (apikey, VALUE_JOB_PENDING)
|
||||||
|
|
||||||
|
# Execute the query and return the count
|
||||||
|
self.__cursor.execute(query_string, query_args)
|
||||||
|
result = self.__cursor.fetchone()
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
|
||||||
|
"""
|
||||||
|
Get a list of jobs from the HISTORY_TABLE_NAME table based on optional filters.
|
||||||
|
|
||||||
|
If `job_uuid` or `apikey` or `job_status` is provided, the query will include that filter.
|
||||||
|
|
||||||
|
Returns a list of jobs matching the filters provided.
|
||||||
|
"""
|
||||||
|
if self.__cursor is None:
|
||||||
|
self.__logger.error("Did you forget to connect to the database?")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# construct the SQL query string and list of arguments based on the provided filters
|
||||||
|
query_args = []
|
||||||
|
query_filters = []
|
||||||
|
if job_uuid:
|
||||||
|
query_filters.append(f"{UUID} = ?")
|
||||||
|
query_args.append(job_uuid)
|
||||||
|
if apikey:
|
||||||
|
query_filters.append(f"{APIKEY} = ?")
|
||||||
|
query_args.append(apikey)
|
||||||
|
if job_status:
|
||||||
|
query_filters.append(f"{KEY_JOB_STATUS} = ?")
|
||||||
|
query_args.append(job_status)
|
||||||
|
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||||
|
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
|
||||||
|
if query_filters:
|
||||||
|
query_string += f" WHERE {' AND '.join(query_filters)}"
|
||||||
|
|
||||||
|
# execute the query and return the results
|
||||||
|
self.__cursor.execute(query_string, tuple(query_args))
|
||||||
|
rows = self.__cursor.fetchall()
|
||||||
|
|
||||||
|
jobs = []
|
||||||
|
for row in rows:
|
||||||
|
job = {
|
||||||
|
columns[i]: row[i] for i in range(len(columns)) if row[i] is not None
|
||||||
|
}
|
||||||
|
jobs.append(job)
|
||||||
|
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
def insert_new_job(self, job_dict: dict, job_uuid="") -> bool:
|
||||||
|
"""
|
||||||
|
Insert a new job into the HISTORY_TABLE_NAME table.
|
||||||
|
|
||||||
|
If `job_uuid` is not provided, a new UUID will be generated automatically.
|
||||||
|
|
||||||
|
Returns True if the insertion was successful, otherwise False.
|
||||||
|
"""
|
||||||
|
if self.__cursor is None:
|
||||||
|
self.__logger.error("Did you forget to connect to the database?")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not job_uuid:
|
||||||
|
job_uuid = str(uuid.uuid4())
|
||||||
|
self.__logger.info(f"inserting a new job with {job_uuid}")
|
||||||
|
|
||||||
|
values = [job_uuid, VALUE_JOB_PENDING]
|
||||||
|
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||||
|
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||||
|
values.append(job_dict.get(column, None))
|
||||||
|
|
||||||
|
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
|
||||||
|
self.__cursor.execute(query, tuple(values))
|
||||||
|
self.__connect.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
|
||||||
|
"""
|
||||||
|
Update an existing job in the HISTORY_TABLE_NAME table with the given `job_uuid`.
|
||||||
|
|
||||||
|
Returns True if the update was successful, otherwise False.
|
||||||
|
"""
|
||||||
|
if self.__cursor is None:
|
||||||
|
self.__logger.error("Did you forget to connect to the database?")
|
||||||
|
return False
|
||||||
|
|
||||||
|
values = []
|
||||||
|
columns = []
|
||||||
|
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||||
|
value = job_dict.get(column, None)
|
||||||
|
if value is not None:
|
||||||
|
columns.append(column)
|
||||||
|
values.append(value)
|
||||||
|
|
||||||
|
set_clause = ", ".join([f"{column}=?" for column in columns])
|
||||||
|
# Add current timestamp to update query
|
||||||
|
set_clause += ", updated_at=?"
|
||||||
|
values.append(datetime.datetime.now())
|
||||||
|
|
||||||
|
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
|
||||||
|
|
||||||
|
values.append(job_uuid)
|
||||||
|
|
||||||
|
self.__cursor.execute(query, tuple(values))
|
||||||
|
self.__connect.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||||
|
"""Cancel the job with the given job_uuid or apikey.
|
||||||
|
If job_uuid or apikey is provided, delete corresponding rows from table history if "status" matches "pending".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_uuid (str, optional): Unique job identifier. Defaults to "".
|
||||||
|
apikey (str, optional): API key associated with the job. Defaults to "".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the job was cancelled successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
if not job_uuid and not apikey:
|
||||||
|
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if job_uuid:
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
|
||||||
|
(
|
||||||
|
job_uuid,
|
||||||
|
VALUE_JOB_PENDING,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
|
||||||
|
(
|
||||||
|
apikey,
|
||||||
|
VALUE_JOB_PENDING,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__cursor.rowcount == 0:
|
||||||
|
self.__logger.info("No matching rows found.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
|
||||||
|
|
||||||
|
self.__connect.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||||
|
"""Delete the job with the given uuid or apikey"""
|
||||||
|
if job_uuid:
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
|
||||||
|
)
|
||||||
|
elif apikey:
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.__cursor.rowcount == 0:
|
||||||
|
print("No matching rows found.")
|
||||||
|
else:
|
||||||
|
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
|
||||||
|
self.__connect.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def safe_disconnect(self):
|
||||||
|
if self.__connect is not None:
|
||||||
|
self.__connect.commit()
|
||||||
|
self.__connect.close()
|
||||||
|
self.__logger.info("Disconnected from database.")
|
||||||
|
else:
|
||||||
|
self.__logger.warn("No database connection to close.")
|
||||||
|
|
@ -84,8 +84,8 @@ class Img2Img:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
BASE64IMAGE: image_to_base64(result.images[0]),
|
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||||
KEY_SEED.lower(): str(seed),
|
KEY_SEED: str(seed),
|
||||||
KEY_WIDTH.lower(): config.get_width(),
|
KEY_WIDTH: config.get_width(),
|
||||||
KEY_HEIGHT.lower(): config.get_height(),
|
KEY_HEIGHT: config.get_height(),
|
||||||
KEY_STEPS.lower(): config.get_steps(),
|
KEY_STEPS: config.get_steps(),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,8 @@ class Text2Img:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
BASE64IMAGE: image_to_base64(result.images[0]),
|
BASE64IMAGE: image_to_base64(result.images[0]),
|
||||||
KEY_SEED.lower(): str(seed),
|
KEY_SEED: str(seed),
|
||||||
KEY_WIDTH.lower(): config.get_width(),
|
KEY_WIDTH: config.get_width(),
|
||||||
KEY_HEIGHT.lower(): config.get_height(),
|
KEY_HEIGHT: config.get_height(),
|
||||||
KEY_STEPS.lower(): config.get_steps(),
|
KEY_STEPS: config.get_steps(),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue