support cpu mode (very slow) and fix bug for second job submission

This commit is contained in:
HappyZ 2023-04-28 23:10:56 -07:00
parent 209a4033a5
commit f5d04126fd
7 changed files with 213 additions and 208 deletions

5
BUILD
View File

@ -11,5 +11,10 @@ par_binary(
"//utilities:logger",
"//utilities:model",
"//utilities:text2img",
"//utilities:envvar",
"//utilities:times",
],
data=[
"templates/index.html",
],
)

93
main.py
View File

@ -1,7 +1,10 @@
import copy
import tempfile
import pkgutil
import uuid
from flask import jsonify
from flask import Flask
from flask import render_template
from flask import request
from threading import Event
from threading import Thread
@ -9,7 +12,6 @@ from threading import Lock
from utilities.constants import API_KEY
from utilities.constants import API_KEY_FOR_DEMO
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_APP
from utilities.constants import KEY_JOB_STATUS
from utilities.constants import KEY_PROMPT
@ -23,7 +25,6 @@ from utilities.constants import VALUE_APP
from utilities.constants import VALUE_JOB_PENDING
from utilities.constants import VALUE_JOB_RUNNING
from utilities.constants import VALUE_JOB_DONE
from utilities.web import web
from utilities.envvar import get_env_var_with_default
from utilities.envvar import get_env_var
from utilities.times import wait_for_seconds
@ -33,35 +34,14 @@ from utilities.config import Config
from utilities.text2img import Text2Img
def load_model(logger: Logger) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger)
model.set_low_memory_mode()
model.load_all()
return model
app = Flask(__name__)
memory_lock = Lock()
event_termination = Event()
logger = Logger(name=LOGGER_NAME)
use_gpu = True
local_job_stack = []
local_completed_jobs = {}
local_completed_jobs = []
@app.route("/add_job", methods=["POST"])
@ -90,11 +70,12 @@ def add_job():
logger.info("adding a new job with uuid {}..".format(req[UUID]))
req[KEY_JOB_STATUS] = VALUE_JOB_PENDING
req["position"] = len(local_job_stack) + 1
with memory_lock:
local_job_stack.append(req)
return jsonify({"msg": "", "position": len(local_job_stack), UUID: req[UUID]})
return jsonify({"msg": "", "position": req["position"], UUID: req[UUID]})
@app.route("/cancel_job", methods=["POST"])
@ -167,23 +148,20 @@ def get_jobs():
jobs = []
all_job_stack = local_job_stack + local_completed_jobs
with memory_lock:
for job_position in range(len(local_job_stack)):
for job_position in range(len(all_job_stack)):
# filter on API_KEY
if local_job_stack[job_position][API_KEY] != req[API_KEY]:
if all_job_stack[job_position][API_KEY] != req[API_KEY]:
continue
# filter on UUID
if UUID in req and req[UUID] != local_job_stack[job_position][UUID]:
if UUID in req and req[UUID] != all_job_stack[job_position][UUID]:
continue
job = copy.deepcopy(local_job_stack[job_position])
job = copy.deepcopy(all_job_stack[job_position])
if job[KEY_JOB_STATUS] == VALUE_JOB_DONE:
del job["position"]
del job[API_KEY]
job["position"] = job_position + 1
jobs.append(job)
all_matching_completed_jobs = local_completed_jobs.get(req[API_KEY], {})
if UUID in req:
all_matching_completed_jobs = all_matching_completed_jobs.get(req[UUID], {})
for key in all_matching_completed_jobs.keys():
jobs.append(all_matching_completed_jobs[key])
if len(jobs) == 0:
return (
@ -192,22 +170,44 @@ def get_jobs():
)
return jsonify({"jobs": jobs})
@app.route("/")
def index():
return web()
return render_template("index.html")
def load_model(logger: Logger) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger, use_gpu=use_gpu)
if use_gpu:
model.set_low_memory_mode()
model.load_all()
return model
def backend(event_termination):
model = load_model(logger)
text2img = Text2Img(model, output_folder="/tmp", logger=logger)
text2img = Text2Img(model, logger=logger)
text2img.breakfast()
while True:
while not event_termination.is_set():
wait_for_seconds(1)
if event_termination.is_set():
break
with memory_lock:
if len(local_job_stack) == 0:
continue
@ -219,17 +219,15 @@ def backend(event_termination):
config = Config().set_config(next_job)
base64img = text2img.lunch(
result_dict = text2img.lunch(
prompt=prompt, negative_prompt=negative_prompt, config=config
)
with memory_lock:
local_job_stack.pop(0)
next_job[KEY_JOB_STATUS] = VALUE_JOB_DONE
next_job[BASE64IMAGE] = base64img
if next_job[API_KEY] not in local_completed_jobs:
local_completed_jobs[next_job[API_KEY]] = {}
local_completed_jobs[next_job[API_KEY]][next_job[UUID]] = next_job
next_job.update(result_dict)
local_completed_jobs.append(next_job)
logger.critical("stopped")
if len(local_job_stack) > 0:
@ -241,12 +239,13 @@ def backend(event_termination):
def main():
# app.run(host="0.0.0.0")
thread = Thread(target=backend, args=(event_termination,))
thread.start()
# ugly solution for now
# TODO: use a database to track instead of internal memory
try:
app.run(host='0.0.0.0')
app.run(host="0.0.0.0")
thread.join()
except KeyboardInterrupt:
event_termination.set()

145
templates/index.html Normal file
View File

@ -0,0 +1,145 @@
<html lang="en">
<head>
<meta charset="utf-8">
<title>Happy Diffusion (Private Access) | 9pm</title>
<meta name="description" content="Stable Diffusion Online">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-Zenh87qX5JnK2Jl0vWa8Ck2rdkQ2Bzep5IDxbcnCeuOxjzrPF/et3URy9Bv1WTRi" crossorigin="anonymous">
</head>
<body>
<div class="container">
<form>
<div class="mb-3">
<label for="apikey" class="form-label">API Key</label>
<input type="apikey" class="form-control" id="apikey" value="demo">
</div>
<div class="mb-3">
<label for="prompt" class="form-label">Prompt</label>
<input type="prompt" class="form-control" id="prompt" aria-describedby="promptHelp">
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated</div>
</div>
<div class="mb-3">
<label for="negprompt" class="form-label">Negative Prompt</label>
<input type="negprompt" class="form-control" id="negprompt" aria-describedby="negpromptHelp">
<div id="negpromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated</div>
</div>
<button id="newjob" type="submit" class="btn btn-primary">Submit New Job</button>
</form>
<div class="card mb-3">
<div class="card-body" id="newjobresult"></div>
<img id="newjobresultimg" class="card-img-bottom" />
</div>
<form>
<div class="mb-3">
<label for="jobuuid" class="form-label">Job UUID</label>
<input type="jobuuid" class="form-control" id="jobuuid" aria-describedby="">
</div>
<button id="getjob" type="submit" class="btn btn-primary">Get Jobs</button>
<button id="canceljob" type="submit" class="btn btn-primary">Cancel Job</button>
</form>
<div class="mb-3" id="joblist">
</div>
</div>
<script src="https://code.jquery.com/jquery-3.6.1.min.js"
integrity="sha256-o88AwQnZB+VDvE9tvIXrMQaPlFFSUTR+nldQm1LuPXQ=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/js/bootstrap.bundle.min.js"
integrity="sha384-OERcA2EqjJCMA+/3y+gxIOqMEjwtxJY7qPCqsdltbNJuaOe923+mo//f6V8Qbsw3"
crossorigin="anonymous"></script>
<script>
function waitForImage(uuidValue) {
// Wait until image is done
$.ajax({
type: 'POST',
url: '/get_jobs',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({ 'api_key': 'demo', 'uuid': uuidValue }),
success: function (response) {
console.log(response);
if (response.jobs.length == 1) {
if (response.jobs[0].status == "done") {
$('#newjobresult').html('');
$('#newjobresultimg').attr('src', response.jobs[0].img);
return;
} else {
$('#newjobresult').append("<p>current status: " + response.jobs[0].status + "</p>")
}
}
setTimeout(function () { waitForImage(uuidValue); }, 1000); // refresh every second
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
$('#newjobresult').html('<p>failed to run, see console error for more details<p>');
}
});
}
$(document).ready(function () {
console.log("--- csrf token set ---");
var csrftoken = $("[name=csrfmiddlewaretoken]").val();
function csrfSafeMethod(method) {
// these HTTP methods do not require CSRF protection
return (/^(GET|HEAD|OPTIONS|TRACE)$/.test(method));
}
$.ajaxSetup({
beforeSend: function (xhr, settings) {
if (!csrfSafeMethod(settings.type) && !this.crossDomain) {
xhr.setRequestHeader("X-CSRFToken", csrftoken);
}
}
});
$('#newjob').click(function (e) {
e.preventDefault(); // Prevent the default form submission
// Gather input field values
var apikeyVal = $('#apikey').val();
var promptVal = $('#prompt').val();
var negPromptVal = $('#negprompt').val();
if (promptVal == "") {
alert("needs to write a prompt!");
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType: 'json',
data: JSON.stringify({ 'api_key': apikeyVal, 'prompt': promptVal, 'neg_prompt': negPromptVal }),
success: function (response) {
console.log(response);
if (response.uuid) {
$('#jobuuid').val(response.uuid);
}
$('#newjobresult').html('<p>waiting for result...<p>');
waitForImage(response.uuid);
},
error: function (xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
}
});
});
});
</script>
</body>
</html>

View File

@ -63,6 +63,7 @@ py_library(
name="text2img",
srcs=["text2img.py"],
deps=[
":constants",
":config",
":logger",
":images",

View File

@ -32,7 +32,7 @@ class Model:
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
self.__logger = logger
self.__torch_dtype = "auto"
self.__torch_dtype = torch.float64
# txt2img and img2img are always loaded together
self.txt2img_pipeline = None

View File

@ -1,8 +1,12 @@
import torch
from typing import Union
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_SEED
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_STEPS
from utilities.config import Config
from utilities.images import save_image
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
@ -22,6 +26,7 @@ class Text2Img:
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__output_folder = output_folder
self.__logger = logger
@ -32,12 +37,12 @@ class Text2Img:
def breakfast(self):
pass
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> str:
def lunch(self, prompt: str, negative_prompt: str = "", config: Config = Config()) -> dict:
self.model.set_txt2img_scheduler(config.get_scheduler())
t = get_epoch_now()
seed = config.get_seed()
generator = torch.Generator("cuda").manual_seed(seed)
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
result = self.model.txt2img_pipeline(
@ -59,4 +64,10 @@ class Text2Img:
empty_memory_cache()
return image_to_base64(result.images[0])
return {
BASE64IMAGE: image_to_base64(result.images[0]),
KEY_SEED.lower(): seed,
KEY_WIDTH.lower(): config.get_width(),
KEY_HEIGHT.lower(): config.get_height(),
KEY_STEPS.lower(): config.get_steps(),
}

View File

@ -1,156 +0,0 @@
def javascript():
return """
<script>
function waitForImage(uuidValue) {
// Wait until image is done
$.ajax({
type: 'POST',
url: '/get_jobs',
contentType: 'application/json; charset=utf-8',
dataType:'json',
data: JSON.stringify({'api_key': 'demo', 'uuid': uuidValue}),
success: function(response) {
console.log(response);
$('#newjobresult').html('<p>waiting for result...<p>');
if (response.jobs.length == 1) {
if (response.jobs[0].status == "done") {
$('#newjobresultimg').attr('src', response.jobs[0].img);
return;
} else {
$('#newjobresult').append("<p>current status: " + response.jobs[0].status + "</p>")
}
}
setTimeout(waitForImage, 1000); // refresh every second
},
error: function(xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
}
});
}
$( document ).ready(function() {
console.log( "--- csrf token set ---" );
var csrftoken = $("[name=csrfmiddlewaretoken]").val();
function csrfSafeMethod(method) {
// these HTTP methods do not require CSRF protection
return (/^(GET|HEAD|OPTIONS|TRACE)$/.test(method));
}
$.ajaxSetup({
beforeSend: function(xhr, settings) {
if (!csrfSafeMethod(settings.type) && !this.crossDomain) {
xhr.setRequestHeader("X-CSRFToken", csrftoken);
}
}
});
$('#newjob').click(function(e) {
e.preventDefault(); // Prevent the default form submission
// Gather input field values
var promptVal = $('#prompt').val();
var negPromptVal = $('#negprompt').val();
if (promptVal == "") {
alert("needs to write a prompt!");
return;
}
// Send POST request using Ajax
$.ajax({
type: 'POST',
url: '/add_job',
contentType: 'application/json; charset=utf-8',
dataType:'json',
data: JSON.stringify({'api_key': 'demo', 'prompt': promptVal, 'neg_prompt': negPromptVal}),
success: function(response) {
console.log(response);
if (response.uuid) {
$('#jobuuid').val(response.uuid);
}
waitForImage(response.uuid);
},
error: function(xhr, status, error) {
// Handle error response
console.log(xhr.responseText);
}
});
});
});
</script>
"""
def stylesheet():
return """
"""
def content():
return """
<form>
<div class="mb-3">
<label for="prompt" class="form-label">Prompt</label>
<input type="prompt" class="form-control" id="prompt" aria-describedby="promptHelp">
<div id="promptHelp" class="form-text">Less than 77 words otherwise it'll be truncated</div>
</div>
<div class="mb-3">
<label for="negprompt" class="form-label">Negative Prompt</label>
<input type="negprompt" class="form-control" id="negprompt" aria-describedby="negpromptHelp">
<div id="negpromptHelp" class="form-text">Less than 77 words otherwise it'll be truncated</div>
</div>
<button id="newjob" type="submit" class="btn btn-primary">Submit New Job</button>
</form>
<div class="card mb-3">
<div class="card-body" id="newjobresult"></div>
<img id="newjobresultimg" class="card-img-bottom" />
</div>
<form>
<div class="mb-3">
<label for="jobuuid" class="form-label">Job UUID</label>
<input type="jobuuid" class="form-control" id="jobuuid" aria-describedby="">
</div>
<button id="getjob" type="submit" class="btn btn-primary">Get Jobs</button>
<button id="canceljob" type="submit" class="btn btn-primary">Cancel Job</button>
</form>
<div class="mb-3" id="joblist">
</div>
"""
def web():
return """
<html lang="en">
<head>
<meta charset="utf-8">
<title>Happy Diffusion (Private Access) | 9pm</title>
<meta name="description" content="Stable Diffusion Online">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-Zenh87qX5JnK2Jl0vWa8Ck2rdkQ2Bzep5IDxbcnCeuOxjzrPF/et3URy9Bv1WTRi" crossorigin="anonymous">
{css}
</head>
<body>
<div class="container">{content}</div>
<script src="https://code.jquery.com/jquery-3.6.1.min.js"
integrity="sha256-o88AwQnZB+VDvE9tvIXrMQaPlFFSUTR+nldQm1LuPXQ=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.2/dist/js/bootstrap.bundle.min.js"
integrity="sha384-OERcA2EqjJCMA+/3y+gxIOqMEjwtxJY7qPCqsdltbNJuaOe923+mo//f6V8Qbsw3"
crossorigin="anonymous"></script>
{js}
</body>
</html>
""".format(
content=content(),
css=stylesheet(),
js=javascript(),
)