diff --git a/css/style.css b/css/style.css index b9e6e2ce..b5f7a448 100644 --- a/css/style.css +++ b/css/style.css @@ -27,6 +27,7 @@ progress { border-radius: 5px; /* Round the corners of the progress bar */ background-color: #f3f3f3; /* Light grey background */ width: 100%; + vertical-align: middle !important; } /* Style the progress bar container */ @@ -69,6 +70,11 @@ progress::after { height: 30px !important; } +.progress-bar span { + text-align: right; + width: 200px; +} + .type_row{ height: 80px !important; } diff --git a/development.md b/development.md new file mode 100644 index 00000000..bbb3def9 --- /dev/null +++ b/development.md @@ -0,0 +1,11 @@ +## Running unit tests + +Native python: +``` +python -m unittest tests/ +``` + +Embedded python (Windows zip file installation method): +``` +..\python_embeded\python.exe -m unittest +``` diff --git a/extras/censor.py b/extras/censor.py new file mode 100644 index 00000000..45617fd8 --- /dev/null +++ b/extras/censor.py @@ -0,0 +1,60 @@ +import os + +import numpy as np +import torch +from transformers import CLIPConfig, CLIPImageProcessor + +import ldm_patched.modules.model_management as model_management +import modules.config +from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker +from ldm_patched.modules.model_patcher import ModelPatcher + +safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker') +config_path = os.path.join(safety_checker_repo_root, "configs", "config.json") +preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json") + + +class Censor: + def __init__(self): + self.safety_checker_model: ModelPatcher | None = None + self.clip_image_processor: CLIPImageProcessor | None = None + self.load_device = torch.device('cpu') + self.offload_device = torch.device('cpu') + + def init(self): + if self.safety_checker_model is None and self.clip_image_processor is None: + safety_checker_model = modules.config.downloading_safety_checker_model() + self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path) + clip_config = CLIPConfig.from_json_file(config_path) + model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) + model.eval() + + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + + model.to(self.offload_device) + + self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) + + def censor(self, images: list | np.ndarray) -> list | np.ndarray: + self.init() + model_management.load_model_gpu(self.safety_checker_model) + + single = False + if not isinstance(images, list) or isinstance(images, np.ndarray): + images = [images] + single = True + + safety_checker_input = self.clip_image_processor(images, return_tensors="pt") + safety_checker_input.to(device=self.load_device) + checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images, + clip_input=safety_checker_input.pixel_values) + checked_images = [image.astype(np.uint8) for image in checked_images] + + if single: + checked_images = checked_images[0] + + return checked_images + + +default_censor = Censor().censor diff --git a/extras/safety_checker/configs/config.json b/extras/safety_checker/configs/config.json new file mode 100644 index 00000000..aa454d22 --- /dev/null +++ b/extras/safety_checker/configs/config.json @@ -0,0 +1,171 @@ +{ + "_name_or_path": "clip-vit-large-patch14/", + "architectures": [ + "SafetyChecker" + ], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + "text_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": 1, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false, + "vocab_size": 49408 + }, + "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12 + }, + "torch_dtype": "float32", + "transformers_version": null, + "vision_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "clip_vision_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 16, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 24, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false + }, + "vision_config_dict": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14 + } +} diff --git a/extras/safety_checker/configs/preprocessor_config.json b/extras/safety_checker/configs/preprocessor_config.json new file mode 100644 index 00000000..5294955f --- /dev/null +++ b/extras/safety_checker/configs/preprocessor_config.json @@ -0,0 +1,20 @@ +{ + "crop_size": 224, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_resize": true, + "feature_extractor_type": "CLIPFeatureExtractor", + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "resample": 3, + "size": 224 +} diff --git a/extras/safety_checker/models/safety_checker.py b/extras/safety_checker/models/safety_checker.py new file mode 100644 index 00000000..ea38bf03 --- /dev/null +++ b/extras/safety_checker/models/safety_checker.py @@ -0,0 +1,126 @@ +# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/fooocus_version.py b/fooocus_version.py index d61557b0..85e56125 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.3.3 (mashb1t)' +version = '2.4.0-rc3 (mashb1t)' diff --git a/language/en.json b/language/en.json index fb6e3fe5..ca4acecc 100644 --- a/language/en.json +++ b/language/en.json @@ -65,6 +65,8 @@ "Disable seed increment": "Disable seed increment", "Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.", "Read wildcards in order": "Read wildcards in order", + "Black Out NSFW": "Black Out NSFW", + "Use black image if NSFW is detected.": "Use black image if NSFW is detected.", "\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "Image Style": "Image Style", "Fooocus V2": "Fooocus V2", diff --git a/models/safety_checker/put_safety_checker_models_here b/models/safety_checker/put_safety_checker_models_here new file mode 100644 index 00000000..e69de29b diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/async_worker.py b/modules/async_worker.py index 08fe2373..119abb83 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -46,12 +46,13 @@ def worker(): import fooocus_version import args_manager - from modules.censor import censor_batch, censor_single - from modules.sdxl_styles import get_random_style, random_style_name, apply_style, apply_wildcards, fooocus_expansion, apply_arrays + from extras.censor import default_censor + from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ - get_shape_ceil, resample_image, erode_or_dilate, get_enabled_loras + from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, + get_shape_ceil, resample_image, erode_or_dilate, get_enabled_loras, + parse_lora_references_from_prompt, apply_wildcards) from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -72,13 +73,14 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, + progressbar_index=flags.preparation_step_count): if not isinstance(imgs, list): imgs = [imgs] if censor and (modules.config.default_black_out_nsfw or black_out_nsfw): progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') - imgs = censor_batch(imgs) + imgs = default_censor(imgs) async_task.results = async_task.results + imgs @@ -156,7 +158,8 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in + range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -206,7 +209,8 @@ def worker(): inpaint_erode_or_dilate = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS + metadata_scheme = MetadataScheme( + args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS cn_tasks = {x: [] for x in flags.ip_list} for _ in range(flags.controlnet_image_count): @@ -464,14 +468,17 @@ def worker(): extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] - progressbar(async_task, 3, 'Loading models ...') + progressbar(async_task, 2, 'Loading models ...') + + loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) progressbar(async_task, 3, 'Processing prompts ...') tasks = [] - + for i in range(image_number): if disable_seed_increment: task_seed = seed % (constants.MAX_SEED + 1) @@ -482,8 +489,10 @@ def worker(): task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] @@ -526,25 +535,25 @@ def worker(): if use_expansion: for i, t in enumerate(tasks): - progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...') + progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...') expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed']) print(f'[Prompt Expansion] {expansion}') t['expansion'] = expansion t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy. for i, t in enumerate(tasks): - progressbar(async_task, 7, f'Encoding positive #{i + 1} ...') + progressbar(async_task, 5, f'Encoding positive #{i + 1} ...') t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k']) for i, t in enumerate(tasks): if abs(float(cfg_scale) - 1.0) < 1e-4: t['uc'] = pipeline.clone_cond(t['c']) else: - progressbar(async_task, 10, f'Encoding negative #{i + 1} ...') + progressbar(async_task, 6, f'Encoding negative #{i + 1} ...') t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k']) if len(goals) > 0: - progressbar(async_task, 13, 'Image processing ...') + progressbar(async_task, 7, 'Image processing ...') if 'vary' in goals: if 'subtle' in uov_method: @@ -565,7 +574,7 @@ def worker(): uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil) initial_pixels = core.numpy_to_pytorch(uov_input_image) - progressbar(async_task, 13, 'VAE encoding ...') + progressbar(async_task, 8, 'VAE encoding ...') candidate_vae, _ = pipeline.get_candidate_vae( steps=steps, @@ -582,7 +591,7 @@ def worker(): if 'upscale' in goals: H, W, C = uov_input_image.shape - progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...') + progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...') uov_input_image = perform_upscale(uov_input_image) print(f'Image upscaled.') @@ -615,10 +624,11 @@ def worker(): direct_return = False if direct_return: - d = [('Upscale', 'upscale', 'Fast 2x')] + d = [('Upscale (Fast)', 'upscale_fast', '2x')] if modules.config.default_black_out_nsfw or black_out_nsfw: progressbar(async_task, 100, 'Checking for NSFW content ...') - uov_input_image = censor_single(uov_input_image) + uov_input_image = default_censor(uov_input_image) + progressbar(async_task, 100, 'Saving image to system ...') uov_input_image_path = log(uov_input_image, d, output_format=output_format) yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True) return @@ -630,7 +640,7 @@ def worker(): denoising_strength = overwrite_upscale_strength initial_pixels = core.numpy_to_pytorch(uov_input_image) - progressbar(async_task, 13, 'VAE encoding ...') + progressbar(async_task, 10, 'VAE encoding ...') candidate_vae, _ = pipeline.get_candidate_vae( steps=steps, @@ -687,7 +697,7 @@ def worker(): yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) return - progressbar(async_task, 13, 'VAE Inpaint encoding ...') + progressbar(async_task, 11, 'VAE Inpaint encoding ...') inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill) inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image) @@ -707,7 +717,7 @@ def worker(): latent_swap = None if candidate_vae_swap is not None: - progressbar(async_task, 13, 'VAE SD15 encoding ...') + progressbar(async_task, 12, 'VAE SD15 encoding ...') latent_swap = core.encode_vae( vae=candidate_vae_swap, pixels=inpaint_pixel_fill)['samples'] @@ -833,15 +843,17 @@ def worker(): zsnr=False)[0] print(f'Using {scheduler_name} scheduler.') - async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)]) + async_task.yields.append(['preview', (flags.preparation_step_count, 'Moving model to GPU ...', None)]) def callback(step, x0, x, total_steps, y): done_steps = current_task_id * steps + step async_task.yields.append(['preview', ( - int(15.0 + 85.0 * float(done_steps) / float(all_steps)), - f'Sampling Image {current_task_id + 1}/{image_number}, Step {step + 1}/{total_steps} ...', y)]) + int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)), + f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{image_number} ...', y)]) for current_task_id, task in enumerate(tasks): + current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(current_task_id * steps) / float(all_steps)) + progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{image_number} ...') execution_start_time = time.perf_counter() try: @@ -885,16 +897,19 @@ def worker(): img_paths = [] + current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float((current_task_id + 1) * steps) / float(all_steps)) if modules.config.default_black_out_nsfw or black_out_nsfw: - progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)), - 'Checking for NSFW content ...') - imgs = censor_batch(imgs) + progressbar(async_task, current_progress, 'Checking for NSFW content ...') + imgs = default_censor(imgs) + + progressbar(async_task, current_progress, f'Saving image {current_task_id + 1}/{image_number} to system ...') for x in imgs: d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), + ('Styles', 'styles', + str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: @@ -917,7 +932,8 @@ def worker(): if refiner_swap_method != flags.refiner_swap_method: d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: - d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) + d.append( + ('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) @@ -937,16 +953,13 @@ def worker(): metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], steps, base_model_name, refiner_model_name, loras, vae_name) - d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) + d.append(('Metadata Scheme', 'metadata_scheme', + metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format, task)) yield_result(async_task, img_paths, black_out_nsfw, False, - do_not_show_finished_images=len(tasks) == 1 - or disable_intermediate_results - or performance_selection == Performance.EXTREME_SPEED - or performance_selection == Performance.LIGHTNING) - + do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/censor.py b/modules/censor.py deleted file mode 100644 index 724042e7..00000000 --- a/modules/censor.py +++ /dev/null @@ -1,50 +0,0 @@ -# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py -import numpy as np - -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor -from PIL import Image -import modules.config - -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - - -def numpy_to_pil(image): - image = (image * 255).round().astype("uint8") - pil_image = Image.fromarray(image) - - return pil_image - - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker - - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) - - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - - return x_checked_image, has_nsfw_concept - - -def censor_single(x): - x_checked_image, has_nsfw_concept = check_safety(x) - - # replace image with black pixels, keep dimensions - # workaround due to different numpy / pytorch image matrix format - if has_nsfw_concept[0]: - imageshape = x_checked_image.shape - x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) - - return x_checked_image - - -def censor_batch(images): - images = [censor_single(image) for image in images] - - return images diff --git a/modules/config.py b/modules/config.py index 6b936845..21b986a5 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,7 +8,8 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.util import get_files_from_folder, makedirs_with_log +from modules.util import makedirs_with_log +from modules.extra_utils import get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme @@ -20,7 +21,7 @@ def get_config_path(key, default_value): else: return os.path.abspath(default_value) - +wildcards_max_bfs_depth = 64 config_path = get_config_path('config_path', "./config.txt") config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") config_dict = {} @@ -199,6 +200,7 @@ path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vi path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/') +path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/') path_outputs = get_path_output() def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): @@ -463,6 +465,11 @@ example_inpaint_prompts = get_config_item_or_set_default( ], validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) default_save_metadata_to_images = get_config_item_or_set_default( key='default_save_metadata_to_images', default_value=False, @@ -731,5 +738,13 @@ def downloading_upscale_model(): ) return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') +def downloading_safety_checker_model(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin', + model_dir=path_safety_checker, + file_name='stable-diffusion-safety-checker.bin' + ) + return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin') + update_files() diff --git a/modules/extra_utils.py b/modules/extra_utils.py new file mode 100644 index 00000000..3e95e8b5 --- /dev/null +++ b/modules/extra_utils.py @@ -0,0 +1,20 @@ +import os + + +def get_files_from_folder(folder_path, extensions=None, name_filter=None): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + + for root, _, files in os.walk(folder_path, topdown=False): + relative_path = os.path.relpath(root, folder_path) + if relative_path == ".": + relative_path = "" + for filename in sorted(files, key=lambda s: s.casefold()): + _, file_extension = os.path.splitext(filename) + if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): + path = os.path.join(relative_path, filename) + filenames.append(path) + + return filenames diff --git a/modules/flags.py b/modules/flags.py index e939849e..a406af4f 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -97,6 +97,7 @@ metadata_scheme = [ ] controlnet_image_count = 4 +preparation_step_count = 13 class OutputFormat(Enum): diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 5b6afb59..12ab6c5c 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -2,14 +2,12 @@ import os import re import json import math -import modules.config -from modules.util import get_files_from_folder +from modules.extra_utils import get_files_from_folder from random import Random # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) -wildcards_max_bfs_depth = 64 def normalize_key(k): @@ -25,7 +23,6 @@ def normalize_key(k): styles = {} - styles_files = get_files_from_folder(styles_path, ['.json']) for x in ['sdxl_styles_fooocus.json', @@ -65,34 +62,7 @@ def apply_style(style, positive): return p.replace('{prompt}', positive).splitlines(), n.splitlines() -def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): - for _ in range(wildcards_max_bfs_depth): - placeholders = re.findall(r'__([\w-]+)__', wildcard_text) - if len(placeholders) == 0: - return wildcard_text - - print(f'[Wildcards] processing: {wildcard_text}') - for placeholder in placeholders: - try: - matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] - words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() - words = [x for x in words if x != ''] - assert len(words) > 0 - if read_wildcards_in_order: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) - else: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) - except: - print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' - f'Using "{placeholder}" as a normal word.') - wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) - print(f'[Wildcards] {wildcard_text}') - - print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') - return wildcard_text - - -def get_words(arrays, totalMult, index): +def get_words(arrays, total_mult, index): if len(arrays) == 1: return [arrays[0].split(',')[index]] else: @@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index): index -= index % len(words) index /= len(words) index = math.floor(index) - return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) + return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index) def apply_arrays(text, index): diff --git a/modules/util.py b/modules/util.py index d2feecb6..8e85ffbe 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,11 +1,12 @@ -import typing - import numpy as np import datetime import random import math import os import cv2 +import re +from typing import List, Tuple, AnyStr, NamedTuple + import json import hashlib @@ -14,8 +15,16 @@ from PIL import Image import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + + +# Regexp compiled once. Matches entries with the following pattern: +# +# +LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) + HASH_SHA256_LENGTH = 10 + def erode_or_dilate(x, k): k = int(k) if k > 0: @@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'): return date_string, os.path.abspath(result), filename -def get_files_from_folder(folder_path, extensions=None, name_filter=None): - if not os.path.isdir(folder_path): - raise ValueError("Folder path is not a valid directory.") - - filenames = [] - - for root, dirs, files in os.walk(folder_path, topdown=False): - relative_path = os.path.relpath(root, folder_path) - if relative_path == ".": - relative_path = "" - for filename in sorted(files, key=lambda s: s.casefold()): - _, file_extension = os.path.splitext(filename) - if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): - path = os.path.join(relative_path, filename) - filenames.append(path) - - return filenames - - def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): print(f"Calculating sha256 for {filename}: ", end='') if use_addnet_hash: @@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt): return list(reversed(extracted)), real_prompt, negative_prompt -class PromptStyle(typing.NamedTuple): +class PromptStyle(NamedTuple): name: str prompt: str negative_prompt: str @@ -382,10 +372,6 @@ def get_file_from_folder_list(name, folders): return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) -def ordinal_suffix(number: int) -> str: - return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') - - def makedirs_with_log(path): try: os.makedirs(path, exist_ok=True) @@ -394,4 +380,47 @@ def makedirs_with_log(path): def get_enabled_loras(loras: list) -> list: - return [[lora[1], lora[2]] for lora in loras if lora[0]] + return [(lora[1], lora[2]) for lora in loras if lora[0]] + + +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: + new_loras = [] + updated_loras = [] + for token in prompt.split(","): + m = LORAS_PROMPT_PATTERN.match(token) + + if m: + new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) + + for lora in loras + new_loras: + if lora[0] != "None": + updated_loras.append(lora) + + return updated_loras[:loras_limit] + + +def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: + for _ in range(modules.config.wildcards_max_bfs_depth): + placeholders = re.findall(r'__([\w-]+)__', wildcard_text) + if len(placeholders) == 0: + return wildcard_text + + print(f'[Wildcards] processing: {wildcard_text}') + for placeholder in placeholders: + try: + matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] + words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() + words = [x for x in words if x != ''] + assert len(words) > 0 + if read_wildcards_in_order: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) + else: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) + except: + print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' + f'Using "{placeholder}" as a normal word.') + wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) + print(f'[Wildcards] {wildcard_text}') + + print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') + return wildcard_text diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..c424468f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..0698dcc8 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,48 @@ +import unittest + +from modules import util + + +class TestUtils(unittest.TestCase): + def test_can_parse_tokens_with_lora(self): + test_cases = [ + { + "input": ("some prompt, very cool, , cool ", [], 5), + "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + }, + # Test can not exceed limit + { + "input": ("some prompt, very cool, , cool ", [], 1), + "output": [("hey-lora.safetensors", 0.4)], + }, + # test Loras from UI take precedence over prompt + { + "input": ( + "some prompt, very cool, , , , , , ", + [("hey-lora.safetensors", 0.4)], + 5, + ), + "output": [ + ("hey-lora.safetensors", 0.4), + ("l1.safetensors", 0.4), + ("l2.safetensors", -0.2), + ("l3.safetensors", 0.3), + ("l4.safetensors", 0.5), + ], + }, + # Test lora specification not separated by comma are ignored, only latest specified is used + { + "input": ("some prompt, very cool, ", [], 3), + "output": [("you-lora.safetensors", 0.2)], + }, + { + "input": (", , and ", [], 6), + "output": [] + } + ] + + for test in test_cases: + prompt, loras, loras_limit = test["input"] + expected = test["output"] + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit) + self.assertEqual(expected, actual) diff --git a/webui.py b/webui.py index 78bf984a..0536f1f6 100644 --- a/webui.py +++ b/webui.py @@ -512,7 +512,8 @@ with shared.gradio_root: info='Use black image if NSFW is detected.') black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), - inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False) + inputs=black_out_nsfw, outputs=disable_preview, queue=False, + show_progress=False) if not args_manager.args.disable_metadata: save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,