diff --git a/modules/async_worker.py b/modules/async_worker.py index 65618038..98a82d1f 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -6,6 +6,8 @@ class AsyncTask: self.args = args self.yields = [] self.results = [] + self.last_stop = False + self.processing = False async_tasks = [] @@ -119,6 +121,7 @@ def worker(): @torch.inference_mode() def handler(async_task): execution_start_time = time.perf_counter() + async_task.processing = True args = async_task.args args.reverse() @@ -664,6 +667,8 @@ def worker(): execution_start_time = time.perf_counter() try: + if async_task.last_stop is not False: + fcbh.model_management.interrupt_current_processing() positive_cond, negative_cond = task['c'], task['uc'] if 'cn' in goals: @@ -727,19 +732,16 @@ def worker(): yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except fcbh.model_management.InterruptProcessingException as e: - if shared.last_stop == 'skip': + if async_task.last_stop == 'skip': print('User skipped') continue - elif shared.last_stop == 'stop_previous': - print('Previous task stopped') - break else: print('User stopped') break execution_time = time.perf_counter() - execution_start_time print(f'Generating and saving time: {execution_time:.2f} seconds') - + async_task.processing = False return while True: diff --git a/shared.py b/shared.py index 269809e3..21a2a864 100644 --- a/shared.py +++ b/shared.py @@ -1,2 +1 @@ -gradio_root = None -last_stop = None +gradio_root = None \ No newline at end of file diff --git a/webui.py b/webui.py index d770cbae..cd769754 100644 --- a/webui.py +++ b/webui.py @@ -21,18 +21,20 @@ from modules.private_logger import get_current_html_path from modules.ui_gradio_extensions import reload_javascript from modules.auth import auth_enabled, check_auth -def generate_clicked(*args): - # worker_outputs=[progress_html, progress_window, progress_gallery, gallery] +def get_task(*args): + args = list(args) + currentTask = args.pop(0) + currentTask = worker.AsyncTask(args=args) + return currentTask + +def generate_clicked(task): + # outputs=[progress_html, progress_window, progress_gallery, gallery] execution_start_time = time.perf_counter() - task = worker.AsyncTask(args=list(args)) finished = False yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \ gr.update(visible=True, value=None), \ gr.update(visible=False, value=None), \ - gr.update(visible=False), \ - gr.update(visible=True, interactive=True), \ - gr.update(visible=True, interactive=True), \ gr.update(visible=False) worker.async_tasks.append(task) @@ -53,26 +55,17 @@ def generate_clicked(*args): yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \ gr.update(visible=True, value=image) if image is not None and not (modules.config.default_black_out_nsfw and modules.config.default_hide_preview_if_black_out_nsfw) else gr.update(), \ gr.update(), \ - gr.update(visible=False), \ - gr.update(visible=True), \ - gr.update(visible=True), \ gr.update(visible=False) if flag == 'results': yield gr.update(visible=True), \ gr.update(visible=True), \ gr.update(visible=True, value=product), \ - gr.update(visible=False), \ - gr.update(visible=True), \ - gr.update(visible=True), \ gr.update(visible=False) if flag == 'finish': yield gr.update(visible=False), \ gr.update(visible=False), \ gr.update(visible=False), \ - gr.update(visible=True, value=product), \ - gr.update(visible=False), \ - gr.update(visible=False), \ - gr.update(visible=True) + gr.update(visible=True, value=product) finished = True execution_time = time.perf_counter() - execution_start_time @@ -92,6 +85,7 @@ shared.gradio_root = gr.Blocks( css=modules.html.css).queue() with shared.gradio_root: + currentTask = gr.State(worker.AsyncTask(args=[])) with gr.Row(): with gr.Column(scale=2): with gr.Row(): @@ -118,19 +112,22 @@ with shared.gradio_root: skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) - def stop_clicked(): - shared.last_stop = 'stop' - model_management.interrupt_current_processing() - return [gr.update(interactive=False)] * 2 + def stop_clicked(currentTask): + import fcbh.model_management as model_management + currentTask.last_stop = 'stop' + if (currentTask.processing): + model_management.interrupt_current_processing() + return currentTask - def skip_clicked(): - shared.last_stop = 'skip' - model_management.interrupt_current_processing() - return + def skip_clicked(currentTask): + import fcbh.model_management as model_management + currentTask.last_stop = 'skip' + if (currentTask.processing): + model_management.interrupt_current_processing() + return currentTask - stop_button.click(stop_clicked, outputs=[skip_button, stop_button], - queue=False, show_progress=False, _js='cancelGenerateForever') - skip_button.click(skip_clicked, queue=False, show_progress=False) + stop_button.click(stop_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False, _js='cancelGenerateForever') + skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False) with gr.Row(elem_classes='advanced_check_row'): input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check') @@ -454,7 +451,7 @@ with shared.gradio_root: .then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False) ctrls = [ - prompt, negative_prompt, style_selections, + currentTask, prompt, negative_prompt, style_selections, performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale ] @@ -464,10 +461,12 @@ with shared.gradio_root: ctrls += [outpaint_selections, inpaint_input_image] ctrls += ip_ctrls - generate_button.click(lambda: (gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=False), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ + generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \ .then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \ - .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery, stop_button, skip_button, generate_button]) \ + .then(fn=get_task, inputs=ctrls, outputs=currentTask) \ + .then(fn=generate_clicked, inputs=currentTask, outputs=[progress_html, progress_window, progress_gallery, gallery]) \ + .then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed') def dump_default_english_config():