Merge branch 'feature/add-nsfw-filter'

# Conflicts:
#	modules/async_worker.py
#	modules/censor.py
#	modules/config.py
This commit is contained in:
Manuel Schmid 2023-12-16 19:39:19 +01:00
commit cc9f3d6c71
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 5 additions and 2 deletions

1
.gitignore vendored
View File

@ -18,6 +18,7 @@ config.txt
config_modification_tutorial.txt
user_path_config.txt
user_path_config-deprecated.txt
/models/safety_checker_models
/modules/*.png
/repositories
/venv

View File

@ -7,6 +7,7 @@ import modules.core as core
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
import modules.config
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
@ -27,8 +28,8 @@ def check_safety(x_image):
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)

View File

@ -126,6 +126,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
path_outputs = get_dir_or_set_default('path_outputs', '../outputs/')
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):