use config to set cache dir for safety checker
This commit is contained in:
parent
ea764855f8
commit
cdaaec3e71
|
|
@ -18,6 +18,7 @@ config.txt
|
||||||
config_modification_tutorial.txt
|
config_modification_tutorial.txt
|
||||||
user_path_config.txt
|
user_path_config.txt
|
||||||
user_path_config-deprecated.txt
|
user_path_config-deprecated.txt
|
||||||
|
/models/safety_checker_models
|
||||||
/modules/*.png
|
/modules/*.png
|
||||||
/repositories
|
/repositories
|
||||||
/venv
|
/venv
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import modules.core as core
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import modules.config
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_feature_extractor = None
|
safety_feature_extractor = None
|
||||||
|
|
@ -27,8 +28,8 @@ def check_safety(x_image):
|
||||||
global safety_feature_extractor, safety_checker
|
global safety_feature_extractor, safety_checker
|
||||||
|
|
||||||
if safety_feature_extractor is None:
|
if safety_feature_extractor is None:
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||||
|
|
||||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne
|
||||||
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
||||||
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
||||||
path_outputs = get_dir_or_set_default('path_outputs', '../outputs/')
|
path_outputs = get_dir_or_set_default('path_outputs', '../outputs/')
|
||||||
|
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
|
||||||
|
|
||||||
|
|
||||||
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue