add nsfw image censoring

activatable via config, uses CompVis/stable-diffusion-safety-checker
This commit is contained in:
Manuel Schmid 2023-11-15 22:00:28 +01:00
parent 943098f8da
commit 52ae346c9d
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 66 additions and 2 deletions

View File

@ -30,6 +30,7 @@ def worker():
import fooocus_extras.ip_adapter as ip_adapter
import fooocus_extras.face_crop
from modules.censor import censor_batch
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion
from modules.private_logger import log
from modules.expansion import safe_str
@ -50,12 +51,16 @@ def worker():
print(f'[Fooocus] {text}')
outputs.append(['preview', (number, text, None)])
def yield_result(imgs, do_not_show_finished_images=False):
def yield_result(imgs, do_not_show_finished_images=False, progressbar_index=13):
global global_results
if not isinstance(imgs, list):
imgs = [imgs]
if modules.config.default_black_out_nsfw:
progressbar(progressbar_index, 'Checking for NSFW content ...')
imgs = censor_batch(imgs)
global_results = global_results + imgs
if do_not_show_finished_images:
@ -711,7 +716,7 @@ def worker():
d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3)
yield_result(imgs, do_not_show_finished_images=len(tasks) == 1)
yield_result(imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))))
except fcbh.model_management.InterruptProcessingException as e:
if shared.last_stop == 'skip':
print('User skipped')

54
modules/censor.py Normal file
View File

@ -0,0 +1,54 @@
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
import numpy as np
import torch
import modules.core as core
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None
def numpy_to_pil(image):
image = (image * 255).round().astype("uint8")
#pil_image = Image.fromarray(image, 'RGB')
pil_image = Image.fromarray(image)
return pil_image
# check and replace nsfw content
def check_safety(x_image):
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
return x_checked_image, has_nsfw_concept
def censor_single(x):
x_checked_image, has_nsfw_concept = check_safety(x)
# replace image with black pixels, keep dimensions
# workaround due to different numpy / pytorch image matrix format
if has_nsfw_concept[0]:
imageshape = x_checked_image.shape
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
return x_checked_image
def censor_batch(images):
images = [censor_single(image) for image in images]
return images

View File

@ -268,6 +268,11 @@ default_overwrite_switch = get_config_item_or_set_default(
default_value=-1,
validator=lambda x: isinstance(x, int)
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
validator=lambda x: isinstance(x, bool)
)
def add_ratio(x):