wip: refactor code to make it more efficient

now first processes all tasks and then does enhancements
This commit is contained in:
Manuel Schmid 2024-06-17 21:40:59 +02:00
parent 24d66f6f77
commit 3567c04918
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 143 additions and 100 deletions

View File

@ -268,7 +268,7 @@ def worker():
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
denoising_strength, final_scheduler_name, goals, initial_latent, switch, positive_cond,
negative_cond, task, tasks, tiled, use_expansion, width, height):
negative_cond, task, tasks, tiled, use_expansion, width, height, base_progress, total_count):
if async_task.last_stop is not False:
ldm_patched.modules.model_management.interrupt_current_processing()
if 'cn' in goals:
@ -301,13 +301,13 @@ def worker():
del positive_cond, negative_cond # Save memory
if inpaint_worker.current_task is not None:
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(
current_progress = int(base_progress + (100 - base_progress) * float(
(current_task_id + 1) * async_task.steps) / float(all_steps))
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
imgs = default_censor(imgs)
progressbar(async_task, current_progress,
f'Saving image {current_task_id + 1}/{async_task.image_number} to system ...')
f'Saving image {current_task_id + 1}/{total_count} to system ...')
img_paths = save_and_log(async_task, height, imgs, task, use_expansion, width)
yield_result(async_task, img_paths, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
@ -907,6 +907,10 @@ def worker():
prompt = prompt + '\n' + fallback_prompt
return prompt
def stop_processing(async_task, processing_start_time):
async_task.processing = False
processing_time = time.perf_counter() - processing_start_time
print(f'Processing time (total): {processing_time:.2f} seconds')
@torch.no_grad()
@torch.inference_mode()
@ -1041,6 +1045,9 @@ def worker():
all_steps = async_task.steps * async_task.image_number
if async_task.enhance_checkbox and len(async_task.enhance_ctrls) != 0:
all_steps += async_task.image_number * len(async_task.enhance_ctrls) * async_task.steps
print(f'[Parameters] Denoising Strength = {denoising_strength}')
if isinstance(initial_latent, dict) and 'samples' in initial_latent:
@ -1060,11 +1067,17 @@ def worker():
processing_start_time = time.perf_counter()
base_progress = int(flags.preparation_step_count)
current_progress = base_progress
total_count = async_task.image_number
def callback(step, x0, x, total_steps, y):
done_steps = current_task_id * async_task.steps + step
async_task.yields.append(['preview', (
int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{async_task.image_number} ...', y)])
int(base_progress + (100 - base_progress) * float(done_steps) / float(all_steps)),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])
generated_imgs = {}
for current_task_id, task in enumerate(tasks):
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(
@ -1079,99 +1092,10 @@ def worker():
current_task_id, denoising_strength,
final_scheduler_name, goals, initial_latent,
switch, task['c'], task['uc'], task,
tasks, tiled, use_expansion, width, height)
tasks, tiled, use_expansion, width, height,
flags.preparation_step_count, async_task.image_number)
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
print(f'[Enhance] Skipping, preconditions aren\'t met')
continue
# enhance
progressbar(async_task, current_progress, 'Processing enhance ...')
for img in imgs:
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_num_boxes=enhance_mask_sam_max_num_boxes,
model_type=enhance_mask_sam_model
))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
if int(async_task.inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
if async_task.debugging_enhance_masks_checkbox:
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
yield_result(async_task, mask, async_task.black_out_nsfw, False, async_task.disable_intermediate_results)
print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
base_model_additional_loras_enhance = []
inpaint_head_model_path_enhance = None
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
if inpaint_parameterized_enhance:
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts, 'prompt')
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt, async_task.translate_prompts, 'negative prompt')
if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt:
task_enhance = task.copy()
tasks_enhance = tasks.copy()
else:
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
enhance_negative_prompt,
base_model_additional_loras_enhance,
1, True,
use_expansion, use_style,
use_synthetic_refiner)
task_enhance = tasks_enhance[0]
# TODO could support vary, upscale and CN in the future
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
if async_task.freeu_enabled:
apply_freeu(async_task)
patch_samplers(async_task)
goals_enhance = ['inpaint']
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
async_task, None, inpaint_head_model_path_enhance, img, mask,
inpaint_parameterized_enhance, enhance_inpaint_strength,
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
current_progress, True)
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
controlnet_canny_path, controlnet_cpds_path,
current_task_id, enhance_inpaint_strength,
final_scheduler_name, goals_enhance,
initial_latent_enhance, switch,
task_enhance['c'], task_enhance['uc'],
task_enhance, tasks_enhance, tiled,
use_expansion, width_enhance, height_enhance)
# reset and prepare next iteration
img = imgs2[0]
generated_imgs[current_task_id] = imgs
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
@ -1186,10 +1110,129 @@ def worker():
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')
async_task.processing = False
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
print(f'[Enhance] Skipping, preconditions aren\'t met')
stop_processing(async_task, processing_start_time)
return
processing_time = time.perf_counter() - processing_start_time
print(f'Processing time (total): {processing_time:.2f} seconds')
# enhance
progressbar(async_task, current_progress, 'Processing enhance ...')
total_count = sum([len(imgs) for _, imgs in generated_imgs.items()]) * len(async_task.enhance_ctrls)
base_progress = current_progress
for generated_imgs_idx, (current_task_id, imgs) in enumerate(generated_imgs.items()):
for imgs_idx, img in enumerate(imgs):
for enhance_ctrls_idx, (enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field) in enumerate(async_task.enhance_ctrls):
current_task_id = generated_imgs_idx + imgs_idx + enhance_ctrls_idx
current_progress = int(base_progress + (100 - base_progress) * float(
current_task_id * async_task.steps) / float(all_steps))
progressbar(async_task, current_progress,
f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
enhancement_task_start_time = time.perf_counter()
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(
img, mask_model=enhance_mask_model, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_num_boxes=enhance_mask_sam_max_num_boxes,
model_type=enhance_mask_sam_model
))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
if int(async_task.inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
if async_task.debugging_enhance_masks_checkbox:
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
yield_result(async_task, mask, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
if enhance_mask_model == 'sam' and (
dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
base_model_additional_loras_enhance = []
inpaint_head_model_path_enhance = None
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
if inpaint_parameterized_enhance:
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts,
'prompt')
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt,
async_task.translate_prompts, 'negative prompt')
# positive and negative conditioning aren't available here anymore
# if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt:
# task_enhance = task.copy()
# tasks_enhance = tasks.copy()
# else:
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
enhance_negative_prompt,
base_model_additional_loras_enhance,
1, True,
use_expansion, use_style,
use_synthetic_refiner)
task_enhance = tasks_enhance[0]
# TODO could support vary, upscale and CN in the future
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
if async_task.freeu_enabled:
apply_freeu(async_task)
patch_samplers(async_task)
goals_enhance = ['inpaint']
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
async_task, None, inpaint_head_model_path_enhance, img, mask,
inpaint_parameterized_enhance, enhance_inpaint_strength,
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
current_progress, True)
try:
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
controlnet_canny_path, controlnet_cpds_path,
current_task_id, enhance_inpaint_strength,
final_scheduler_name, goals_enhance,
initial_latent_enhance, switch,
task_enhance['c'], task_enhance['uc'],
task_enhance, tasks_enhance, tiled,
use_expansion, width_enhance, height_enhance,
current_progress, total_count)
img = imgs2[0]
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
continue
else:
print('User stopped')
break
del task_enhance['c'], task_enhance['uc'] # Save memory
enhancement_task_time = time.perf_counter() - enhancement_task_start_time
print(f'Enhancement time: {enhancement_task_time:.2f} seconds')
stop_processing(async_task, processing_start_time)
return
while True:
time.sleep(0.01)