wip: refactor code to make it more efficient
now first processes all tasks and then does enhancements
This commit is contained in:
parent
24d66f6f77
commit
3567c04918
|
|
@ -268,7 +268,7 @@ def worker():
|
|||
|
||||
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
|
||||
denoising_strength, final_scheduler_name, goals, initial_latent, switch, positive_cond,
|
||||
negative_cond, task, tasks, tiled, use_expansion, width, height):
|
||||
negative_cond, task, tasks, tiled, use_expansion, width, height, base_progress, total_count):
|
||||
if async_task.last_stop is not False:
|
||||
ldm_patched.modules.model_management.interrupt_current_processing()
|
||||
if 'cn' in goals:
|
||||
|
|
@ -301,13 +301,13 @@ def worker():
|
|||
del positive_cond, negative_cond # Save memory
|
||||
if inpaint_worker.current_task is not None:
|
||||
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
||||
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(
|
||||
current_progress = int(base_progress + (100 - base_progress) * float(
|
||||
(current_task_id + 1) * async_task.steps) / float(all_steps))
|
||||
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
|
||||
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
|
||||
imgs = default_censor(imgs)
|
||||
progressbar(async_task, current_progress,
|
||||
f'Saving image {current_task_id + 1}/{async_task.image_number} to system ...')
|
||||
f'Saving image {current_task_id + 1}/{total_count} to system ...')
|
||||
img_paths = save_and_log(async_task, height, imgs, task, use_expansion, width)
|
||||
yield_result(async_task, img_paths, async_task.black_out_nsfw, False,
|
||||
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
|
||||
|
|
@ -907,6 +907,10 @@ def worker():
|
|||
prompt = prompt + '\n' + fallback_prompt
|
||||
return prompt
|
||||
|
||||
def stop_processing(async_task, processing_start_time):
|
||||
async_task.processing = False
|
||||
processing_time = time.perf_counter() - processing_start_time
|
||||
print(f'Processing time (total): {processing_time:.2f} seconds')
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
|
|
@ -1041,6 +1045,9 @@ def worker():
|
|||
|
||||
all_steps = async_task.steps * async_task.image_number
|
||||
|
||||
if async_task.enhance_checkbox and len(async_task.enhance_ctrls) != 0:
|
||||
all_steps += async_task.image_number * len(async_task.enhance_ctrls) * async_task.steps
|
||||
|
||||
print(f'[Parameters] Denoising Strength = {denoising_strength}')
|
||||
|
||||
if isinstance(initial_latent, dict) and 'samples' in initial_latent:
|
||||
|
|
@ -1060,11 +1067,17 @@ def worker():
|
|||
|
||||
processing_start_time = time.perf_counter()
|
||||
|
||||
base_progress = int(flags.preparation_step_count)
|
||||
current_progress = base_progress
|
||||
total_count = async_task.image_number
|
||||
|
||||
def callback(step, x0, x, total_steps, y):
|
||||
done_steps = current_task_id * async_task.steps + step
|
||||
async_task.yields.append(['preview', (
|
||||
int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)),
|
||||
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{async_task.image_number} ...', y)])
|
||||
int(base_progress + (100 - base_progress) * float(done_steps) / float(all_steps)),
|
||||
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])
|
||||
|
||||
generated_imgs = {}
|
||||
|
||||
for current_task_id, task in enumerate(tasks):
|
||||
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(
|
||||
|
|
@ -1079,99 +1092,10 @@ def worker():
|
|||
current_task_id, denoising_strength,
|
||||
final_scheduler_name, goals, initial_latent,
|
||||
switch, task['c'], task['uc'], task,
|
||||
tasks, tiled, use_expansion, width, height)
|
||||
tasks, tiled, use_expansion, width, height,
|
||||
flags.preparation_step_count, async_task.image_number)
|
||||
|
||||
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
|
||||
print(f'[Enhance] Skipping, preconditions aren\'t met')
|
||||
continue
|
||||
|
||||
# enhance
|
||||
progressbar(async_task, current_progress, 'Processing enhance ...')
|
||||
|
||||
for img in imgs:
|
||||
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
||||
if enhance_mask_model == 'sam':
|
||||
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
|
||||
|
||||
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions(
|
||||
dino_prompt=enhance_mask_dino_prompt_text,
|
||||
dino_box_threshold=enhance_mask_box_threshold,
|
||||
dino_text_threshold=enhance_mask_text_threshold,
|
||||
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
|
||||
dino_debug=async_task.debugging_dino,
|
||||
max_num_boxes=enhance_mask_sam_max_num_boxes,
|
||||
model_type=enhance_mask_sam_model
|
||||
))
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask[:, :, 0]
|
||||
|
||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
|
||||
|
||||
if async_task.debugging_enhance_masks_checkbox:
|
||||
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
|
||||
yield_result(async_task, mask, async_task.black_out_nsfw, False, async_task.disable_intermediate_results)
|
||||
|
||||
print(f'[Enhance] {dino_detection_count} boxes detected')
|
||||
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
|
||||
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
|
||||
|
||||
if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
|
||||
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
|
||||
continue
|
||||
|
||||
base_model_additional_loras_enhance = []
|
||||
inpaint_head_model_path_enhance = None
|
||||
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
|
||||
|
||||
if inpaint_parameterized_enhance:
|
||||
progressbar(async_task, current_progress, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
|
||||
async_task.inpaint_engine)
|
||||
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
|
||||
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
|
||||
|
||||
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
|
||||
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts, 'prompt')
|
||||
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt, async_task.translate_prompts, 'negative prompt')
|
||||
|
||||
if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt:
|
||||
task_enhance = task.copy()
|
||||
tasks_enhance = tasks.copy()
|
||||
else:
|
||||
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
|
||||
enhance_negative_prompt,
|
||||
base_model_additional_loras_enhance,
|
||||
1, True,
|
||||
use_expansion, use_style,
|
||||
use_synthetic_refiner)
|
||||
task_enhance = tasks_enhance[0]
|
||||
|
||||
# TODO could support vary, upscale and CN in the future
|
||||
# if 'cn' in goals:
|
||||
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
|
||||
if async_task.freeu_enabled:
|
||||
apply_freeu(async_task)
|
||||
patch_samplers(async_task)
|
||||
|
||||
goals_enhance = ['inpaint']
|
||||
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
|
||||
async_task, None, inpaint_head_model_path_enhance, img, mask,
|
||||
inpaint_parameterized_enhance, enhance_inpaint_strength,
|
||||
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
|
||||
current_progress, True)
|
||||
|
||||
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
|
||||
controlnet_canny_path, controlnet_cpds_path,
|
||||
current_task_id, enhance_inpaint_strength,
|
||||
final_scheduler_name, goals_enhance,
|
||||
initial_latent_enhance, switch,
|
||||
task_enhance['c'], task_enhance['uc'],
|
||||
task_enhance, tasks_enhance, tiled,
|
||||
use_expansion, width_enhance, height_enhance)
|
||||
|
||||
# reset and prepare next iteration
|
||||
img = imgs2[0]
|
||||
generated_imgs[current_task_id] = imgs
|
||||
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException:
|
||||
if async_task.last_stop == 'skip':
|
||||
|
|
@ -1186,10 +1110,129 @@ def worker():
|
|||
execution_time = time.perf_counter() - execution_start_time
|
||||
print(f'Generating and saving time: {execution_time:.2f} seconds')
|
||||
|
||||
async_task.processing = False
|
||||
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
|
||||
print(f'[Enhance] Skipping, preconditions aren\'t met')
|
||||
stop_processing(async_task, processing_start_time)
|
||||
return
|
||||
|
||||
processing_time = time.perf_counter() - processing_start_time
|
||||
print(f'Processing time (total): {processing_time:.2f} seconds')
|
||||
# enhance
|
||||
progressbar(async_task, current_progress, 'Processing enhance ...')
|
||||
total_count = sum([len(imgs) for _, imgs in generated_imgs.items()]) * len(async_task.enhance_ctrls)
|
||||
base_progress = current_progress
|
||||
for generated_imgs_idx, (current_task_id, imgs) in enumerate(generated_imgs.items()):
|
||||
for imgs_idx, img in enumerate(imgs):
|
||||
for enhance_ctrls_idx, (enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field) in enumerate(async_task.enhance_ctrls):
|
||||
current_task_id = generated_imgs_idx + imgs_idx + enhance_ctrls_idx
|
||||
current_progress = int(base_progress + (100 - base_progress) * float(
|
||||
current_task_id * async_task.steps) / float(all_steps))
|
||||
progressbar(async_task, current_progress,
|
||||
f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
|
||||
enhancement_task_start_time = time.perf_counter()
|
||||
|
||||
if enhance_mask_model == 'sam':
|
||||
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
|
||||
|
||||
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(
|
||||
img, mask_model=enhance_mask_model, sam_options=SAMOptions(
|
||||
dino_prompt=enhance_mask_dino_prompt_text,
|
||||
dino_box_threshold=enhance_mask_box_threshold,
|
||||
dino_text_threshold=enhance_mask_text_threshold,
|
||||
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
|
||||
dino_debug=async_task.debugging_dino,
|
||||
max_num_boxes=enhance_mask_sam_max_num_boxes,
|
||||
model_type=enhance_mask_sam_model
|
||||
))
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask[:, :, 0]
|
||||
|
||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
|
||||
|
||||
if async_task.debugging_enhance_masks_checkbox:
|
||||
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
|
||||
yield_result(async_task, mask, async_task.black_out_nsfw, False,
|
||||
async_task.disable_intermediate_results)
|
||||
|
||||
print(f'[Enhance] {dino_detection_count} boxes detected')
|
||||
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
|
||||
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
|
||||
|
||||
if enhance_mask_model == 'sam' and (
|
||||
dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
|
||||
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
|
||||
continue
|
||||
|
||||
base_model_additional_loras_enhance = []
|
||||
inpaint_head_model_path_enhance = None
|
||||
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
|
||||
|
||||
if inpaint_parameterized_enhance:
|
||||
progressbar(async_task, current_progress, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
|
||||
async_task.inpaint_engine)
|
||||
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
|
||||
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
|
||||
|
||||
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
|
||||
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts,
|
||||
'prompt')
|
||||
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt,
|
||||
async_task.translate_prompts, 'negative prompt')
|
||||
|
||||
# positive and negative conditioning aren't available here anymore
|
||||
# if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt:
|
||||
# task_enhance = task.copy()
|
||||
# tasks_enhance = tasks.copy()
|
||||
# else:
|
||||
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
|
||||
enhance_negative_prompt,
|
||||
base_model_additional_loras_enhance,
|
||||
1, True,
|
||||
use_expansion, use_style,
|
||||
use_synthetic_refiner)
|
||||
task_enhance = tasks_enhance[0]
|
||||
|
||||
# TODO could support vary, upscale and CN in the future
|
||||
# if 'cn' in goals:
|
||||
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
|
||||
if async_task.freeu_enabled:
|
||||
apply_freeu(async_task)
|
||||
patch_samplers(async_task)
|
||||
|
||||
goals_enhance = ['inpaint']
|
||||
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
|
||||
async_task, None, inpaint_head_model_path_enhance, img, mask,
|
||||
inpaint_parameterized_enhance, enhance_inpaint_strength,
|
||||
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
|
||||
current_progress, True)
|
||||
|
||||
try:
|
||||
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
|
||||
controlnet_canny_path, controlnet_cpds_path,
|
||||
current_task_id, enhance_inpaint_strength,
|
||||
final_scheduler_name, goals_enhance,
|
||||
initial_latent_enhance, switch,
|
||||
task_enhance['c'], task_enhance['uc'],
|
||||
task_enhance, tasks_enhance, tiled,
|
||||
use_expansion, width_enhance, height_enhance,
|
||||
current_progress, total_count)
|
||||
img = imgs2[0]
|
||||
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException:
|
||||
if async_task.last_stop == 'skip':
|
||||
print('User skipped')
|
||||
async_task.last_stop = False
|
||||
continue
|
||||
else:
|
||||
print('User stopped')
|
||||
break
|
||||
|
||||
del task_enhance['c'], task_enhance['uc'] # Save memory
|
||||
enhancement_task_time = time.perf_counter() - enhancement_task_start_time
|
||||
print(f'Enhancement time: {enhancement_task_time:.2f} seconds')
|
||||
|
||||
stop_processing(async_task, processing_start_time)
|
||||
return
|
||||
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
|
|
|
|||
Loading…
Reference in New Issue