diff --git a/.gitignore b/.gitignore index 05ce1df8..da9cf974 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ config.txt config_modification_tutorial.txt user_path_config.txt user_path_config-deprecated.txt +/models/safety_checker_models /modules/*.png /repositories /venv diff --git a/models/safety_checker_models/put_safety_checker_models_here b/models/safety_checker_models/put_safety_checker_models_here new file mode 100644 index 00000000..e69de29b diff --git a/modules/censor.py b/modules/censor.py index fac6db09..e2352218 100644 --- a/modules/censor.py +++ b/modules/censor.py @@ -7,6 +7,7 @@ import modules.core as core from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from PIL import Image +import modules.config safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = None @@ -27,8 +28,8 @@ def check_safety(x_image): global safety_feature_extractor, safety_checker if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) diff --git a/modules/config.py b/modules/config.py index cab19d03..b4fbf93c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -126,6 +126,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_outputs = get_dir_or_set_default('path_outputs', '../outputs/') +path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):