separate frontend and backend, adds multilingual support (chinese), improve multiaccess to database
This commit is contained in:
parent
6cd69355ac
commit
7e84144432
24
BUILD
24
BUILD
|
|
@ -4,19 +4,31 @@ load("@subpar//:subpar.bzl", "par_binary")
|
|||
package(default_visibility=["//visibility:public"])
|
||||
|
||||
par_binary(
|
||||
name="main",
|
||||
srcs=["main.py"],
|
||||
name="frontend",
|
||||
srcs=["frontend.py"],
|
||||
deps=[
|
||||
"//utilities:constants",
|
||||
"//utilities:database",
|
||||
"//utilities:logger",
|
||||
"//utilities:model",
|
||||
"//utilities:text2img",
|
||||
"//utilities:img2img",
|
||||
"//utilities:envvar",
|
||||
|
||||
"//utilities:times",
|
||||
],
|
||||
data=[
|
||||
"templates/index.html",
|
||||
],
|
||||
)
|
||||
|
||||
par_binary(
|
||||
name="backend",
|
||||
srcs=["backend.py"],
|
||||
deps=[
|
||||
"//utilities:constants",
|
||||
"//utilities:database",
|
||||
"//utilities:logger",
|
||||
"//utilities:model",
|
||||
"//utilities:text2img",
|
||||
"//utilities:translator",
|
||||
"//utilities:img2img",
|
||||
"//utilities:times",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
283
main.py
283
main.py
|
|
@ -1,283 +0,0 @@
|
|||
import argparse
|
||||
import copy
|
||||
import tempfile
|
||||
import pkgutil
|
||||
import uuid
|
||||
from flask import jsonify
|
||||
from flask import Flask
|
||||
from flask import render_template
|
||||
from flask import request
|
||||
from threading import Event
|
||||
from threading import Thread
|
||||
from threading import Lock
|
||||
|
||||
from utilities.constants import APIKEY
|
||||
from utilities.constants import KEY_APP
|
||||
from utilities.constants import KEY_JOB_STATUS
|
||||
from utilities.constants import KEY_JOB_TYPE
|
||||
from utilities.constants import KEY_PROMPT
|
||||
from utilities.constants import KEY_NEG_PROMPT
|
||||
from utilities.constants import LOGGER_NAME
|
||||
from utilities.constants import LOGGER_NAME_IMG2IMG
|
||||
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||
from utilities.constants import REFERENCE_IMG
|
||||
from utilities.constants import MAX_JOB_NUMBER
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
from utilities.constants import REQUIRED_KEYS
|
||||
from utilities.constants import UUID
|
||||
from utilities.constants import VALUE_APP
|
||||
from utilities.constants import VALUE_JOB_TXT2IMG
|
||||
from utilities.constants import VALUE_JOB_IMG2IMG
|
||||
from utilities.constants import VALUE_JOB_INPAINTING
|
||||
from utilities.constants import VALUE_JOB_PENDING
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
from utilities.constants import VALUE_JOB_FAILED
|
||||
from utilities.database import Database
|
||||
from utilities.envvar import get_env_var_with_default
|
||||
from utilities.envvar import get_env_var
|
||||
from utilities.times import wait_for_seconds
|
||||
from utilities.logger import Logger
|
||||
from utilities.model import Model
|
||||
from utilities.config import Config
|
||||
from utilities.text2img import Text2Img
|
||||
from utilities.img2img import Img2Img
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
memory_lock = Lock()
|
||||
event_termination = Event()
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
database = Database(logger)
|
||||
use_gpu = True
|
||||
|
||||
local_job_stack = []
|
||||
local_completed_jobs = []
|
||||
|
||||
|
||||
@app.route("/add_job", methods=["POST"])
|
||||
def add_job():
|
||||
req = request.get_json()
|
||||
|
||||
if APIKEY not in req:
|
||||
logger.error(f"{APIKEY} not present in {req}")
|
||||
return "", 401
|
||||
with memory_lock:
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
logger.error(f"user not found with {req[APIKEY]}")
|
||||
return "", 401
|
||||
|
||||
for key in req.keys():
|
||||
if (key not in REQUIRED_KEYS) and (key not in OPTIONAL_KEYS):
|
||||
return jsonify({"msg": "provided one or more unrecognized keys"}), 404
|
||||
for required_key in REQUIRED_KEYS:
|
||||
if required_key not in req:
|
||||
return jsonify({"msg": "missing one or more required keys"}), 404
|
||||
|
||||
if req[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG and REFERENCE_IMG not in req:
|
||||
return jsonify({"msg": "missing reference image"}), 404
|
||||
|
||||
if database.count_all_pending_jobs(req[APIKEY]) > MAX_JOB_NUMBER:
|
||||
return (
|
||||
jsonify({"msg": "too many jobs in queue, please wait or cancel some"}),
|
||||
500,
|
||||
)
|
||||
|
||||
job_uuid = str(uuid.uuid4())
|
||||
logger.info("adding a new job with uuid {}..".format(job_uuid))
|
||||
|
||||
with memory_lock:
|
||||
database.insert_new_job(req, job_uuid=job_uuid)
|
||||
|
||||
return jsonify({"msg": "", UUID: job_uuid})
|
||||
|
||||
|
||||
@app.route("/cancel_job", methods=["POST"])
|
||||
def cancel_job():
|
||||
req = request.get_json()
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
with memory_lock:
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
if UUID not in req:
|
||||
return jsonify({"msg": "missing uuid"}), 404
|
||||
|
||||
logger.info("cancelling job with uuid {}..".format(req[UUID]))
|
||||
|
||||
with memory_lock:
|
||||
result = database.cancel_job(job_uuid=req[UUID])
|
||||
|
||||
if result:
|
||||
msg = "job with uuid {} removed".format(req[UUID])
|
||||
return jsonify({"msg": msg})
|
||||
|
||||
with memory_lock:
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
if jobs:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"msg": "job {} is not in pending state, unable to cancel".format(
|
||||
req[UUID]
|
||||
)
|
||||
}
|
||||
),
|
||||
405,
|
||||
)
|
||||
|
||||
return (
|
||||
jsonify({"msg": "unable to find the job with uuid {}".format(req[UUID])}),
|
||||
404,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/get_jobs", methods=["POST"])
|
||||
def get_jobs():
|
||||
req = request.get_json()
|
||||
if APIKEY not in req:
|
||||
return "", 401
|
||||
with memory_lock:
|
||||
user = database.validate_user(req[APIKEY])
|
||||
if not user:
|
||||
return "", 401
|
||||
|
||||
with memory_lock:
|
||||
jobs = database.get_jobs(job_uuid=req[UUID])
|
||||
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
def load_model(logger: Logger) -> Model:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
# "CompVis/stable-diffusion-v1-4"
|
||||
# "stabilityai/stable-diffusion-2-1"
|
||||
# "SG161222/Realistic_Vision_V2.0"
|
||||
# "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||
# "darkstorm2150/Protogen_x5.8_Official_Release"
|
||||
# "prompthero/openjourney"
|
||||
# "naclbit/trinart_stable_diffusion_v2"
|
||||
# "hakurei/waifu-diffusion"
|
||||
model_name = "darkstorm2150/Protogen_x5.8_Official_Release"
|
||||
# inpainting model candidates:
|
||||
# "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
|
||||
if use_gpu:
|
||||
model.set_low_memory_mode()
|
||||
model.load_all()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def backend(event_termination, db):
|
||||
model = load_model(logger)
|
||||
text2img = Text2Img(model, logger=Logger(name=LOGGER_NAME_TXT2IMG))
|
||||
img2img = Img2Img(model, logger=Logger(name=LOGGER_NAME_IMG2IMG))
|
||||
|
||||
text2img.breakfast()
|
||||
img2img.breakfast()
|
||||
|
||||
while not event_termination.is_set():
|
||||
wait_for_seconds(1)
|
||||
|
||||
with memory_lock:
|
||||
pending_jobs = database.get_all_pending_jobs()
|
||||
|
||||
if len(pending_jobs) == 0:
|
||||
continue
|
||||
|
||||
next_job = pending_jobs[0]
|
||||
|
||||
with memory_lock:
|
||||
database.update_job({KEY_JOB_STATUS: VALUE_JOB_RUNNING}, job_uuid=next_job[UUID])
|
||||
|
||||
prompt = next_job[KEY_PROMPT]
|
||||
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||
|
||||
config = Config().set_config(next_job)
|
||||
|
||||
try:
|
||||
if next_job[KEY_JOB_TYPE] == VALUE_JOB_TXT2IMG:
|
||||
result_dict = text2img.lunch(
|
||||
prompt=prompt, negative_prompt=negative_prompt, config=config
|
||||
)
|
||||
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_IMG2IMG:
|
||||
ref_img = next_job[REFERENCE_IMG]
|
||||
result_dict = img2img.lunch(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
reference_image=ref_img,
|
||||
config=config,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error("text2img.lunch error: {}".format(e))
|
||||
with memory_lock:
|
||||
database.update_job(
|
||||
{KEY_JOB_STATUS: VALUE_JOB_FAILED}, job_uuid=next_job[UUID]
|
||||
)
|
||||
continue
|
||||
|
||||
with memory_lock:
|
||||
database.update_job({KEY_JOB_STATUS: VALUE_JOB_DONE}, job_uuid=next_job[UUID])
|
||||
database.update_job(result_dict, job_uuid=next_job[UUID])
|
||||
|
||||
logger.critical("stopped")
|
||||
|
||||
|
||||
def main(db_filepath, is_testing: bool = False):
|
||||
database.connect(db_filepath)
|
||||
|
||||
if is_testing:
|
||||
try:
|
||||
app.run(host="0.0.0.0", port="5000")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return
|
||||
thread = Thread(
|
||||
target=backend,
|
||||
args=(
|
||||
event_termination,
|
||||
database,
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
# ugly solution for now
|
||||
# TODO: use a database to track instead of internal memory
|
||||
try:
|
||||
app.run(host="0.0.0.0", port="8888")
|
||||
thread.join()
|
||||
except KeyboardInterrupt:
|
||||
event_termination.set()
|
||||
|
||||
database.safe_disconnect()
|
||||
|
||||
thread.join(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add an argument to set the 'testing' flag
|
||||
parser.add_argument("--testing", action="store_true", help="Enable testing mode")
|
||||
|
||||
# Add an argument to set the path of the database file
|
||||
parser.add_argument(
|
||||
"--db", type=str, default="happysd.db", help="Path to SQLite database file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(args)
|
||||
|
||||
main(args.db, args.testing)
|
||||
|
|
@ -34,13 +34,14 @@ def create_table_history(c):
|
|||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS history
|
||||
(uuid TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
apikey TEXT,
|
||||
priority INT,
|
||||
type TEXT,
|
||||
status TEXT,
|
||||
prompt TEXT,
|
||||
lang TEXT,
|
||||
neg_prompt TEXT,
|
||||
seed TEXT,
|
||||
ref_img TEXT,
|
||||
|
|
@ -146,7 +147,7 @@ def show_users(c, username="", details=False):
|
|||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute(
|
||||
"SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
(user[1],),
|
||||
)
|
||||
rows = c.fetchall()
|
||||
|
|
@ -163,7 +164,7 @@ def show_users(c, username="", details=False):
|
|||
print(f"Username: {user[0]}, API Key: {user[1]}, Number of jobs: {count}")
|
||||
if details:
|
||||
c.execute(
|
||||
"SELECT uuid, created_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
"SELECT uuid, created_at, updated_at, type, status, width, height, steps, prompt, neg_prompt FROM history WHERE apikey=?",
|
||||
(user[1],),
|
||||
)
|
||||
rows = c.fetchall()
|
||||
|
|
|
|||
|
|
@ -7,3 +7,5 @@ Pillow
|
|||
scikit-image
|
||||
torch
|
||||
transformers
|
||||
sentencepiece
|
||||
fcntl
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
<html lang="en">
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
|
@ -17,75 +17,84 @@
|
|||
<div class="card-body">
|
||||
<div class="row mb-3">
|
||||
<div class="col-sm-8">
|
||||
<label for="apiKey" class="form-label">API Key</label>
|
||||
<input type="password" class="form-control" id="apiKey" value="demo">
|
||||
<label for="apiKey" class="form-label" data-en_XX="API Key" data-zh_CN="API 密钥">API Key</label>
|
||||
<input type="password" class="form-control" id="apiKey" value="">
|
||||
</div>
|
||||
<div class="col-sm-4">
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="showPreview" disabled>
|
||||
<label class="form-check-label" for="showPreview">
|
||||
Preview Image
|
||||
<label class="form-check-label" for="showPreview" data-en_XX="Preview Image"
|
||||
data-zh_CN="预览生成图像">Preview Image
|
||||
</label>
|
||||
</div>
|
||||
<select class="form-select" size="2" id="language">
|
||||
<option value="zh_CN">中文</option>
|
||||
<option selected value="en_XX">English</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="prompt" class="form-label">Describe Your Image</label>
|
||||
<label for="prompt" class="form-label" data-en_XX="Describe Your Image"
|
||||
data-zh_CN="形容你的图片(提示词)">Describe Your Image (Prompts)</label>
|
||||
<input type="text" class="form-control" id="prompt" aria-describedby="promptHelp" value="">
|
||||
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated. Example:
|
||||
"photo of cute cat, RAW photo, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high
|
||||
quality, film grain, Fujifilm XT3"</div>
|
||||
<div id="promptHelp" class="form-text"
|
||||
data-en_XX="Less than 77 words. Example: photo of a cute cat. Use () to emphasize."
|
||||
data-zh_CN="少于77个词。比如:一张可爱的猫的照片。用括号()强调重要性。">Less than 77 words. Example: photo of a cute cat.
|
||||
Use () to emphasize.</div>
|
||||
</div>
|
||||
<div class="form-row mb-3">
|
||||
<label for="negPrompt" class="form-label">Describe What's NOT Your Image</label>
|
||||
<label for="negPrompt" class="form-label"
|
||||
data-en_XX="Describe What's NOT Your Image (Negative Prompts)"
|
||||
data-zh_CN="反向形容你的图片(反向提示词)">Describe What's NOT Your Image (Negative Prompts)</label>
|
||||
<input type="text" class="form-control" id="negPrompt" aria-describedby="negPromptHelp" value="">
|
||||
<div id="negPromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated.
|
||||
Example: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon,
|
||||
drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg
|
||||
artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn
|
||||
hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad
|
||||
proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs,
|
||||
missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long
|
||||
neck"</div>
|
||||
<div id="negPromptHelp" class="form-text" data-en_XX="Less than 77 words. Optional."
|
||||
data-zh_CN="少于77个词。非必填。">Less than 77 words. Optional.</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col-md-3">
|
||||
<div class="form-row">
|
||||
<label for="inputSeed">Seed</label>
|
||||
<label for="inputSeed" data-en_XX="Seed" data-zh_CN="随机数">Seed</label>
|
||||
<input type="text" class="form-control" id="inputSeed" aria-describedby="inputSeedHelp"
|
||||
value="">
|
||||
<div id="inputSeedHelp" class="form-text">Leave it empty or put 0 to use a random
|
||||
seed
|
||||
<div id="inputSeedHelp" class="form-text"
|
||||
data-en_XX="Leave it empty or set 0 to use random seed" data-zh_CN="非必填。留白或填0使用默认随机数">
|
||||
Leave it empty or set 0 to use a random seed
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputSteps">Steps</label>
|
||||
<label for="inputSteps" data-en_XX="Steps" data-zh_CN="迭代次数">Steps</label>
|
||||
<input type="number" class="form-control" id="inputSteps" aria-describedby="inputStepsHelp"
|
||||
placeholder="default is 50">
|
||||
<div id="inputStepsHelp" class="form-text">Each step is about 38s (CPU) or 0.1s
|
||||
(GPU)
|
||||
placeholder="50">
|
||||
<div id="inputStepsHelp" class="form-text"
|
||||
data-en_XX="More steps better image but longer time to generate"
|
||||
data-zh_CN="迭代次数越多图片越好,但生成时间越久">More steps better image but longer time to generate
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputWidth">Width</label>
|
||||
<label for="inputWidth" data-en_XX="Width" data-zh_CN="图片宽度">Width</label>
|
||||
<input type="number" class="form-control" id="inputWidth" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="inputHeight">Height</label>
|
||||
<label for="inputHeight" data-en_XX="Height" data-zh_CN="图片高度">Height</label>
|
||||
<input type="number" class="form-control" id="inputHeight" placeholder="512" min="1"
|
||||
max="1024">
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="guidanceScale">Guidance Scale</label>
|
||||
<label for="guidanceScale" data-en_XX="Guidance Scale" data-zh_CN="指导强度">Guidance
|
||||
Scale</label>
|
||||
<input type="number" class="form-control" id="inputGuidanceScale"
|
||||
aria-describedby="inputGuidanceScaleHelp" placeholder="25" min="1" max="30">
|
||||
<div id="inputGuidanceScaleHelp" class="form-text">How much guidance to follow from
|
||||
description. 20 strictly follow prompt, 7 creative/artistic.
|
||||
aria-describedby="inputGuidanceScaleHelp" placeholder="12.5" min="1" max="30">
|
||||
<div id="inputGuidanceScaleHelp" class="form-text"
|
||||
data-en_XX="Don't set it to the extremes (1 or 30). 20 means strictly follow prompt, 7 creative/artistic. Lower this number if you see bad images."
|
||||
data-zh_CN="不建议设最低1或最高30。20代表提示词非常重要,7更有创造性。如果看到图片结果较差,适当减少强度。">
|
||||
Don't set it to the extremes (1 or 30). 20 means strictly follow prompt, 7
|
||||
creative/artistic. Lower this number if you see bad images.
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<button id="newTxt2ImgJob" class="btn btn-primary">Let's Go!</button>
|
||||
<button id="newTxt2ImgJob" class="btn btn-primary" data-en_XX="Let's Go!"
|
||||
data-zh_CN="生成图片!">Let's Go!</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-9">
|
||||
|
|
@ -93,54 +102,62 @@
|
|||
<div class="card-header">
|
||||
<ul class="nav nav-pills card-header-pills">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#card-txt">Text-to-Image</a>
|
||||
<a class="nav-link" href="#card-txt" data-en_XX="Text-to-Image"
|
||||
data-zh_CN="文字->图片">Text-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#card-img">Image-to-Image</a>
|
||||
<a class="nav-link" href="#card-img" data-en_XX="Image-to-Image"
|
||||
data-zh_CN="图片->图片">Image-to-Image</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#card-inpainting">Inpainting</a>
|
||||
<a class="nav-link" href="#card-inpainting" data-en_XX="Inpainting"
|
||||
data-zh_CN="图片修复">Inpainting</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="card-body card-specific" id="card-img" style="display:none">
|
||||
<div class="row">
|
||||
<div class="col-md-6">
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<div class="card-header" data-en_XX="Reference Image" data-zh_CN="参照图">
|
||||
Reference Image
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<button id="copy-txt-to-img" class="btn btn-primary mb-3">Copy from
|
||||
Txt-to-Image</button>
|
||||
<button id="copy-last-img" class="btn btn-primary mb-3">Copy from
|
||||
Last
|
||||
Image
|
||||
Result</button>
|
||||
<button id="upload-img" class="btn btn-primary mb-3">Upload
|
||||
Image</button>
|
||||
<button id="copy-txt-to-img" class="btn btn-primary mb-3"
|
||||
data-en_XX="Copy from text-to-image"
|
||||
data-zh_CN="从【文字->图片】结果复制">Copy from text-to-image</button>
|
||||
<button id="copy-last-img" class="btn btn-primary mb-3"
|
||||
data-en_XX="Copy from last image result"
|
||||
data-zh_CN="从【图片->图片】结果复制">Copy from last image result</button>
|
||||
<button id="upload-img" class="btn btn-primary mb-3"
|
||||
data-en_XX="Upload image" data-zh_CN="上传一张图片">Upload
|
||||
image</button>
|
||||
</div>
|
||||
<div class="form-row">
|
||||
<label for="strength">Strength</label>
|
||||
<label for="strength" data-en_XX="Strength"
|
||||
data-zh_CN="改变程度">Strength</label>
|
||||
<input type="number" class="form-control" id="inputStrength"
|
||||
aria-describedby="inputStrengthHelp" placeholder="0.5" min="0"
|
||||
max="1">
|
||||
<div id="inputStrengthHelp" class="form-text">How semantically
|
||||
consistent with the origional image.
|
||||
<div id="inputStrengthHelp" class="form-text"
|
||||
data-en_XX="How different from the original image. 0 means the same, 1 means very different."
|
||||
data-zh_CN="和参照图有多么的不同。0指一样,1指非常不一样。">How different from the
|
||||
original image
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<button id="newImg2ImgJob" class="btn btn-primary mb-3">Let's Go
|
||||
with Image Below!</button>
|
||||
<button id="newImg2ImgJob" class="btn btn-primary mb-3"
|
||||
data-en_XX="Let's Go with Image Below!"
|
||||
data-zh_CN="就用下面的图生成!">Let's Go with Image Below!</button>
|
||||
</div>
|
||||
</div>
|
||||
<img class="card-img-bottom" id="reference-img">
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<div class="col-md-8">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
|
||||
Result
|
||||
</div>
|
||||
<div class="card-body">
|
||||
|
|
@ -163,7 +180,7 @@
|
|||
</div>
|
||||
<div class="card-body card-specific" id="card-txt" style="display:none">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<div class="card-header" data-en_XX="Result" data-zh_CN="结果">
|
||||
Result
|
||||
</div>
|
||||
<div class="card-body">
|
||||
|
|
@ -191,8 +208,10 @@
|
|||
<label for="jobuuid" class="form-label">Job UUID</label>
|
||||
<input type="jobuuid" class="form-control" id="jobuuid" aria-describedby="">
|
||||
</div>
|
||||
<button id="getjob" type="submit" class="btn btn-primary" disabled>Get Jobs</button>
|
||||
<button id="canceljob" type="submit" class="btn btn-primary" disabled>Cancel Job</button>
|
||||
<button id="getjob" type="submit" class="btn btn-primary" data-en_XX="Get Jobs" data-zh_CN="搜索生成结果"
|
||||
disabled>Get Jobs</button>
|
||||
<button id="canceljob" type="submit" class="btn btn-primary" data-en_XX="Cancel Job" data-zh_CN="取消"
|
||||
disabled>Cancel Job</button>
|
||||
</form>
|
||||
|
||||
<div class="mb-3" id="joblist">
|
||||
|
|
@ -397,7 +416,6 @@
|
|||
return;
|
||||
}
|
||||
|
||||
// Send POST request using Ajax
|
||||
$.ajax({
|
||||
type: 'POST',
|
||||
url: '/add_job',
|
||||
|
|
@ -411,6 +429,7 @@
|
|||
'steps': stepsVal,
|
||||
'width': widthVal,
|
||||
'height': heightVal,
|
||||
'lang': $("#language option:selected").val(),
|
||||
'guidance_scale': guidanceScaleVal,
|
||||
'neg_prompt': negPromptVal
|
||||
}),
|
||||
|
|
@ -521,6 +540,7 @@
|
|||
'steps': stepsVal,
|
||||
'width': widthVal,
|
||||
'height': heightVal,
|
||||
'lang': $("#language option:selected").val(),
|
||||
'guidance_scale': guidanceScaleVal,
|
||||
'strength': strengthVal,
|
||||
'neg_prompt': negPromptVal
|
||||
|
|
@ -541,6 +561,28 @@
|
|||
});
|
||||
});
|
||||
|
||||
// Listen for changes to the select element
|
||||
$("#language").change(function () {
|
||||
// Get the newly selected value
|
||||
var newLanguage = $(this).val();
|
||||
|
||||
// Store the selected value in cache
|
||||
localStorage.setItem("selectedLanguage", newLanguage);
|
||||
|
||||
$("[data-" + newLanguage + "]").each(function () {
|
||||
$(this).text($(this).data(newLanguage.toLowerCase()));
|
||||
});
|
||||
});
|
||||
|
||||
// Get the selected value from cache (if it exists)
|
||||
var cachedLanguage = localStorage.getItem("selectedLanguage");
|
||||
if (cachedLanguage) {
|
||||
// Set the selected value
|
||||
$("#language").val(cachedLanguage);
|
||||
|
||||
// Trigger a change event to update the text and store in cache
|
||||
$("#language").change();
|
||||
}
|
||||
|
||||
});
|
||||
</script>
|
||||
|
|
|
|||
|
|
@ -106,6 +106,11 @@ py_test(
|
|||
deps=[":times"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="translator",
|
||||
srcs=["translator.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="web",
|
||||
srcs=["web.py"],
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
KEY_APP = "APP"
|
||||
VALUE_APP = "demo"
|
||||
|
||||
LOGGER_NAME = VALUE_APP
|
||||
LOGGER_NAME_TXT2IMG = "txt2img"
|
||||
LOGGER_NAME_IMG2IMG = "img2img"
|
||||
LOGGER_NAME_FRONTEND = VALUE_APP + "_fe"
|
||||
LOGGER_NAME_BACKEND = VALUE_APP + "_be"
|
||||
LOGGER_NAME_TXT2IMG = VALUE_APP + "_txt2img"
|
||||
LOGGER_NAME_IMG2IMG = VALUE_APP + "_img2img"
|
||||
MAX_JOB_NUMBER = 10
|
||||
|
||||
|
||||
LOCK_FILEPATH = "/tmp/happysd_db.lock"
|
||||
|
||||
KEY_OUTPUT_FOLDER = "outfolder"
|
||||
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||
|
|
@ -30,6 +31,9 @@ VALUE_JOB_IMG2IMG = "img"
|
|||
REFERENCE_IMG = "ref_img"
|
||||
VALUE_JOB_INPAINTING = "inpaint"
|
||||
|
||||
KEY_LANGUAGE = "lang"
|
||||
VALUE_LANGUAGE_ZH_CN = "zh_CN"
|
||||
VALUE_LANGUAGE_EN = "en_XX"
|
||||
KEY_PROMPT = "prompt"
|
||||
KEY_NEG_PROMPT = "neg_prompt"
|
||||
KEY_SEED = "seed"
|
||||
|
|
@ -39,9 +43,9 @@ VALUE_WIDTH_DEFAULT = 512 # default value for KEY_WIDTH
|
|||
KEY_HEIGHT = "height"
|
||||
VALUE_HEIGHT_DEFAULT = 512 # default value for KEY_HEIGHT
|
||||
KEY_GUIDANCE_SCALE = "guidance_scale"
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 25.0 # default value for KEY_GUIDANCE_SCALE
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 12.5 # default value for KEY_GUIDANCE_SCALE
|
||||
KEY_STEPS = "steps"
|
||||
VALUE_STEPS_DEFAULT = 50 # default value for KEY_STEPS
|
||||
VALUE_STEPS_DEFAULT = 100 # default value for KEY_STEPS
|
||||
KEY_SCHEDULER = "scheduler"
|
||||
VALUE_SCHEDULER_DEFAULT = "Default" # default value for KEY_SCHEDULER
|
||||
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
|
||||
|
|
@ -67,6 +71,7 @@ OPTIONAL_KEYS = [
|
|||
KEY_SCHEDULER, # str
|
||||
KEY_STRENGTH, # float
|
||||
REFERENCE_IMG, # str (base64)
|
||||
KEY_LANGUAGE,
|
||||
]
|
||||
|
||||
# - output only
|
||||
|
|
@ -84,4 +89,9 @@ OUTPUT_ONLY_KEYS = [
|
|||
KEY_PRIORITY, # int
|
||||
BASE64IMAGE, # str (base64)
|
||||
KEY_JOB_STATUS, # str
|
||||
]
|
||||
|
||||
SUPPORTED_LANGS = [
|
||||
VALUE_LANGUAGE_ZH_CN,
|
||||
VALUE_LANGUAGE_EN,
|
||||
]
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import datetime
|
||||
import sqlite3
|
||||
import fcntl
|
||||
import uuid
|
||||
|
||||
from utilities.constants import APIKEY
|
||||
|
|
@ -14,6 +15,7 @@ from utilities.constants import KEY_JOB_STATUS
|
|||
from utilities.constants import VALUE_JOB_PENDING
|
||||
from utilities.constants import VALUE_JOB_RUNNING
|
||||
from utilities.constants import VALUE_JOB_DONE
|
||||
from utilities.constants import LOCK_FILEPATH
|
||||
|
||||
from utilities.constants import OUTPUT_ONLY_KEYS
|
||||
from utilities.constants import OPTIONAL_KEYS
|
||||
|
|
@ -27,12 +29,26 @@ from utilities.constants import USERS_TABLE_NAME
|
|||
from utilities.logger import DummyLogger
|
||||
|
||||
|
||||
# Function to acquire a lock on the database file
|
||||
def acquire_lock():
|
||||
lock_fd = open(LOCK_FILEPATH, "w")
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||
|
||||
|
||||
# Function to release the lock on the database file
|
||||
def release_lock():
|
||||
lock_fd = open(LOCK_FILEPATH, "w")
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
lock_fd.close()
|
||||
|
||||
|
||||
class Database:
|
||||
"""This class represents a SQLite database and assumes single-thread usage."""
|
||||
|
||||
def __init__(self, logger: DummyLogger = DummyLogger()):
|
||||
"""Initialize the class with a logger instance, but without a database connection or cursor."""
|
||||
self.__connect = None # the database connection object
|
||||
self.is_connected = False
|
||||
self.__cursor = None # the cursor object for executing SQL statements
|
||||
self.__logger = logger # the logger object for logging messages
|
||||
|
||||
|
|
@ -46,28 +62,32 @@ class Database:
|
|||
self.__logger.error(f"{db_filepath} does not exist!")
|
||||
return False
|
||||
self.__connect = sqlite3.connect(db_filepath, check_same_thread=False)
|
||||
self.__cursor = self.__connect.cursor()
|
||||
self.__logger.info(f"Connected to database {db_filepath}")
|
||||
self.is_connected = True
|
||||
return True
|
||||
|
||||
def get_cursor(self):
|
||||
if not self.is_connected:
|
||||
raise RuntimeError("Did you forget to connect() to the database?")
|
||||
return self.__connect.cursor()
|
||||
|
||||
def commit(self):
|
||||
if not self.is_connected:
|
||||
raise RuntimeError("Did you forget to connect() to the database?")
|
||||
return self.__connect.commit()
|
||||
|
||||
def validate_user(self, apikey: str) -> str:
|
||||
"""
|
||||
Validate if the provided API key exists in the users table and return the corresponding
|
||||
username if found, or an empty string otherwise.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return ""
|
||||
|
||||
query = f"SELECT username FROM {USERS_TABLE_NAME} WHERE {APIKEY}=?"
|
||||
self.__cursor.execute(query, (apikey,))
|
||||
result = self.__cursor.fetchone()
|
||||
|
||||
self.__logger.debug(result)
|
||||
c = self.get_cursor()
|
||||
result = c.execute(query, (apikey,)).fetchone()
|
||||
|
||||
if result is not None:
|
||||
return result[0] # the first column is the username
|
||||
|
||||
return result[0]
|
||||
return ""
|
||||
|
||||
def get_all_pending_jobs(self, apikey: str = "") -> list:
|
||||
|
|
@ -79,17 +99,13 @@ class Database:
|
|||
|
||||
Returns the number of pending jobs found.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return 0
|
||||
|
||||
# Construct the SQL query string and list of arguments
|
||||
query_string = f"SELECT COUNT(*) FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?"
|
||||
query_args = (apikey, VALUE_JOB_PENDING)
|
||||
|
||||
# Execute the query and return the count
|
||||
self.__cursor.execute(query_string, query_args)
|
||||
result = self.__cursor.fetchone()
|
||||
c = self.get_cursor()
|
||||
result = c.execute(query_string, query_args).fetchone()
|
||||
return result[0]
|
||||
|
||||
def get_jobs(self, job_uuid="", apikey="", job_status="") -> list:
|
||||
|
|
@ -100,30 +116,27 @@ class Database:
|
|||
|
||||
Returns a list of jobs matching the filters provided.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return []
|
||||
|
||||
# construct the SQL query string and list of arguments based on the provided filters
|
||||
query_args = []
|
||||
values = []
|
||||
query_filters = []
|
||||
if job_uuid:
|
||||
query_filters.append(f"{UUID} = ?")
|
||||
query_args.append(job_uuid)
|
||||
values.append(job_uuid)
|
||||
if apikey:
|
||||
query_filters.append(f"{APIKEY} = ?")
|
||||
query_args.append(apikey)
|
||||
values.append(apikey)
|
||||
if job_status:
|
||||
query_filters.append(f"{KEY_JOB_STATUS} = ?")
|
||||
query_args.append(job_status)
|
||||
values.append(job_status)
|
||||
|
||||
columns = OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
query_string = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
|
||||
query = f"SELECT {', '.join(columns)} FROM {HISTORY_TABLE_NAME}"
|
||||
if query_filters:
|
||||
query_string += f" WHERE {' AND '.join(query_filters)}"
|
||||
query += f" WHERE {' AND '.join(query_filters)}"
|
||||
|
||||
# execute the query and return the results
|
||||
self.__cursor.execute(query_string, tuple(query_args))
|
||||
rows = self.__cursor.fetchall()
|
||||
c = self.get_cursor()
|
||||
rows = c.execute(query, tuple(values)).fetchall()
|
||||
|
||||
jobs = []
|
||||
for row in rows:
|
||||
|
|
@ -142,22 +155,24 @@ class Database:
|
|||
|
||||
Returns True if the insertion was successful, otherwise False.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return False
|
||||
|
||||
if not job_uuid:
|
||||
job_uuid = str(uuid.uuid4())
|
||||
self.__logger.info(f"inserting a new job with {job_uuid}")
|
||||
|
||||
values = [job_uuid, VALUE_JOB_PENDING]
|
||||
columns = [UUID, KEY_JOB_STATUS] + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
values = [job_uuid, VALUE_JOB_PENDING, datetime.datetime.now()]
|
||||
columns = [UUID, KEY_JOB_STATUS, "created_at"] + REQUIRED_KEYS + OPTIONAL_KEYS
|
||||
for column in REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||
values.append(job_dict.get(column, None))
|
||||
|
||||
|
||||
query = f"INSERT INTO {HISTORY_TABLE_NAME} ({', '.join(columns)}) VALUES ({', '.join(['?' for _ in columns])})"
|
||||
self.__cursor.execute(query, tuple(values))
|
||||
self.__connect.commit()
|
||||
|
||||
acquire_lock()
|
||||
try:
|
||||
c = self.get_cursor()
|
||||
c.execute(query, tuple(values))
|
||||
self.commit()
|
||||
finally:
|
||||
release_lock()
|
||||
return True
|
||||
|
||||
def update_job(self, job_dict: dict, job_uuid: str) -> bool:
|
||||
|
|
@ -166,10 +181,6 @@ class Database:
|
|||
|
||||
Returns True if the update was successful, otherwise False.
|
||||
"""
|
||||
if self.__cursor is None:
|
||||
self.__logger.error("Did you forget to connect to the database?")
|
||||
return False
|
||||
|
||||
values = []
|
||||
columns = []
|
||||
for column in OUTPUT_ONLY_KEYS + REQUIRED_KEYS + OPTIONAL_KEYS:
|
||||
|
|
@ -182,13 +193,17 @@ class Database:
|
|||
# Add current timestamp to update query
|
||||
set_clause += ", updated_at=?"
|
||||
values.append(datetime.datetime.now())
|
||||
values.append(job_uuid)
|
||||
|
||||
query = f"UPDATE {HISTORY_TABLE_NAME} SET {set_clause} WHERE {UUID}=?"
|
||||
|
||||
values.append(job_uuid)
|
||||
|
||||
self.__cursor.execute(query, tuple(values))
|
||||
self.__connect.commit()
|
||||
acquire_lock()
|
||||
try:
|
||||
c = self.get_cursor()
|
||||
c.execute(query, tuple(values))
|
||||
self.commit()
|
||||
finally:
|
||||
release_lock()
|
||||
return True
|
||||
|
||||
def cancel_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||
|
|
@ -202,61 +217,48 @@ class Database:
|
|||
Returns:
|
||||
bool: True if the job was cancelled successfully, False otherwise.
|
||||
"""
|
||||
return self.delete_job(
|
||||
job_uuid=job_uuid, apikey=apikey, status=VALUE_JOB_PENDING
|
||||
)
|
||||
|
||||
def delete_job(
|
||||
self, job_uuid: str = "", apikey: str = "", status: str = ""
|
||||
) -> bool:
|
||||
if not job_uuid and not apikey:
|
||||
self.__logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||
return False
|
||||
|
||||
query = f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?"
|
||||
if status:
|
||||
query += f" AND {KEY_JOB_STATUS}=?"
|
||||
values = []
|
||||
if job_uuid:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=? AND {KEY_JOB_STATUS}=?",
|
||||
(
|
||||
job_uuid,
|
||||
VALUE_JOB_PENDING,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=? AND {KEY_JOB_STATUS}=?",
|
||||
(
|
||||
apikey,
|
||||
VALUE_JOB_PENDING,
|
||||
),
|
||||
)
|
||||
values.append(job_uuid)
|
||||
elif apikey:
|
||||
values.append(apikey)
|
||||
if status:
|
||||
values.append(status)
|
||||
|
||||
if self.__cursor.rowcount == 0:
|
||||
rows_removed = 0
|
||||
|
||||
acquire_lock()
|
||||
try:
|
||||
c = self.get_cursor()
|
||||
c.execute(query, tuple(values))
|
||||
rows_removed = c.rowcount
|
||||
self.commit()
|
||||
finally:
|
||||
release_lock()
|
||||
|
||||
if rows_removed == 0:
|
||||
self.__logger.info("No matching rows found.")
|
||||
return False
|
||||
else:
|
||||
self.__logger.info(f"{self.__cursor.rowcount} rows cancelled.")
|
||||
|
||||
self.__connect.commit()
|
||||
return True
|
||||
|
||||
def delete_job(self, job_uuid: str = "", apikey: str = "") -> bool:
|
||||
"""Delete the job with the given uuid or apikey"""
|
||||
if job_uuid:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {UUID}=?", (job_uuid,)
|
||||
)
|
||||
elif apikey:
|
||||
self.__cursor.execute(
|
||||
f"DELETE FROM {HISTORY_TABLE_NAME} WHERE {APIKEY}=?", (apikey,)
|
||||
)
|
||||
else:
|
||||
self.logger.error(f"either {UUID} or {APIKEY} must be provided.")
|
||||
return False
|
||||
|
||||
if self.__cursor.rowcount == 0:
|
||||
print("No matching rows found.")
|
||||
else:
|
||||
self.logger.info(f"{self.__cursor.rowcount} rows deleted.")
|
||||
self.__connect.commit()
|
||||
self.__logger.info(f"{rows_removed} rows removed.")
|
||||
return True
|
||||
|
||||
def safe_disconnect(self):
|
||||
if self.__connect is not None:
|
||||
self.__connect.commit()
|
||||
self.__connect.close()
|
||||
self.__logger.info("Disconnected from database.")
|
||||
else:
|
||||
self.__logger.warn("No database connection to close.")
|
||||
if not self.is_connected:
|
||||
raise RuntimeError("Did you forget to connect() to the database?")
|
||||
self.commit()
|
||||
self.__connect.close()
|
||||
self.__logger.info("Disconnected from database.")
|
||||
|
|
|
|||
|
|
@ -37,7 +37,9 @@ class Text2Img:
|
|||
def breakfast(self):
|
||||
pass
|
||||
|
||||
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> dict:
|
||||
def lunch(
|
||||
self, prompt: str, negative_prompt: str = "", config: Config = Config()
|
||||
) -> dict:
|
||||
self.model.set_txt2img_scheduler(config.get_scheduler())
|
||||
|
||||
t = get_epoch_now()
|
||||
|
|
|
|||
Loading…
Reference in New Issue