diff --git a/modules/async_worker.py b/modules/async_worker.py index 0ca051d0..e6cc8a49 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -7,6 +7,7 @@ class AsyncTask: self.yields = [] self.results = [] self.last_stop = False + self.processing = False async_tasks = [] @@ -115,6 +116,7 @@ def worker(): @torch.inference_mode() def handler(async_task): execution_start_time = time.perf_counter() + async_task.processing = True args = async_task.args args.reverse() @@ -660,6 +662,8 @@ def worker(): execution_start_time = time.perf_counter() try: + if async_task.last_stop is not False: + fcbh.model_management.interrupt_current_processing() positive_cond, negative_cond = task['c'], task['uc'] if 'cn' in goals: @@ -732,7 +736,7 @@ def worker(): execution_time = time.perf_counter() - execution_start_time print(f'Generating and saving time: {execution_time:.2f} seconds') - + async_task.processing = False return while True: diff --git a/webui.py b/webui.py index e43a6c96..3dde192c 100644 --- a/webui.py +++ b/webui.py @@ -20,13 +20,15 @@ from modules.private_logger import get_current_html_path from modules.ui_gradio_extensions import reload_javascript from modules.auth import auth_enabled, check_auth -currentTask = gr.State() +def get_task(*args): + args = list(args) + currentTask = args.pop(0) + currentTask = worker.AsyncTask(args=args) + return currentTask -def generate_clicked(*args): +def generate_clicked(task): # outputs=[progress_html, progress_window, progress_gallery, gallery] - execution_start_time = time.perf_counter() - currentTask.value = task = worker.AsyncTask(args=list(args)) finished = False yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \ @@ -82,6 +84,7 @@ shared.gradio_root = gr.Blocks( css=modules.html.css).queue() with shared.gradio_root: + currentTask = gr.State(worker.AsyncTask(args=[])) with gr.Row(): with gr.Column(scale=2): with gr.Row(): @@ -108,21 +111,24 @@ with shared.gradio_root: skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) - def stop_clicked(): + def stop_clicked(currentTask): import fcbh.model_management as model_management - currentTask.value.last_stop = 'stop' - model_management.interrupt_current_processing() - return [gr.update(interactive=False)] * 2 + currentTask.last_stop = 'stop' + if (currentTask.processing): + model_management.interrupt_current_processing() + return [gr.update(interactive=False)] * 2, currentTask - def skip_clicked(): + def skip_clicked(currentTask): import fcbh.model_management as model_management - currentTask.value.last_stop = 'skip' - model_management.interrupt_current_processing() - return + currentTask.last_stop = 'skip' + if (currentTask.processing): + model_management.interrupt_current_processing() + return currentTask - stop_button.click(stop_clicked, outputs=[skip_button, stop_button], + stop_button.click(stop_clicked, inputs=currentTask, outputs=[skip_button, stop_button, currentTask], queue=False, show_progress=False, _js='cancelGenerateForever') - skip_button.click(skip_clicked, queue=False, show_progress=False) + skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask, + queue=False, show_progress=False) with gr.Row(elem_classes='advanced_check_row'): input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check') @@ -428,7 +434,7 @@ with shared.gradio_root: .then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False) ctrls = [ - prompt, negative_prompt, style_selections, + currentTask, prompt, negative_prompt, style_selections, performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale ] @@ -441,7 +447,8 @@ with shared.gradio_root: generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \ .then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \ - .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \ + .then(fn=get_task, inputs=ctrls, outputs=currentTask) \ + .then(fn=generate_clicked, inputs=currentTask, outputs=[progress_html, progress_window, progress_gallery, gallery]) \ .then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')