introduce state for task skipping/stopping

This commit is contained in:
Manuel Schmid 2023-11-18 20:46:09 +01:00 committed by Manuel Schmid
parent 83efbbd2fa
commit ebad9ea976
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 28 additions and 17 deletions

View File

@ -7,6 +7,7 @@ class AsyncTask:
self.yields = []
self.results = []
self.last_stop = False
self.processing = False
async_tasks = []
@ -115,6 +116,7 @@ def worker():
@torch.inference_mode()
def handler(async_task):
execution_start_time = time.perf_counter()
async_task.processing = True
args = async_task.args
args.reverse()
@ -660,6 +662,8 @@ def worker():
execution_start_time = time.perf_counter()
try:
if async_task.last_stop is not False:
fcbh.model_management.interrupt_current_processing()
positive_cond, negative_cond = task['c'], task['uc']
if 'cn' in goals:
@ -732,7 +736,7 @@ def worker():
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')
async_task.processing = False
return
while True:

View File

@ -20,13 +20,15 @@ from modules.private_logger import get_current_html_path
from modules.ui_gradio_extensions import reload_javascript
from modules.auth import auth_enabled, check_auth
currentTask = gr.State()
def get_task(*args):
args = list(args)
currentTask = args.pop(0)
currentTask = worker.AsyncTask(args=args)
return currentTask
def generate_clicked(*args):
def generate_clicked(task):
# outputs=[progress_html, progress_window, progress_gallery, gallery]
execution_start_time = time.perf_counter()
currentTask.value = task = worker.AsyncTask(args=list(args))
finished = False
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
@ -82,6 +84,7 @@ shared.gradio_root = gr.Blocks(
css=modules.html.css).queue()
with shared.gradio_root:
currentTask = gr.State(worker.AsyncTask(args=[]))
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
@ -108,21 +111,24 @@ with shared.gradio_root:
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
def stop_clicked():
def stop_clicked(currentTask):
import fcbh.model_management as model_management
currentTask.value.last_stop = 'stop'
model_management.interrupt_current_processing()
return [gr.update(interactive=False)] * 2
currentTask.last_stop = 'stop'
if (currentTask.processing):
model_management.interrupt_current_processing()
return [gr.update(interactive=False)] * 2, currentTask
def skip_clicked():
def skip_clicked(currentTask):
import fcbh.model_management as model_management
currentTask.value.last_stop = 'skip'
model_management.interrupt_current_processing()
return
currentTask.last_stop = 'skip'
if (currentTask.processing):
model_management.interrupt_current_processing()
return currentTask
stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
stop_button.click(stop_clicked, inputs=currentTask, outputs=[skip_button, stop_button, currentTask],
queue=False, show_progress=False, _js='cancelGenerateForever')
skip_button.click(skip_clicked, queue=False, show_progress=False)
skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask,
queue=False, show_progress=False)
with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
@ -428,7 +434,7 @@ with shared.gradio_root:
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)
ctrls = [
prompt, negative_prompt, style_selections,
currentTask, prompt, negative_prompt, style_selections,
performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
]
@ -441,7 +447,8 @@ with shared.gradio_root:
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(fn=get_task, inputs=ctrls, outputs=currentTask) \
.then(fn=generate_clicked, inputs=currentTask, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')