feat: optimize model management of image censoring (#2960)
now follows general Fooocus model management principles + includes code optimisations for reusability
This commit is contained in:
parent
dad228907e
commit
35b74dfa64
|
|
@ -1,56 +1,60 @@
|
|||
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import CLIPFeatureExtractor, CLIPConfig
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPConfig, CLIPImageProcessor
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
import modules.config
|
||||
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
|
||||
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
|
||||
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")
|
||||
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
class Censor:
|
||||
def __init__(self):
|
||||
self.safety_checker_model: ModelPatcher | None = None
|
||||
self.clip_image_processor: CLIPImageProcessor | None = None
|
||||
self.load_device = torch.device('cpu')
|
||||
self.offload_device = torch.device('cpu')
|
||||
|
||||
def numpy_to_pil(image):
|
||||
image = (image * 255).round().astype("uint8")
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
return pil_image
|
||||
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None or safety_checker is None:
|
||||
def init(self):
|
||||
if self.safety_checker_model is None and self.clip_image_processor is None:
|
||||
safety_checker_model = modules.config.downloading_safety_checker_model()
|
||||
safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path)
|
||||
self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path)
|
||||
clip_config = CLIPConfig.from_json_file(config_path)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
|
||||
model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
|
||||
model.eval()
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
model.to(self.offload_device)
|
||||
|
||||
self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
def censor(self, images: list | np.ndarray) -> list | np.ndarray:
|
||||
self.init()
|
||||
model_management.load_model_gpu(self.safety_checker_model)
|
||||
|
||||
single = False
|
||||
if not isinstance(images, list) or isinstance(images, np.ndarray):
|
||||
images = [images]
|
||||
single = True
|
||||
|
||||
safety_checker_input = self.clip_image_processor(images, return_tensors="pt")
|
||||
safety_checker_input.to(device=self.load_device)
|
||||
checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images,
|
||||
clip_input=safety_checker_input.pixel_values)
|
||||
checked_images = [image.astype(np.uint8) for image in checked_images]
|
||||
|
||||
if single:
|
||||
checked_images = checked_images[0]
|
||||
|
||||
return checked_images
|
||||
|
||||
|
||||
def censor_single(x):
|
||||
x_checked_image, has_nsfw_concept = check_safety(x)
|
||||
|
||||
# replace image with black pixels, keep dimensions
|
||||
# workaround due to different numpy / pytorch image matrix format
|
||||
if has_nsfw_concept[0]:
|
||||
imageshape = x_checked_image.shape
|
||||
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
|
||||
|
||||
return x_checked_image
|
||||
|
||||
|
||||
def censor_batch(images):
|
||||
images = [censor_single(image) for image in images]
|
||||
|
||||
return images
|
||||
default_censor = Censor().censor
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def worker():
|
|||
import fooocus_version
|
||||
import args_manager
|
||||
|
||||
from extras.censor import censor_batch, censor_single
|
||||
from extras.censor import default_censor
|
||||
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
|
||||
from modules.private_logger import log
|
||||
from extras.expansion import safe_str
|
||||
|
|
@ -78,7 +78,7 @@ def worker():
|
|||
|
||||
if censor and (modules.config.default_black_out_nsfw or black_out_nsfw):
|
||||
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
|
||||
imgs = censor_batch(imgs)
|
||||
imgs = default_censor(imgs)
|
||||
|
||||
async_task.results = async_task.results + imgs
|
||||
|
||||
|
|
@ -615,7 +615,8 @@ def worker():
|
|||
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
|
||||
if modules.config.default_black_out_nsfw or black_out_nsfw:
|
||||
progressbar(async_task, 100, 'Checking for NSFW content ...')
|
||||
uov_input_image = censor_single(uov_input_image)
|
||||
uov_input_image = default_censor(uov_input_image)
|
||||
progressbar(async_task, 100, 'Saving image to system ...')
|
||||
uov_input_image_path = log(uov_input_image, d, output_format=output_format)
|
||||
yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True)
|
||||
return
|
||||
|
|
@ -883,12 +884,12 @@ def worker():
|
|||
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
||||
|
||||
img_paths = []
|
||||
|
||||
current_progress = int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))
|
||||
if modules.config.default_black_out_nsfw or black_out_nsfw:
|
||||
progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)),
|
||||
'Checking for NSFW content ...')
|
||||
imgs = censor_batch(imgs)
|
||||
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
|
||||
imgs = default_censor(imgs)
|
||||
|
||||
progressbar(async_task, current_progress, 'Saving image to system ...')
|
||||
for x in imgs:
|
||||
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
||||
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
||||
|
|
|
|||
Loading…
Reference in New Issue