diff --git a/modules/async_worker.py b/modules/async_worker.py index fdb7a3c4..41a769e0 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -268,7 +268,7 @@ def worker(): def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent, switch, positive_cond, - negative_cond, task, tasks, tiled, use_expansion, width, height): + negative_cond, task, tasks, tiled, use_expansion, width, height, base_progress, total_count): if async_task.last_stop is not False: ldm_patched.modules.model_management.interrupt_current_processing() if 'cn' in goals: @@ -301,13 +301,13 @@ def worker(): del positive_cond, negative_cond # Save memory if inpaint_worker.current_task is not None: imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] - current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float( + current_progress = int(base_progress + (100 - base_progress) * float( (current_task_id + 1) * async_task.steps) / float(all_steps)) if modules.config.default_black_out_nsfw or async_task.black_out_nsfw: progressbar(async_task, current_progress, 'Checking for NSFW content ...') imgs = default_censor(imgs) progressbar(async_task, current_progress, - f'Saving image {current_task_id + 1}/{async_task.image_number} to system ...') + f'Saving image {current_task_id + 1}/{total_count} to system ...') img_paths = save_and_log(async_task, height, imgs, task, use_expansion, width) yield_result(async_task, img_paths, async_task.black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results) @@ -907,6 +907,10 @@ def worker(): prompt = prompt + '\n' + fallback_prompt return prompt + def stop_processing(async_task, processing_start_time): + async_task.processing = False + processing_time = time.perf_counter() - processing_start_time + print(f'Processing time (total): {processing_time:.2f} seconds') @torch.no_grad() @torch.inference_mode() @@ -1041,6 +1045,9 @@ def worker(): all_steps = async_task.steps * async_task.image_number + if async_task.enhance_checkbox and len(async_task.enhance_ctrls) != 0: + all_steps += async_task.image_number * len(async_task.enhance_ctrls) * async_task.steps + print(f'[Parameters] Denoising Strength = {denoising_strength}') if isinstance(initial_latent, dict) and 'samples' in initial_latent: @@ -1060,11 +1067,17 @@ def worker(): processing_start_time = time.perf_counter() + base_progress = int(flags.preparation_step_count) + current_progress = base_progress + total_count = async_task.image_number + def callback(step, x0, x, total_steps, y): done_steps = current_task_id * async_task.steps + step async_task.yields.append(['preview', ( - int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)), - f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{async_task.image_number} ...', y)]) + int(base_progress + (100 - base_progress) * float(done_steps) / float(all_steps)), + f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)]) + + generated_imgs = {} for current_task_id, task in enumerate(tasks): current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float( @@ -1079,99 +1092,10 @@ def worker(): current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent, switch, task['c'], task['uc'], task, - tasks, tiled, use_expansion, width, height) + tasks, tiled, use_expansion, width, height, + flags.preparation_step_count, async_task.image_number) - if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0: - print(f'[Enhance] Skipping, preconditions aren\'t met') - continue - - # enhance - progressbar(async_task, current_progress, 'Processing enhance ...') - - for img in imgs: - for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls: - if enhance_mask_model == 'sam': - print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"') - - mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions( - dino_prompt=enhance_mask_dino_prompt_text, - dino_box_threshold=enhance_mask_box_threshold, - dino_text_threshold=enhance_mask_text_threshold, - dino_erode_or_dilate=async_task.dino_erode_or_dilate, - dino_debug=async_task.debugging_dino, - max_num_boxes=enhance_mask_sam_max_num_boxes, - model_type=enhance_mask_sam_model - )) - if len(mask.shape) == 3: - mask = mask[:, :, 0] - - if int(async_task.inpaint_erode_or_dilate) != 0: - mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate) - - if async_task.debugging_enhance_masks_checkbox: - async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) - yield_result(async_task, mask, async_task.black_out_nsfw, False, async_task.disable_intermediate_results) - - print(f'[Enhance] {dino_detection_count} boxes detected') - print(f'[Enhance] {sam_detection_count} segments detected in boxes') - print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask') - - if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0): - print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping') - continue - - base_model_additional_loras_enhance = [] - inpaint_head_model_path_enhance = None - inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail - - if inpaint_parameterized_enhance: - progressbar(async_task, current_progress, 'Downloading inpainter ...') - inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models( - async_task.inpaint_engine) - if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance: - base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)] - - progressbar(async_task, current_progress, 'Preparing enhance prompts ...') - enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts, 'prompt') - enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt, async_task.translate_prompts, 'negative prompt') - - if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt: - task_enhance = task.copy() - tasks_enhance = tasks.copy() - else: - tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt, - enhance_negative_prompt, - base_model_additional_loras_enhance, - 1, True, - use_expansion, use_style, - use_synthetic_refiner) - task_enhance = tasks_enhance[0] - - # TODO could support vary, upscale and CN in the future - # if 'cn' in goals: - # apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width) - if async_task.freeu_enabled: - apply_freeu(async_task) - patch_samplers(async_task) - - goals_enhance = ['inpaint'] - enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint( - async_task, None, inpaint_head_model_path_enhance, img, mask, - inpaint_parameterized_enhance, enhance_inpaint_strength, - enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent, - current_progress, True) - - imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback, - controlnet_canny_path, controlnet_cpds_path, - current_task_id, enhance_inpaint_strength, - final_scheduler_name, goals_enhance, - initial_latent_enhance, switch, - task_enhance['c'], task_enhance['uc'], - task_enhance, tasks_enhance, tiled, - use_expansion, width_enhance, height_enhance) - - # reset and prepare next iteration - img = imgs2[0] + generated_imgs[current_task_id] = imgs except ldm_patched.modules.model_management.InterruptProcessingException: if async_task.last_stop == 'skip': @@ -1186,10 +1110,129 @@ def worker(): execution_time = time.perf_counter() - execution_start_time print(f'Generating and saving time: {execution_time:.2f} seconds') - async_task.processing = False + if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0: + print(f'[Enhance] Skipping, preconditions aren\'t met') + stop_processing(async_task, processing_start_time) + return - processing_time = time.perf_counter() - processing_start_time - print(f'Processing time (total): {processing_time:.2f} seconds') + # enhance + progressbar(async_task, current_progress, 'Processing enhance ...') + total_count = sum([len(imgs) for _, imgs in generated_imgs.items()]) * len(async_task.enhance_ctrls) + base_progress = current_progress + for generated_imgs_idx, (current_task_id, imgs) in enumerate(generated_imgs.items()): + for imgs_idx, img in enumerate(imgs): + for enhance_ctrls_idx, (enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field) in enumerate(async_task.enhance_ctrls): + current_task_id = generated_imgs_idx + imgs_idx + enhance_ctrls_idx + current_progress = int(base_progress + (100 - base_progress) * float( + current_task_id * async_task.steps) / float(all_steps)) + progressbar(async_task, current_progress, + f'Preparing enhancement {current_task_id + 1}/{total_count} ...') + enhancement_task_start_time = time.perf_counter() + + if enhance_mask_model == 'sam': + print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"') + + mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image( + img, mask_model=enhance_mask_model, sam_options=SAMOptions( + dino_prompt=enhance_mask_dino_prompt_text, + dino_box_threshold=enhance_mask_box_threshold, + dino_text_threshold=enhance_mask_text_threshold, + dino_erode_or_dilate=async_task.dino_erode_or_dilate, + dino_debug=async_task.debugging_dino, + max_num_boxes=enhance_mask_sam_max_num_boxes, + model_type=enhance_mask_sam_model + )) + if len(mask.shape) == 3: + mask = mask[:, :, 0] + + if int(async_task.inpaint_erode_or_dilate) != 0: + mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate) + + if async_task.debugging_enhance_masks_checkbox: + async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) + yield_result(async_task, mask, async_task.black_out_nsfw, False, + async_task.disable_intermediate_results) + + print(f'[Enhance] {dino_detection_count} boxes detected') + print(f'[Enhance] {sam_detection_count} segments detected in boxes') + print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask') + + if enhance_mask_model == 'sam' and ( + dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0): + print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping') + continue + + base_model_additional_loras_enhance = [] + inpaint_head_model_path_enhance = None + inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail + + if inpaint_parameterized_enhance: + progressbar(async_task, current_progress, 'Downloading inpainter ...') + inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models( + async_task.inpaint_engine) + if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance: + base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)] + + progressbar(async_task, current_progress, 'Preparing enhance prompts ...') + enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts, + 'prompt') + enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt, + async_task.translate_prompts, 'negative prompt') + + # positive and negative conditioning aren't available here anymore + # if not inpaint_parameterized_enhance and enhance_prompt == async_task.prompt and enhance_negative_prompt == async_task.negative_prompt: + # task_enhance = task.copy() + # tasks_enhance = tasks.copy() + # else: + tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt, + enhance_negative_prompt, + base_model_additional_loras_enhance, + 1, True, + use_expansion, use_style, + use_synthetic_refiner) + task_enhance = tasks_enhance[0] + + # TODO could support vary, upscale and CN in the future + # if 'cn' in goals: + # apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width) + if async_task.freeu_enabled: + apply_freeu(async_task) + patch_samplers(async_task) + + goals_enhance = ['inpaint'] + enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint( + async_task, None, inpaint_head_model_path_enhance, img, mask, + inpaint_parameterized_enhance, enhance_inpaint_strength, + enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent, + current_progress, True) + + try: + imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback, + controlnet_canny_path, controlnet_cpds_path, + current_task_id, enhance_inpaint_strength, + final_scheduler_name, goals_enhance, + initial_latent_enhance, switch, + task_enhance['c'], task_enhance['uc'], + task_enhance, tasks_enhance, tiled, + use_expansion, width_enhance, height_enhance, + current_progress, total_count) + img = imgs2[0] + + except ldm_patched.modules.model_management.InterruptProcessingException: + if async_task.last_stop == 'skip': + print('User skipped') + async_task.last_stop = False + continue + else: + print('User stopped') + break + + del task_enhance['c'], task_enhance['uc'] # Save memory + enhancement_task_time = time.perf_counter() - enhancement_task_start_time + print(f'Enhancement time: {enhancement_task_time:.2f} seconds') + + stop_processing(async_task, processing_start_time) + return while True: time.sleep(0.01)