From 1d606ecb7e871aa104c56fc8b4f24095f6ab8db8 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 24 Feb 2024 18:09:45 +0100 Subject: [PATCH] feat: optimize image censoring Does not save 2x to file (log and yield), but only once (log). --- modules/async_worker.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index a69200d4..19a4da6f 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -44,7 +44,7 @@ def worker(): import fooocus_version import args_manager - from modules.censor import censor_batch + from modules.censor import censor_batch, censor_single from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log @@ -71,16 +71,13 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] - if modules.config.default_black_out_nsfw or black_out_nsfw: + if censor and (modules.config.default_black_out_nsfw or black_out_nsfw): progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') - imgs_censor = [cv2.imread(img) for img in imgs] - imgs_censor = censor_batch(imgs_censor) - for i, img in enumerate(imgs): - cv2.imwrite(img, imgs_censor[i]) + imgs = censor_batch(imgs) async_task.results = async_task.results + imgs @@ -560,8 +557,11 @@ def worker(): if direct_return: d = [('Upscale', 'upscale', 'Fast 2x')] + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, 100, 'Checking for NSFW content ...') + uov_input_image = censor_single(uov_input_image) uov_input_image_path = log(uov_input_image, d, output_format=output_format) - yield_result(async_task, uov_input_image_path, black_out_nsfw, do_not_show_finished_images=True) + yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True) return tiled = True @@ -827,6 +827,11 @@ def worker(): img_paths = [] + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)), + 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + for x in imgs: d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), @@ -874,9 +879,8 @@ def worker(): d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format)) - yield_result(async_task, img_paths, black_out_nsfw, do_not_show_finished_images=len(tasks) == 1 - or disable_intermediate_results or sampler_name == 'lcm', - progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) + yield_result(async_task, img_paths, black_out_nsfw, False, + do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results or sampler_name == 'lcm') except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped')