From cdaaec3e71866c0ceedcd368c2ca329a223681a0 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 16 Dec 2023 18:54:46 +0100 Subject: [PATCH] use config to set cache dir for safety checker --- .gitignore | 1 + models/safety_checker_models/put_safety_checker_models_here | 0 modules/censor.py | 5 +++-- modules/config.py | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 models/safety_checker_models/put_safety_checker_models_here diff --git a/.gitignore b/.gitignore index 05ce1df8..da9cf974 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ config.txt config_modification_tutorial.txt user_path_config.txt user_path_config-deprecated.txt +/models/safety_checker_models /modules/*.png /repositories /venv diff --git a/models/safety_checker_models/put_safety_checker_models_here b/models/safety_checker_models/put_safety_checker_models_here new file mode 100644 index 00000000..e69de29b diff --git a/modules/censor.py b/modules/censor.py index fac6db09..e2352218 100644 --- a/modules/censor.py +++ b/modules/censor.py @@ -7,6 +7,7 @@ import modules.core as core from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from PIL import Image +import modules.config safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = None @@ -27,8 +28,8 @@ def check_safety(x_image): global safety_feature_extractor, safety_checker if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) diff --git a/modules/config.py b/modules/config.py index ec0b49c3..c6515bd9 100644 --- a/modules/config.py +++ b/modules/config.py @@ -126,6 +126,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_outputs = get_dir_or_set_default('path_outputs', '../outputs/') +path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):