From 52ae346c9dfedaf0df61a6f9d376ec9c5856e497 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Wed, 15 Nov 2023 22:00:28 +0100 Subject: [PATCH] add nsfw image censoring activatable via config, uses CompVis/stable-diffusion-safety-checker --- modules/async_worker.py | 9 +++++-- modules/censor.py | 54 +++++++++++++++++++++++++++++++++++++++++ modules/config.py | 5 ++++ 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 modules/censor.py diff --git a/modules/async_worker.py b/modules/async_worker.py index a6807547..362cddde 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -30,6 +30,7 @@ def worker(): import fooocus_extras.ip_adapter as ip_adapter import fooocus_extras.face_crop + from modules.censor import censor_batch from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from modules.expansion import safe_str @@ -50,12 +51,16 @@ def worker(): print(f'[Fooocus] {text}') outputs.append(['preview', (number, text, None)]) - def yield_result(imgs, do_not_show_finished_images=False): + def yield_result(imgs, do_not_show_finished_images=False, progressbar_index=13): global global_results if not isinstance(imgs, list): imgs = [imgs] + if modules.config.default_black_out_nsfw: + progressbar(progressbar_index, 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + global_results = global_results + imgs if do_not_show_finished_images: @@ -711,7 +716,7 @@ def worker(): d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) - yield_result(imgs, do_not_show_finished_images=len(tasks) == 1) + yield_result(imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))) except fcbh.model_management.InterruptProcessingException as e: if shared.last_stop == 'skip': print('User skipped') diff --git a/modules/censor.py b/modules/censor.py new file mode 100644 index 00000000..fac6db09 --- /dev/null +++ b/modules/censor.py @@ -0,0 +1,54 @@ +# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py + +import numpy as np +import torch +import modules.core as core + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + + +def numpy_to_pil(image): + image = (image * 255).round().astype("uint8") + + #pil_image = Image.fromarray(image, 'RGB') + pil_image = Image.fromarray(image) + + return pil_image + + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_single(x): + x_checked_image, has_nsfw_concept = check_safety(x) + + # replace image with black pixels, keep dimensions + # workaround due to different numpy / pytorch image matrix format + if has_nsfw_concept[0]: + imageshape = x_checked_image.shape + x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) + + return x_checked_image + + +def censor_batch(images): + images = [censor_single(image) for image in images] + + return images diff --git a/modules/config.py b/modules/config.py index 7216a26d..ac81ef0b 100644 --- a/modules/config.py +++ b/modules/config.py @@ -268,6 +268,11 @@ default_overwrite_switch = get_config_item_or_set_default( default_value=-1, validator=lambda x: isinstance(x, int) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) def add_ratio(x):