Merge branch 'feature/add-nsfw-filter'
# Conflicts: # modules/async_worker.py # modules/censor.py # modules/config.py
This commit is contained in:
commit
cc9f3d6c71
|
|
@ -18,6 +18,7 @@ config.txt
|
|||
config_modification_tutorial.txt
|
||||
user_path_config.txt
|
||||
user_path_config-deprecated.txt
|
||||
/models/safety_checker_models
|
||||
/modules/*.png
|
||||
/repositories
|
||||
/venv
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import modules.core as core
|
|||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
import modules.config
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
|
|
@ -27,8 +28,8 @@ def check_safety(x_image):
|
|||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne
|
|||
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
||||
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
||||
path_outputs = get_dir_or_set_default('path_outputs', '../outputs/')
|
||||
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
|
||||
|
||||
|
||||
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
||||
|
|
|
|||
Loading…
Reference in New Issue