56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
|
|
import numpy as np
|
|
import os
|
|
|
|
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
|
|
from transformers import CLIPFeatureExtractor, CLIPConfig
|
|
from PIL import Image
|
|
import modules.config
|
|
|
|
safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
|
|
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
|
|
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")
|
|
|
|
safety_feature_extractor = None
|
|
safety_checker = None
|
|
|
|
|
|
def numpy_to_pil(image):
|
|
image = (image * 255).round().astype("uint8")
|
|
pil_image = Image.fromarray(image)
|
|
|
|
return pil_image
|
|
|
|
|
|
# check and replace nsfw content
|
|
def check_safety(x_image):
|
|
global safety_feature_extractor, safety_checker
|
|
|
|
if safety_feature_extractor is None or safety_checker is None:
|
|
safety_checker_model = modules.config.downloading_safety_checker_model()
|
|
safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path)
|
|
clip_config = CLIPConfig.from_json_file(config_path)
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
|
|
|
|
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
|
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
|
|
|
return x_checked_image, has_nsfw_concept
|
|
|
|
|
|
def censor_single(x):
|
|
x_checked_image, has_nsfw_concept = check_safety(x)
|
|
|
|
# replace image with black pixels, keep dimensions
|
|
# workaround due to different numpy / pytorch image matrix format
|
|
if has_nsfw_concept[0]:
|
|
imageshape = x_checked_image.shape
|
|
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
|
|
|
|
return x_checked_image
|
|
|
|
|
|
def censor_batch(images):
|
|
images = [censor_single(image) for image in images]
|
|
|
|
return images |