Rework Stop/Skip while bulk enhancing

-Rework stop/skip during bulk enhance to use "async_task.should_run" and -"async_task.should_skip" for more reliable stopping and skipping.
-Code cleanup
-Added more logging.
-Removed some emoji's
This commit is contained in:
ChrisColeTech 2024-08-18 13:55:46 -04:00
parent 9f535e8121
commit 7a0b8eebb3
2 changed files with 206 additions and 304 deletions

View File

@ -1,5 +1,5 @@
import threading
import gradio as gr
from extras.inpaint_mask import generate_mask_from_image, SAMOptions
from modules.patch import PatchSettings, patch_settings, patch_all
import modules.config
@ -9,7 +9,6 @@ patch_all()
class AsyncTask:
def __init__(self, args):
from modules.flags import Performance, MetadataScheme, ip_list, disabled
from modules.util import get_enabled_loras
@ -166,16 +165,18 @@ class AsyncTask:
self.bulk_enhance_data_type = args.pop()
self.bulk_enhance_file_explorer = args.pop()
self.bulk_enhance_input_path = args.pop()
self.current_task_id = 0
self.current_progress = None
self.preparation_steps = None
self.all_steps = None
self.total_count = None
self.current_async_task = None
self.should_run = True
self.should_skip = False
async_tasks = []
# Define global variables
current_task_id = 0
current_progress = None
preparation_steps = None
all_steps = None
total_count = None
current_async_task = None
class EarlyReturnException(BaseException):
@ -227,21 +228,8 @@ def worker():
if async_gradio_app.share:
flag += f''' or {async_gradio_app.share_url}'''
print(flag)
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n💥The specified file was not found. Please check the file path and try again. 📁\n\n ")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\n⚠️ A value error occurred: {ve}. Please check the input values. ⚠️\n\n ")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n💥 An unexpected error occurred: {e} 💥\n\n ")
print(e)
def progressbar(async_task, number, text):
print(f'[Fooocus] {text}')
@ -341,7 +329,7 @@ def worker():
refiner_swap_method=async_task.refiner_swap_method,
disable_preview=async_task.disable_preview
)
# del positive_cond, negative_cond # Save memory
del positive_cond, negative_cond # Save memory
if inpaint_worker.current_task is not None:
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
current_progress = int(
@ -1032,7 +1020,6 @@ def worker():
if async_task.bulk_enhance_enabled and async_task.bulk_enhance_data_type == 'Files' and len(async_task.bulk_enhance_file_explorer) > 0:
goals.append('bulk_enhance_files')
skip_prompt_processing = True
if async_task.bulk_enhance_enabled and async_task.bulk_enhance_data_type == 'Folder' and async_task.bulk_enhance_input_path:
@ -1143,9 +1130,26 @@ def worker():
preparation_steps, total_count, show_intermediate_results,
persist_image)
# del task_enhance['c'], task_enhance['uc'] # Save memory
del task_enhance['c'], task_enhance['uc'] # Save memory
return current_progress, imgs[0], prompt, negative_prompt
def print_user_skipped(async_task):
print('User skipped')
async_task.last_stop = False
async_task.should_skip = True
progressbar(async_task, 0,
'Image skipped')
print(
"\n\n⚠️ Image skipped ... ⚠️\n\n")
time.sleep(0.1)
def print_user_stopped(async_task):
print('User stopped')
async_task.should_run = False
print(
"\n\n 💥 Processing was interrupted by the user. Please try again. 💥\n\n ")
def enhance_upscale(all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps,
prompt, negative_prompt, final_scheduler_name, height, img, preparation_steps, switch, tiled,
@ -1172,32 +1176,15 @@ def worker():
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
print_user_skipped(async_task)
# also skip all enhance steps for this image, but add the steps to the progress bar
if async_task.enhance_uov_processing_order == flags.enhancement_uov_before:
done_steps_inpainting += len(
async_task.enhance_ctrls) * enhance_steps
exception_result = 'continue'
else:
print('User stopped')
gr.Error(
"\n\n 💥 Processing was interrupted by the user. Please try again. 💥\n\n ")
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n 💥 The specified file was not found. Please check the file path and try again. 📁\n\n ")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\n⚠️ A value error occurred: {ve}. Please check the input values. ⚠️\n\n ")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n 💥 An unexpected error occurred: {e} 💥\n\n ")
exception_result = 'break'
print_user_stopped(async_task)
finally:
done_steps_upscaling += steps
return current_task_id, done_steps_inpainting, done_steps_upscaling, img, exception_result
@ -1224,6 +1211,14 @@ def worker():
current_task_number,
persist_image=True
):
if not async_task.should_run:
stop_processing(async_task, async_task.processing_start_time)
return
if async_task.should_skip:
return
index = current_task_id
progressbar(async_task, current_progress,
@ -1358,28 +1353,12 @@ def worker():
last_enhance_negative_prompt = enhance_negative_prompt_processed
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
print_user_skipped(async_task)
continue
else:
print('User stopped')
gr.Error(
"\n\n💥 Processing was interrupted by the user. Please try again. 💥 \n\n ")
print_user_stopped(async_task)
break
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n💥 The specified file was not found. Please check the file path and try again. 📁\n\n ")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\n⚠️ A value error occurred: {ve}. Please check the input values. ⚠️\n\n ")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n💥 An unexpected error occurred: {e} 💥\n\n")
finally:
done_steps_inpainting += enhance_steps
@ -1412,45 +1391,74 @@ def worker():
f'Enhancement image time: {enhancement_image_time:.2f} seconds')
async_task.enhance_stats[-1] += 1
def callback(step, x0, x, total_steps, y):
global current_async_task
global current_task_id
global current_progress
global preparation_steps
global all_steps
global total_count
def loop_image_files(files, bulk_enhance_callback,
async_task,
current_progress,
all_steps,
height,
width,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner
):
for index, file_name in enumerate(files):
async_task.current_task_id = index
# Check if current_async_task is not None to avoid AttributeError
if current_async_task is None:
raise ValueError("current_async_task is not set")
if not async_task.should_run:
print_user_stopped(async_task)
progressbar(async_task, 0,
'Stopping ...')
break
if step == 0:
current_async_task.callback_steps = 0
if async_task.should_skip:
print_user_skipped(async_task)
async_task.should_skip = False
continue
# Calculate callback steps
current_async_task.callback_steps += (100 -
preparation_steps) / float(all_steps)
# Build full path to the file
image = grh.Image(type='numpy')._format_image(Image.open(
file_name))
image = HWC3(image)
# Append to yields
current_async_task.yields.append([
'preview', (
int(current_progress + current_async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...',
y
# Add the image to the list
images_to_enhance = [image]
_, height, width = image.shape # Unpack the shape into C, H, W
yield_result(async_task, image, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
enhance_images(
images_to_enhance,
async_task,
0,
current_progress,
all_steps,
height,
width,
bulk_enhance_callback,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner,
0
)
])
@torch.no_grad()
@torch.inference_mode()
def handler(async_task: AsyncTask):
global current_async_task
global current_task_id
global current_progress
global preparation_steps
global all_steps
global total_count
current_async_task = async_task
preparation_start_time = time.perf_counter()
async_task.processing = True
@ -1523,8 +1531,8 @@ def worker():
goals = []
tasks = []
current_progress = 1
if async_task.input_image_checkbox or async_task.bulk_enhance_enabled:
input_image = async_task.input_image_checkbox or async_task.bulk_enhance_enabled
if input_image:
base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner = apply_image_input(
async_task, base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path,
goals, inpaint_head_model_path, inpaint_image, inpaint_mask, inpaint_parameterized, ip_adapter_face_path,
@ -1601,15 +1609,6 @@ def worker():
advance_progress=True)
except EarlyReturnException:
return
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\n⚠️ A value error occurred: {ve}. Please check the input values. ⚠️\n\n")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f" \n\n💥 An unexpected error occurred: {e} 💥 \n\n")
if 'cn' in goals:
apply_control_nets(async_task, height, ip_adapter_face_path,
@ -1623,6 +1622,7 @@ def worker():
# async_task.steps can have value of uov steps here when upscale has been applied
steps, _, _, _ = apply_overrides(
async_task, async_task.steps, height, width)
all_steps = steps * async_task.image_number
image_enhance = async_task.enhance_checkbox or async_task.bulk_enhance_enabled
if image_enhance and async_task.enhance_uov_method != flags.disabled.casefold():
@ -1644,8 +1644,7 @@ def worker():
len(async_task.enhance_ctrls) * enhance_steps
if image_enhance and len(async_task.enhance_ctrls) == 0 and async_task.enhance_uov_method == flags.disabled.casefold():
# Handle value errors (e.g., invalid parameters)
gr.Warning(
print(
f"\n\n⚠️ Warning - Enhancements will be skipped. ⚠️ \n\nNo Enhancements were selected. \n\n Please check the input values. \n\n")
all_steps = max(all_steps, 1)
@ -1670,12 +1669,29 @@ def worker():
processing_start_time = time.perf_counter()
async_task.processing_start_time = time.perf_counter()
preparation_steps = current_progress
total_count = async_task.image_number
async_task.current_task_id = 0
# BULK ENHANCEMENTS #
def bulk_enhance_callback(step, x0, x, total_steps, y):
if step == 0:
async_task.callback_steps = 0
# Calculate callback steps
async_task.callback_steps += (100 -
preparation_steps) / float(all_steps)
# Append to yields
async_task.yields.append([
'preview', (
int(current_progress + async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {async_task.current_task_id + 1}/{total_count} ...',
y
)
])
show_intermediate_results = len(tasks) > 1 or async_task.should_enhance
persist_image = not async_task.should_enhance or not async_task.save_final_enhanced_image_only
# ENHANCEMENTS #
images_to_enhance = []
if 'enhance' in goals:
async_task.image_number = 1
@ -1693,7 +1709,7 @@ def worker():
all_steps,
height,
width,
callback,
bulk_enhance_callback,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
@ -1708,144 +1724,82 @@ def worker():
)
if 'bulk_enhance_files' in goals:
for file in async_task.bulk_enhance_file_explorer:
try:
files = []
# Open and preprocess the image
image = grh.Image(type='numpy')._format_image(Image.open(
file.orig_name))
image = HWC3(image)
# Add the image to the list
images_to_enhance = [image]
_, height, width = image.shape # Unpack the shape into C, H, W
# input image already provided, processing is skipped
steps = 0
yield_result(async_task, image, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
enhance_images(
images_to_enhance,
async_task,
0,
current_progress,
all_steps,
height,
width,
callback,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner,
0
)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
continue
else:
print('User stopped')
gr.Error(
"\n\n💥Processing was interrupted by the user. Please try again.💥 \n\n")
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n💥The specified file was not found. Please check the file path and try again. 📁 \n\n")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\nA value error occurred: {ve}. Please check the input values. ⚠️\n\n")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n💥An unexpected error occurred: {e} 💥\n\n")
for file, index in async_task.bulk_enhance_file_explorer:
files.append(file.orig_name)
total_count = len(files)
loop_image_files(files, bulk_enhance_callback,
async_task,
current_progress,
all_steps,
height,
width,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner
)
if 'bulk_enhance_folder' in goals:
# Walk through the directory tree
valid_extensions = (".jpg", ".jpeg", ".png",
".bmp", ".tiff", ".webp")
files = []
# Walk through the directory tree
for root, dirs, files_in_dir in os.walk(async_task.bulk_enhance_input_path):
try:
for file_name in files_in_dir:
# Build full path to the file
full_file_path = os.path.join(root, file_name)
# Check if the file has a valid extension
if file_name.lower().endswith(valid_extensions):
files.append(full_file_path)
total_count = len(files)
loop_image_files(files, bulk_enhance_callback,
async_task,
current_progress,
all_steps,
height,
width,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner
)
for file_name in files_in_dir:
# Build full path to the file
full_file_path = os.path.join(root, file_name)
image = grh.Image(type='numpy')._format_image(Image.open(
full_file_path))
image = HWC3(image)
def callback(step, x0, x, total_steps, y):
if step == 0:
async_task.callback_steps = 0
async_task.callback_steps += (100 -
preparation_steps) / float(all_steps)
async_task.yields.append(['preview', (
int(current_progress + async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])
# Add the image to the list
images_to_enhance = [image]
show_intermediate_results = len(tasks) > 1 or async_task.should_enhance
persist_image = not async_task.should_enhance or not async_task.save_final_enhanced_image_only
_, height, width = image.shape # Unpack the shape into C, H, W
steps = 0
yield_result(async_task, image, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
enhance_images(
images_to_enhance,
async_task,
0,
current_progress,
all_steps,
height,
width,
callback,
controlnet_canny_path,
controlnet_cpds_path,
denoising_strength,
final_scheduler_name,
preparation_steps,
switch,
tiled,
use_expansion,
use_style,
use_synthetic_refiner,
0
)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
continue
else:
print('User stopped')
gr.Error(
"\n\n💥Processing was interrupted by the user. Please try again.💥\n\n")
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n💥The specified file was not found. Please check the file path and try again. 📁\n\n")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\nA value error occurred: {ve}. Please check the input values. ⚠️\n\n")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n💥An unexpected error occurred: {e} 💥\n\n")
# MAIN GENERATION QUEUE #
for current_task_id, task in enumerate(tasks):
current_task_number = current_task_id + 1
progressbar(async_task, current_progress,
f'Preparing task {current_task_number}/{async_task.image_number} ...')
setup(async_task, current_task_number)
f'Preparing task {current_task_id + 1}/{async_task.image_number} ...')
current_task_number = current_task_id + 1
inpaint_worker.current_task = None
patch_samplers(async_task)
execution_start_time = time.perf_counter()
try:
imgs, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path,
controlnet_cpds_path, current_task_id,
@ -1859,16 +1813,14 @@ def worker():
current_progress = int(preparation_steps + (100 - preparation_steps) / float(
all_steps) * async_task.steps * (current_task_id + 1))
# images_to_enhance += imgs
if not async_task.should_enhance:
print(f'[Enhance] Skipping, preconditions aren\'t met')
stop_processing(async_task, processing_start_time)
return
# Immediately enhance each image generated
continue
images_to_enhance = imgs
enhance_images(
imgs,
images_to_enhance,
async_task,
current_task_id,
current_progress,
@ -1891,72 +1843,19 @@ def worker():
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
print_user_skipped(async_task)
continue
else:
print('User stopped')
gr.Error(
"\n\n💥Processing was interrupted by the user. Please try again.💥\n\n")
print_user_stopped(async_task)
break
except FileNotFoundError:
# Handle the case where the file is not found
gr.Error(
"\n\n💥The specified file was not found. Please check the file path and try again. 📁\n\n")
except ValueError as ve:
# Handle value errors (e.g., invalid parameters)
gr.Warning(
f"\n\n💥A value error occurred: {ve}. Please check the input values. ⚠️\n\n")
except Exception as e:
# Handle any other unforeseen errors
gr.Error(
f"\n\n💥An unexpected error occurred: {e} 💥\n\n")
# del task['c'], task['uc'] # Save memory
del task['c'], task['uc'] # Save memory
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')
stop_processing(async_task, processing_start_time)
return
def setup(async_task: AsyncTask, current_task_number):
if async_task.performance_selection == Performance.EXTREME_SPEED:
set_lcm_defaults(async_task, current_progress,
advance_progress=True)
elif async_task.performance_selection == Performance.LIGHTNING:
set_lightning_defaults(
async_task, current_progress, advance_progress=True)
elif async_task.performance_selection == Performance.HYPER_SD:
set_hyper_sd_defaults(
async_task, current_progress, advance_progress=True)
width, height = async_task.aspect_ratios_selection.replace('×', ' ').split(' ')[
:2]
width, height = int(width), int(height)
inpaint_worker.current_task = None
controlnet_canny_path = None
controlnet_cpds_path = None
clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
goals = []
current_progress = current_task_number
# Load or unload CNs
progressbar(async_task, current_progress, 'Loading control models ...')
pipeline.refresh_controlnets(
[controlnet_canny_path, controlnet_cpds_path])
ip_adapter.load_ip_adapter(
clip_vision_path, ip_negative_path, ip_adapter_path)
ip_adapter.load_ip_adapter(
clip_vision_path, ip_negative_path, ip_adapter_face_path)
async_task.steps, switch, width, height = apply_overrides(
async_task, async_task.steps, height, width)
progressbar(async_task, current_progress, 'Initializing ...')
patch_samplers(async_task)
while True:
time.sleep(0.01)
if len(async_tasks) > 0:
@ -1971,9 +1870,6 @@ def worker():
except:
traceback.print_exc()
task.yields.append(['finish', task.results])
gr.Error(
f"\n\n💥An unexpected error occurred: Please try again. 💥 \n\n")
finally:
if pid in modules.patch.patch_settings:
del modules.patch.patch_settings[pid]

View File

@ -199,6 +199,9 @@ with shared.gradio_root:
def stop_clicked(currentTask):
import ldm_patched.modules.model_management as model_management
currentTask.last_stop = 'stop'
currentTask.should_run = False
print(
"\n\n⚠️ Stopping. Please wait ... ⚠️\n\n")
if (currentTask.processing):
model_management.interrupt_current_processing()
return currentTask
@ -206,6 +209,9 @@ with shared.gradio_root:
def skip_clicked(currentTask):
import ldm_patched.modules.model_management as model_management
currentTask.last_stop = 'skip'
currentTask.should_skip = True
print(
"\n\n⚠️ Skipping. Please wait ... ⚠️\n\n")
if (currentTask.processing):
model_management.interrupt_current_processing()
return currentTask