From 5e3816a8b35437005d7efe43da97f24cd3dc0321 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 19 Feb 2024 23:12:33 +0100 Subject: [PATCH] fix: add nsfw filter support again accidentally deleted when merging --- modules/async_worker.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 2ffebb43..d3652b72 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -29,6 +29,7 @@ def worker(): import shared import random import copy + import cv2 import modules.default_pipeline as pipeline import modules.core as core import modules.flags as flags @@ -53,6 +54,7 @@ def worker(): from modules.upscaler import perform_upscale from modules.flags import Performance, lora_count from modules.meta_parser import get_metadata_parser, MetadataScheme + from pathlib import Path pid = os.getpid() print(f'Started worker with PID {pid}') @@ -70,10 +72,17 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, do_not_show_finished_images=False): + def yield_result(async_task, imgs, black_out_nsfw, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') + imgs_censor = [cv2.imread(img) for img in imgs] + imgs_censor = censor_batch(imgs_censor) + for i, img in enumerate(imgs): + cv2.imwrite(img, imgs_censor[i]) + async_task.results = async_task.results + imgs if do_not_show_finished_images: @@ -867,7 +876,8 @@ def worker(): img_paths.append(log(x, d, metadata_parser, output_format)) yield_result(async_task, img_paths, black_out_nsfw, do_not_show_finished_images=len(tasks) == 1 - or disable_intermediate_results or sampler_name == 'lcm') + or disable_intermediate_results or sampler_name == 'lcm', + progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped')