separate frontend and backend, adds multilingual support (chinese), improve multiaccess to database

This commit is contained in:
HappyZ 2023-05-07 21:43:14 -07:00
parent 6cd69355ac
commit 7e84144432
9 changed files with 240 additions and 447 deletions

24
BUILD
View File

@ -4,19 +4,31 @@ load("@subpar//:subpar.bzl", "par_binary")
package(default_visibility=["//visibility:public"])
par_binary(
name="main",
srcs=["main.py"],
name="frontend",
srcs=["frontend.py"],
deps=[
"//utilities:constants",
"//utilities:database",
"//utilities:logger",
"//utilities:model",
"//utilities:text2img",
"//utilities:img2img",
"//utilities:envvar",
"//utilities:times",
],
data=[
"templates/index.html",
],
)
par_binary(
name="backend",
srcs=["backend.py"],
deps=[
"//utilities:constants",
"//utilities:database",
"//utilities:logger",
"//utilities:model",
"//utilities:text2img",
"//utilities:translator",
"//utilities:img2img",
"//utilities:times",
],
)

283
main.py
View File

@ -1,283 +0,0 @@
import argparse
import copy
import tempfile
import pkgutil
import uuid
from flask import jsonify
from flask import Flask
from flask import render_template
from flask import request
from threading import Event
from threading import Thread
from threading import Lock
from utilities.constants import APIKEY
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_JOB_TYPE
from utilities.constants import KEY_PROMPT
from utilities.constants import KEY_NEG_PROMPT
from utilities.constants import LOGGER_NAME
from utilities.constants import LOGGER_NAME_IMG2IMG
from utilities.constants import LOGGER_NAME_TXT2IMG
from utilities.constants import REFERENCE_IMG
from utilities.constants import MAX_JOB_NUMBER
from utilities.constants import OPTIONAL_KEYS
from utilities.constants import REQUIRED_KEYS
from utilities.constants import UUID
from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_TXT2IMG
from utilities.constants import VALUE_JOB_IMG2IMG
from utilities.constants import VALUE_JOB_INPAINTING
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import VALUE_JOB_FAILED
from utilities.database import Database
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
from utilities.logger import Logger
from utilities.model import Model
from utilities.config import Config
from utilities.text2img import Text2Img
from utilities.img2img import Img2Img
app = Flask(__name__)
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
database = Database(logger)
use_gpu = True
local_job_stack = []
local_completed_jobs = []
@app.route("/add_job", methods=["POST"])
def add_job():
req = request.get_json()
if APIKEY not in req:
logger.error(f"{APIKEY} not present in {req}")
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
logger.error(f"user not found with {req[APIKEY]}")
return "", 401
for key in req.keys():
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
for required_key in REQUIRED_KEYS:
if required_key not in req:
return jsonify({"msg": "missing one or more required keys"}), 404
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
return jsonify({"msg": "missing reference image"}), 404
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
return (
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
500,
)
job_uuid = str(uuid.uuid4())
logger.info("adding a new job with uuid {}..".format(job_uuid))
with memory_lock:
database.insert_new_job(req, job_uuid=job_uuid)
return jsonify({"msg": "", UUID: job_uuid})
@app.route("/cancel_job", methods=["POST"])
def cancel_job():
req = request.get_json()
if APIKEY not in req:
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
if UUID not in req:
return jsonify({"msg": "missing uuid"}), 404
logger.info("cancelling job with uuid {}..".format(req[UUID]))
with memory_lock:
result = database.cancel_job(job_uuid=req[UUID])
if result:
msg = "job with uuid {} removed".format(req[UUID])
return jsonify({"msg": msg})
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
if jobs:
return (
jsonify(
{
"msg": "job {} is not in pending state, unable to cancel".format(
req[UUID]
)
}
),
405,
)
return (
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
404,
)
@app.route("/get_jobs", methods=["POST"])
def get_jobs():
req = request.get_json()
if APIKEY not in req:
return "", 401
with memory_lock:
user = database.validate_user(req[APIKEY])
if not user:
return "", 401
with memory_lock:
jobs = database.get_jobs(job_uuid=req[UUID])
return jsonify({"jobs": jobs})
@app.route("/")
def index():
return render_template("index.html")
def load_model(logger: Logger) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "darkstorm2150/Protogen_x5.8_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x5.8_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu:
model.set_low_memory_mode()
model.load_all()
return model
def backend(event_termination, db):
model = load_model(logger)
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
text2img.breakfast()
img2img.breakfast()
while not event_termination.is_set():
wait_for_seconds(1)
with memory_lock:
pending_jobs = database.get_all_pending_jobs()
if len(pending_jobs) == 0:
continue
next_job = pending_jobs[0]
with memory_lock:
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
prompt = next_job[KEY_PROMPT]
negative_prompt = next_job[KEY_NEG_PROMPT]
config = Config().set_config(next_job)
try:
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
ref_img = next_job[REFERENCE_IMG]
result_dict = img2img.lunch(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=ref_img,
config=config,
)
except BaseException as e:
logger.error("text2img.lunch error: {}".format(e))
with memory_lock:
database.update_job(
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
)
continue
with memory_lock:
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
database.update_job(result_dict, job_uuid=next_job[UUID])
logger.critical("stopped")
def main(db_filepath, is_testing: bool = False):
database.connect(db_filepath)
if is_testing:
try:
app.run(host="0.0.0.0", port="5000")
except KeyboardInterrupt:
pass
return
thread = Thread(
target=backend,
args=(
event_termination,
database,
),
)
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run(host="0.0.0.0", port="8888")
thread.join()
except KeyboardInterrupt:
event_termination.set()
database.safe_disconnect()
thread.join(2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Add an argument to set the 'testing' flag
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
# Add an argument to set the path of the database file
parser.add_argument(
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
)
args = parser.parse_args()
logger.info(args)
main(args.db, args.testing)

View File

@ -34,13 +34,14 @@ def create_table_history(c):
c.execute(
"""CREATE TABLE IF NOT EXISTS history
(uuid TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP,
updated_at TIMESTAMP,
apikey TEXT,
priority INT,
type TEXT,
status TEXT,
prompt TEXT,
lang TEXT,
neg_prompt TEXT,
seed TEXT,
ref_img TEXT,
@ -146,7 +147,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()
@ -163,7 +164,7 @@ def show_users(c, username="", details=False):
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
if details:
c.execute(
"SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
(user[1],),
)
rows = c.fetchall()

View File

@ -7,3 +7,5 @@ Pillow
scikit-image
torch
transformers
sentencepiece
fcntl

View File

@ -1,4 +1,4 @@
<html lang="en">
<html>
<head>
<meta charset="utf-8">
@ -17,75 +17,84 @@
<div class="card-body">
<div class="row mb-3">
<div class="col-sm-8">
<label for="apiKey" class="form-label">API Key</label>
<input type="password" class="form-control" id="apiKey" value="demo">
<label for="apiKey" class="form-label" data-en_XX="API Key" data-zh_CN="API 密钥">API Key</label>
<input type="password" class="form-control" id="apiKey" value="">
</div>
<div class="col-sm-4">
<div class="form-check">
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
<label class="form-check-label" for="showPreview">
Preview Image
<label class="form-check-label" for="showPreview" data-en_XX="Preview Image"
data-zh_CN="预览生成图像">Preview Image
</label>
</div>
<select class="form-select" size="2" id="language">
<option value="zh_CN">中文</option>
<option selected value="en_XX">English</option>
</select>
</div>
</div>
<div class="form-row mb-3">
<label for="prompt" class="form-label">Describe Your Image</label>
<label for="prompt" class="form-label" data-en_XX="Describe Your Image"
data-zh_CN="形容你的图片(提示词)">Describe Your Image (Prompts)</label>
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
quality, film grain, Fujifilm XT3"</div>
<div id="promptHelp" class="form-text"
data-en_XX="Less than 77 words. Example: photo of a cute cat. Use () to emphasize."
data-zh_CN="少于77个词。比如一张可爱的猫的照片。用括号强调重要性。">Less than 77 words. Example: photo of a cute cat.
Use () to emphasize.</div>
</div>
<div class="form-row mb-3">
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
<label for="negPrompt" class="form-label"
data-en_XX="Describe What's NOT Your Image (Negative Prompts)"
data-zh_CN="反向形容你的图片(反向提示词)">Describe What's NOT Your Image (Negative Prompts)</label>
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" value="">
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
neck"</div>
<div id="negPromptHelp" class="form-text" data-en_XX="Less than 77 words. Optional."
data-zh_CN="少于77个词。非必填。">Less than 77 words. Optional.</div>
</div>
<div class="row">
<div class="col-md-3">
<div class="form-row">
<label for="inputSeed">Seed</label>
<label for="inputSeed" data-en_XX="Seed" data-zh_CN="随机数">Seed</label>
<input type="text" class="form-control" id="inputSeed" aria-describedby="inputSeedHelp"
value="">
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
seed
<div id="inputSeedHelp" class="form-text"
data-en_XX="Leave it empty or set 0 to use random seed" data-zh_CN="非必填。留白或填0使用默认随机数">
Leave it empty or set 0 to use a random seed
</div>
</div>
<div class="form-row">
<label for="inputSteps">Steps</label>
<label for="inputSteps" data-en_XX="Steps" data-zh_CN="迭代次数">Steps</label>
<input type="number" class="form-control" id="inputSteps" aria-describedby="inputStepsHelp"
placeholder="default is 50">
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
(GPU)
placeholder="50">
<div id="inputStepsHelp" class="form-text"
data-en_XX="More steps better image but longer time to generate"
data-zh_CN="迭代次数越多图片越好,但生成时间越久">More steps better image but longer time to generate
</div>
</div>
<div class="form-row">
<label for="inputWidth">Width</label>
<label for="inputWidth" data-en_XX="Width" data-zh_CN="图片宽度">Width</label>
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
max="1024">
</div>
<div class="form-row">
<label for="inputHeight">Height</label>
<label for="inputHeight" data-en_XX="Height" data-zh_CN="图片高度">Height</label>
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
max="1024">
</div>
<div class="form-row">
<label for="guidanceScale">Guidance Scale</label>
<label for="guidanceScale" data-en_XX="Guidance Scale" data-zh_CN="指导强度">Guidance
Scale</label>
<input type="number" class="form-control" id="inputGuidanceScale"
aria-describedby="inputGuidanceScaleHelp" placeholder="25" min="1" max="30">
<div id="inputGuidanceScaleHelp" class="form-text">How much guidance to follow from
description. 20 strictly follow prompt, 7 creative/artistic.
aria-describedby="inputGuidanceScaleHelp" placeholder="12.5" min="1" max="30">
<div id="inputGuidanceScaleHelp" class="form-text"
data-en_XX="Don't set it to the extremes (1 or 30). 20 means strictly follow prompt, 7 creative/artistic. Lower this number if you see bad images."
data-zh_CN="不建议设最低1或最高30。20代表提示词非常重要7更有创造性。如果看到图片结果较差适当减少强度。">
Don't set it to the extremes (1 or 30). 20 means strictly follow prompt, 7
creative/artistic. Lower this number if you see bad images.
</div>
</div>
<div class="row">
<button id="newTxt2ImgJob" class="btn btn-primary">Let's Go!</button>
<button id="newTxt2ImgJob" class="btn btn-primary" data-en_XX="Let's Go!"
data-zh_CN="生成图片!">Let's Go!</button>
</div>
</div>
<div class="col-md-9">
@ -93,54 +102,62 @@
<div class="card-header">
<ul class="nav nav-pills card-header-pills">
<li class="nav-item">
<a class="nav-link" href="#card-txt">Text-to-Image</a>
<a class="nav-link" href="#card-txt" data-en_XX="Text-to-Image"
data-zh_CN="文字->图片">Text-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#card-img">Image-to-Image</a>
<a class="nav-link" href="#card-img" data-en_XX="Image-to-Image"
data-zh_CN="图片->图片">Image-to-Image</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#card-inpainting">Inpainting</a>
<a class="nav-link" href="#card-inpainting" data-en_XX="Inpainting"
data-zh_CN="图片修复">Inpainting</a>
</li>
</ul>
</div>
<div class="card-body card-specific" id="card-img" style="display:none">
<div class="row">
<div class="col-md-6">
<div class="col-md-4">
<div class="card">
<div class="card-header">
<div class="card-header" data-en_XX="Reference Image" data-zh_CN="参照图">
Reference Image
</div>
<div class="card-body">
<div class="row">
<button id="copy-txt-to-img" class="btn btn-primary mb-3">Copy from
Txt-to-Image</button>
<button id="copy-last-img" class="btn btn-primary mb-3">Copy from
Last
Image
Result</button>
<button id="upload-img" class="btn btn-primary mb-3">Upload
Image</button>
<button id="copy-txt-to-img" class="btn btn-primary mb-3"
data-en_XX="Copy from text-to-image"
data-zh_CN="从【文字->图片】结果复制">Copy from text-to-image</button>
<button id="copy-last-img" class="btn btn-primary mb-3"
data-en_XX="Copy from last image result"
data-zh_CN="从【图片->图片】结果复制">Copy from last image result</button>
<button id="upload-img" class="btn btn-primary mb-3"
data-en_XX="Upload image" data-zh_CN="上传一张图片">Upload
image</button>
</div>
<div class="form-row">
<label for="strength">Strength</label>
<label for="strength" data-en_XX="Strength"
data-zh_CN="改变程度">Strength</label>
<input type="number" class="form-control" id="inputStrength"
aria-describedby="inputStrengthHelp" placeholder="0.5" min="0"
max="1">
<div id="inputStrengthHelp" class="form-text">How semantically
consistent with the origional image.
<div id="inputStrengthHelp" class="form-text"
data-en_XX="How different from the original image. 0 means the same, 1 means very different."
data-zh_CN="和参照图有多么的不同。0指一样1指非常不一样。">How different from the
original image
</div>
</div>
<div class="row">
<button id="newImg2ImgJob" class="btn btn-primary mb-3">Let's Go
with Image Below!</button>
<button id="newImg2ImgJob" class="btn btn-primary mb-3"
data-en_XX="Let's Go with Image Below!"
data-zh_CN="就用下面的图生成!">Let's Go with Image Below!</button>
</div>
</div>
<img class="card-img-bottom" id="reference-img">
</div>
</div>
<div class="col-md-6">
<div class="col-md-8">
<div class="card">
<div class="card-header">
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
Result
</div>
<div class="card-body">
@ -163,7 +180,7 @@
</div>
<div class="card-body card-specific" id="card-txt" style="display:none">
<div class="card">
<div class="card-header">
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
Result
</div>
<div class="card-body">
@ -191,8 +208,10 @@
<label for="jobuuid" class="form-label">Job UUID</label>
<input type="jobuuid" class="form-control" id="jobuuid" aria-describedby="">
</div>
<button id="getjob" type="submit" class="btn btn-primary" disabled>Get Jobs</button>
<button id="canceljob" type="submit" class="btn btn-primary" disabled>Cancel Job</button>
<button id="getjob" type="submit" class="btn btn-primary" data-en_XX="Get Jobs" data-zh_CN="搜索生成结果"
disabled>Get Jobs</button>
<button id="canceljob" type="submit" class="btn btn-primary" data-en_XX="Cancel Job" data-zh_CN="取消"
disabled>Cancel Job</button>
</form>
<div class="mb-3" id="joblist">
@ -397,7 +416,6 @@
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
@ -411,6 +429,7 @@
'steps': stepsVal,
'width': widthVal,
'height': heightVal,
'lang': $("#language option:selected").val(),
'guidance_scale': guidanceScaleVal,
'neg_prompt': negPromptVal
}),
@ -521,6 +540,7 @@
'steps': stepsVal,
'width': widthVal,
'height': heightVal,
'lang': $("#language option:selected").val(),
'guidance_scale': guidanceScaleVal,
'strength': strengthVal,
'neg_prompt': negPromptVal
@ -541,6 +561,28 @@
});
});
// Listen for changes to the select element
$("#language").change(function () {
// Get the newly selected value
var newLanguage = $(this).val();
// Store the selected value in cache
localStorage.setItem("selectedLanguage", newLanguage);
$("[data-" + newLanguage + "]").each(function () {
$(this).text($(this).data(newLanguage.toLowerCase()));
});
});
// Get the selected value from cache (if it exists)
var cachedLanguage = localStorage.getItem("selectedLanguage");
if (cachedLanguage) {
// Set the selected value
$("#language").val(cachedLanguage);
// Trigger a change event to update the text and store in cache
$("#language").change();
}
});
</script>

View File

@ -106,6 +106,11 @@ py_test(
deps=[":times"],
)
py_library(
name="translator",
srcs=["translator.py"],
)
py_library(
name="web",
srcs=["web.py"],

View File

@ -1,12 +1,13 @@
KEY_APP = "APP"
VALUE_APP = "demo"
LOGGER_NAME = VALUE_APP
LOGGER_NAME_TXT2IMG = "txt2img"
LOGGER_NAME_IMG2IMG = "img2img"
LOGGER_NAME_FRONTEND = VALUE_APP + "_fe"
LOGGER_NAME_BACKEND = VALUE_APP + "_be"
LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img"
LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img"
MAX_JOB_NUMBER = 10
LOCK_FILEPATH = "/tmp/happysd_db.lock"
KEY_OUTPUT_FOLDER = "outfolder"
VALUE_OUTPUT_FOLDER_DEFAULT = ""
@ -30,6 +31,9 @@ VALUE_JOB_IMG2IMG = "img"
REFERENCE_IMG = "ref_img"
VALUE_JOB_INPAINTING = "inpaint"
KEY_LANGUAGE = "lang"
VALUE_LANGUAGE_ZH_CN = "zh_CN"
VALUE_LANGUAGE_EN = "en_XX"
KEY_PROMPT = "prompt"
KEY_NEG_PROMPT = "neg_prompt"
KEY_SEED = "seed"
@ -39,9 +43,9 @@ VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH
KEY_HEIGHT = "height"
VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT
KEY_GUIDANCE_SCALE = "guidance_scale"
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE
VALUE_GUIDANCE_SCALE_DEFAULT = 12.5 # default value for KEY_GUIDANCE_SCALE
KEY_STEPS = "steps"
VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS
VALUE_STEPS_DEFAULT = 100 # default value for KEY_STEPS
KEY_SCHEDULER = "scheduler"
VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
@ -67,6 +71,7 @@ OPTIONAL_KEYS = [
KEY_SCHEDULER, # str
KEY_STRENGTH, # float
REFERENCE_IMG, # str (base64)
KEY_LANGUAGE,
]
# - output only
@ -84,4 +89,9 @@ OUTPUT_ONLY_KEYS = [
KEY_PRIORITY, # int
BASE64IMAGE, # str (base64)
KEY_JOB_STATUS, # str
]
SUPPORTED_LANGS = [
VALUE_LANGUAGE_ZH_CN,
VALUE_LANGUAGE_EN,
]

View File

@ -1,6 +1,7 @@
import os
import datetime
import sqlite3
import fcntl
import uuid
from utilities.constants import APIKEY
@ -14,6 +15,7 @@ from utilities.constants import KEY_JOB_STATUS
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.constants import LOCK_FILEPATH
from utilities.constants import OUTPUT_ONLY_KEYS
from utilities.constants import OPTIONAL_KEYS
@ -27,12 +29,26 @@ from utilities.constants import USERS_TABLE_NAME
from utilities.logger import DummyLogger
# Function to acquire a lock on the database file
def acquire_lock():
lock_fd = open(LOCK_FILEPATH, "w")
fcntl.flock(lock_fd, fcntl.LOCK_EX)
# Function to release the lock on the database file
def release_lock():
lock_fd = open(LOCK_FILEPATH, "w")
fcntl.flock(lock_fd, fcntl.LOCK_UN)
lock_fd.close()
class Database:
"""This class represents a SQLite database and assumes single-thread usage."""
def __init__(self, logger: DummyLogger = DummyLogger()):
"""Initialize the class with a logger instance, but without a database connection or cursor."""
self.__connect = None # the database connection object
self.is_connected = False
self.__cursor = None # the cursor object for executing SQL statements
self.__logger = logger # the logger object for logging messages
@ -46,28 +62,32 @@ class Database:
self.__logger.error(f"{db_filepath} does not exist!")
return False
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
self.__cursor = self.__connect.cursor()
self.__logger.info(f"Connected to database {db_filepath}")
self.is_connected = True
return True
def get_cursor(self):
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
return self.__connect.cursor()
def commit(self):
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
return self.__connect.commit()
def validate_user(self, apikey: str) -> str:
"""
Validate if the provided API key exists in the users table and return the corresponding
username if found, or an empty string otherwise.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return ""
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
self.__cursor.execute(query, (apikey,))
result = self.__cursor.fetchone()
self.__logger.debug(result)
c = self.get_cursor()
result = c.execute(query, (apikey,)).fetchone()
if result is not None:
return result[0] # the first column is the username
return result[0]
return ""
def get_all_pending_jobs(self, apikey: str = "") -> list:
@ -79,17 +99,13 @@ class Database:
Returns the number of pending jobs found.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return 0
# Construct the SQL query string and list of arguments
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
query_args = (apikey, VALUE_JOB_PENDING)
# Execute the query and return the count
self.__cursor.execute(query_string, query_args)
result = self.__cursor.fetchone()
c = self.get_cursor()
result = c.execute(query_string, query_args).fetchone()
return result[0]
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
@ -100,30 +116,27 @@ class Database:
Returns a list of jobs matching the filters provided.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return []
# construct the SQL query string and list of arguments based on the provided filters
query_args = []
values = []
query_filters = []
if job_uuid:
query_filters.append(f"{UUID} = ?")
query_args.append(job_uuid)
values.append(job_uuid)
if apikey:
query_filters.append(f"{APIKEY} = ?")
query_args.append(apikey)
values.append(apikey)
if job_status:
query_filters.append(f"{KEY_JOB_STATUS} = ?")
query_args.append(job_status)
values.append(job_status)
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
if query_filters:
query_string += f" WHERE {' AND '.join(query_filters)}"
query += f" WHERE {' AND '.join(query_filters)}"
# execute the query and return the results
self.__cursor.execute(query_string, tuple(query_args))
rows = self.__cursor.fetchall()
c = self.get_cursor()
rows = c.execute(query, tuple(values)).fetchall()
jobs = []
for row in rows:
@ -142,22 +155,24 @@ class Database:
Returns True if the insertion was successful, otherwise False.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return False
if not job_uuid:
job_uuid = str(uuid.uuid4())
self.__logger.info(f"inserting a new job with {job_uuid}")
values = [job_uuid, VALUE_JOB_PENDING]
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
values.append(job_dict.get(column, None))
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
self.__cursor.execute(query, tuple(values))
self.__connect.commit()
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
self.commit()
finally:
release_lock()
return True
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
@ -166,10 +181,6 @@ class Database:
Returns True if the update was successful, otherwise False.
"""
if self.__cursor is None:
self.__logger.error("Did you forget to connect to the database?")
return False
values = []
columns = []
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
@ -182,13 +193,17 @@ class Database:
# Add current timestamp to update query
set_clause += ", updated_at=?"
values.append(datetime.datetime.now())
values.append(job_uuid)
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
values.append(job_uuid)
self.__cursor.execute(query, tuple(values))
self.__connect.commit()
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
self.commit()
finally:
release_lock()
return True
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
@ -202,61 +217,48 @@ class Database:
Returns:
bool: True if the job was cancelled successfully, False otherwise.
"""
return self.delete_job(
job_uuid=job_uuid, apikey=apikey, status=VALUE_JOB_PENDING
)
def delete_job(
self, job_uuid: str = "", apikey: str = "", status: str = ""
) -> bool:
if not job_uuid and not apikey:
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
return False
query = f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?"
if status:
query += f" AND {KEY_JOB_STATUS}=?"
values = []
if job_uuid:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
(
job_uuid,
VALUE_JOB_PENDING,
),
)
else:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
(
apikey,
VALUE_JOB_PENDING,
),
)
values.append(job_uuid)
elif apikey:
values.append(apikey)
if status:
values.append(status)
if self.__cursor.rowcount == 0:
rows_removed = 0
acquire_lock()
try:
c = self.get_cursor()
c.execute(query, tuple(values))
rows_removed = c.rowcount
self.commit()
finally:
release_lock()
if rows_removed == 0:
self.__logger.info("No matching rows found.")
return False
else:
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
self.__connect.commit()
return True
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
"""Delete the job with the given uuid or apikey"""
if job_uuid:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
)
elif apikey:
self.__cursor.execute(
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
)
else:
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
return False
if self.__cursor.rowcount == 0:
print("No matching rows found.")
else:
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
self.__connect.commit()
self.__logger.info(f"{rows_removed} rows removed.")
return True
def safe_disconnect(self):
if self.__connect is not None:
self.__connect.commit()
self.__connect.close()
self.__logger.info("Disconnected from database.")
else:
self.__logger.warn("No database connection to close.")
if not self.is_connected:
raise RuntimeError("Did you forget to connect() to the database?")
self.commit()
self.__connect.close()
self.__logger.info("Disconnected from database.")

View File

@ -37,7 +37,9 @@ class Text2Img:
def breakfast(self):
pass
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> dict:
def lunch(
self, prompt: str, negative_prompt: str = "", config: Config = Config()
) -> dict:
self.model.set_txt2img_scheduler(config.get_scheduler())
t = get_epoch_now()