Merge branch 'feature/progress-bar'
# Conflicts: # fooocus_version.py # modules/async_worker.py # webui.py
This commit is contained in:
commit
dd5a14ac7f
|
|
@ -27,6 +27,7 @@ progress {
|
|||
border-radius: 5px; /* Round the corners of the progress bar */
|
||||
background-color: #f3f3f3; /* Light grey background */
|
||||
width: 100%;
|
||||
vertical-align: middle !important;
|
||||
}
|
||||
|
||||
/* Style the progress bar container */
|
||||
|
|
@ -69,6 +70,11 @@ progress::after {
|
|||
height: 30px !important;
|
||||
}
|
||||
|
||||
.progress-bar span {
|
||||
text-align: right;
|
||||
width: 200px;
|
||||
}
|
||||
|
||||
.type_row{
|
||||
height: 80px !important;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
## Running unit tests
|
||||
|
||||
Native python:
|
||||
```
|
||||
python -m unittest tests/
|
||||
```
|
||||
|
||||
Embedded python (Windows zip file installation method):
|
||||
```
|
||||
..\python_embeded\python.exe -m unittest
|
||||
```
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPConfig, CLIPImageProcessor
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
import modules.config
|
||||
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
|
||||
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
|
||||
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")
|
||||
|
||||
|
||||
class Censor:
|
||||
def __init__(self):
|
||||
self.safety_checker_model: ModelPatcher | None = None
|
||||
self.clip_image_processor: CLIPImageProcessor | None = None
|
||||
self.load_device = torch.device('cpu')
|
||||
self.offload_device = torch.device('cpu')
|
||||
|
||||
def init(self):
|
||||
if self.safety_checker_model is None and self.clip_image_processor is None:
|
||||
safety_checker_model = modules.config.downloading_safety_checker_model()
|
||||
self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path)
|
||||
clip_config = CLIPConfig.from_json_file(config_path)
|
||||
model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
|
||||
model.eval()
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
|
||||
model.to(self.offload_device)
|
||||
|
||||
self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
def censor(self, images: list | np.ndarray) -> list | np.ndarray:
|
||||
self.init()
|
||||
model_management.load_model_gpu(self.safety_checker_model)
|
||||
|
||||
single = False
|
||||
if not isinstance(images, list) or isinstance(images, np.ndarray):
|
||||
images = [images]
|
||||
single = True
|
||||
|
||||
safety_checker_input = self.clip_image_processor(images, return_tensors="pt")
|
||||
safety_checker_input.to(device=self.load_device)
|
||||
checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images,
|
||||
clip_input=safety_checker_input.pixel_values)
|
||||
checked_images = [image.astype(np.uint8) for image in checked_images]
|
||||
|
||||
if single:
|
||||
checked_images = checked_images[0]
|
||||
|
||||
return checked_images
|
||||
|
||||
|
||||
default_censor = Censor().censor
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
{
|
||||
"_name_or_path": "clip-vit-large-patch14/",
|
||||
"architectures": [
|
||||
"SafetyChecker"
|
||||
],
|
||||
"initializer_factor": 1.0,
|
||||
"logit_scale_init_value": 2.6592,
|
||||
"model_type": "clip",
|
||||
"projection_dim": 768,
|
||||
"text_config": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": false,
|
||||
"architectures": null,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"bos_token_id": 0,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": null,
|
||||
"decoder_start_token_id": null,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"dropout": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 2,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layer_norm_eps": 1e-05,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 77,
|
||||
"min_length": 0,
|
||||
"model_type": "clip_text_model",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 12,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_hidden_layers": 12,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": 1,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"sep_token_id": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"transformers_version": "4.21.0.dev0",
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false,
|
||||
"vocab_size": 49408
|
||||
},
|
||||
"text_config_dict": {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12
|
||||
},
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": null,
|
||||
"vision_config": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": false,
|
||||
"architectures": null,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"bos_token_id": null,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": null,
|
||||
"decoder_start_token_id": null,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"dropout": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": null,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"image_size": 224,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layer_norm_eps": 1e-05,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"min_length": 0,
|
||||
"model_type": "clip_vision_model",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 16,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_hidden_layers": 24,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": null,
|
||||
"patch_size": 14,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"sep_token_id": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"transformers_version": "4.21.0.dev0",
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false
|
||||
},
|
||||
"vision_config_dict": {
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"crop_size": 224,
|
||||
"do_center_crop": true,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": true,
|
||||
"do_resize": true,
|
||||
"feature_extractor_type": "CLIPFeatureExtractor",
|
||||
"image_mean": [
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073
|
||||
],
|
||||
"image_std": [
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711
|
||||
],
|
||||
"resample": 3,
|
||||
"size": 224
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def cosine_distance(image_embeds, text_embeds):
|
||||
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
||||
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
||||
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
||||
|
||||
|
||||
class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
main_input_name = "clip_input"
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = CLIPVisionModel(config.vision_config)
|
||||
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
|
||||
|
||||
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
|
||||
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
|
||||
|
||||
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
||||
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
|
||||
|
||||
result = []
|
||||
batch_size = image_embeds.shape[0]
|
||||
for i in range(batch_size):
|
||||
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
|
||||
|
||||
# increase this value to create a stronger `nfsw` filter
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
|
||||
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
||||
images[idx] = torch.zeros_like(images[idx]) # black image
|
||||
else:
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
logger.warning(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
|
||||
" Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
|
||||
|
||||
# increase this value to create a stronger `nsfw` filter
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
|
||||
# special_scores = special_scores.round(decimals=3)
|
||||
special_care = torch.any(special_scores > 0, dim=1)
|
||||
special_adjustment = special_care * 0.01
|
||||
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
|
||||
|
||||
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
||||
# concept_scores = concept_scores.round(decimals=3)
|
||||
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
|
||||
|
||||
images[has_nsfw_concepts] = 0.0 # black image
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
|
@ -1 +1 @@
|
|||
version = '2.3.3 (mashb1t)'
|
||||
version = '2.4.0-rc3 (mashb1t)'
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@
|
|||
"Disable seed increment": "Disable seed increment",
|
||||
"Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.",
|
||||
"Read wildcards in order": "Read wildcards in order",
|
||||
"Black Out NSFW": "Black Out NSFW",
|
||||
"Use black image if NSFW is detected.": "Use black image if NSFW is detected.",
|
||||
"\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
|
||||
"Image Style": "Image Style",
|
||||
"Fooocus V2": "Fooocus V2",
|
||||
|
|
|
|||
|
|
@ -46,12 +46,13 @@ def worker():
|
|||
import fooocus_version
|
||||
import args_manager
|
||||
|
||||
from modules.censor import censor_batch, censor_single
|
||||
from modules.sdxl_styles import get_random_style, random_style_name, apply_style, apply_wildcards, fooocus_expansion, apply_arrays
|
||||
from extras.censor import default_censor
|
||||
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
|
||||
from modules.private_logger import log
|
||||
from extras.expansion import safe_str
|
||||
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
|
||||
get_shape_ceil, resample_image, erode_or_dilate, get_enabled_loras
|
||||
from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
|
||||
get_shape_ceil, resample_image, erode_or_dilate, get_enabled_loras,
|
||||
parse_lora_references_from_prompt, apply_wildcards)
|
||||
from modules.upscaler import perform_upscale
|
||||
from modules.flags import Performance
|
||||
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
||||
|
|
@ -72,13 +73,14 @@ def worker():
|
|||
print(f'[Fooocus] {text}')
|
||||
async_task.yields.append(['preview', (number, text, None)])
|
||||
|
||||
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13):
|
||||
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False,
|
||||
progressbar_index=flags.preparation_step_count):
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
|
||||
if censor and (modules.config.default_black_out_nsfw or black_out_nsfw):
|
||||
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
|
||||
imgs = censor_batch(imgs)
|
||||
imgs = default_censor(imgs)
|
||||
|
||||
async_task.results = async_task.results + imgs
|
||||
|
||||
|
|
@ -156,7 +158,8 @@ def worker():
|
|||
base_model_name = args.pop()
|
||||
refiner_model_name = args.pop()
|
||||
refiner_switch = args.pop()
|
||||
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
|
||||
loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in
|
||||
range(modules.config.default_max_lora_number)])
|
||||
input_image_checkbox = args.pop()
|
||||
current_tab = args.pop()
|
||||
uov_method = args.pop()
|
||||
|
|
@ -206,7 +209,8 @@ def worker():
|
|||
inpaint_erode_or_dilate = args.pop()
|
||||
|
||||
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
|
||||
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
|
||||
metadata_scheme = MetadataScheme(
|
||||
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
|
||||
|
||||
cn_tasks = {x: [] for x in flags.ip_list}
|
||||
for _ in range(flags.controlnet_image_count):
|
||||
|
|
@ -464,14 +468,17 @@ def worker():
|
|||
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
|
||||
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
||||
|
||||
progressbar(async_task, 3, 'Loading models ...')
|
||||
progressbar(async_task, 2, 'Loading models ...')
|
||||
|
||||
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||
|
||||
progressbar(async_task, 3, 'Processing prompts ...')
|
||||
tasks = []
|
||||
|
||||
|
||||
for i in range(image_number):
|
||||
if disable_seed_increment:
|
||||
task_seed = seed % (constants.MAX_SEED + 1)
|
||||
|
|
@ -482,8 +489,10 @@ def worker():
|
|||
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
|
||||
task_prompt = apply_arrays(task_prompt, i)
|
||||
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
|
||||
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
|
||||
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
|
||||
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
||||
extra_positive_prompts]
|
||||
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
||||
extra_negative_prompts]
|
||||
|
||||
positive_basic_workloads = []
|
||||
negative_basic_workloads = []
|
||||
|
|
@ -526,25 +535,25 @@ def worker():
|
|||
|
||||
if use_expansion:
|
||||
for i, t in enumerate(tasks):
|
||||
progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...')
|
||||
progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
|
||||
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
|
||||
print(f'[Prompt Expansion] {expansion}')
|
||||
t['expansion'] = expansion
|
||||
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
progressbar(async_task, 7, f'Encoding positive #{i + 1} ...')
|
||||
progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
|
||||
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
if abs(float(cfg_scale) - 1.0) < 1e-4:
|
||||
t['uc'] = pipeline.clone_cond(t['c'])
|
||||
else:
|
||||
progressbar(async_task, 10, f'Encoding negative #{i + 1} ...')
|
||||
progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
|
||||
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
|
||||
|
||||
if len(goals) > 0:
|
||||
progressbar(async_task, 13, 'Image processing ...')
|
||||
progressbar(async_task, 7, 'Image processing ...')
|
||||
|
||||
if 'vary' in goals:
|
||||
if 'subtle' in uov_method:
|
||||
|
|
@ -565,7 +574,7 @@ def worker():
|
|||
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
|
||||
|
||||
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
||||
progressbar(async_task, 13, 'VAE encoding ...')
|
||||
progressbar(async_task, 8, 'VAE encoding ...')
|
||||
|
||||
candidate_vae, _ = pipeline.get_candidate_vae(
|
||||
steps=steps,
|
||||
|
|
@ -582,7 +591,7 @@ def worker():
|
|||
|
||||
if 'upscale' in goals:
|
||||
H, W, C = uov_input_image.shape
|
||||
progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...')
|
||||
progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
|
||||
uov_input_image = perform_upscale(uov_input_image)
|
||||
print(f'Image upscaled.')
|
||||
|
||||
|
|
@ -615,10 +624,11 @@ def worker():
|
|||
direct_return = False
|
||||
|
||||
if direct_return:
|
||||
d = [('Upscale', 'upscale', 'Fast 2x')]
|
||||
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
|
||||
if modules.config.default_black_out_nsfw or black_out_nsfw:
|
||||
progressbar(async_task, 100, 'Checking for NSFW content ...')
|
||||
uov_input_image = censor_single(uov_input_image)
|
||||
uov_input_image = default_censor(uov_input_image)
|
||||
progressbar(async_task, 100, 'Saving image to system ...')
|
||||
uov_input_image_path = log(uov_input_image, d, output_format=output_format)
|
||||
yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True)
|
||||
return
|
||||
|
|
@ -630,7 +640,7 @@ def worker():
|
|||
denoising_strength = overwrite_upscale_strength
|
||||
|
||||
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
||||
progressbar(async_task, 13, 'VAE encoding ...')
|
||||
progressbar(async_task, 10, 'VAE encoding ...')
|
||||
|
||||
candidate_vae, _ = pipeline.get_candidate_vae(
|
||||
steps=steps,
|
||||
|
|
@ -687,7 +697,7 @@ def worker():
|
|||
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True)
|
||||
return
|
||||
|
||||
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
|
||||
progressbar(async_task, 11, 'VAE Inpaint encoding ...')
|
||||
|
||||
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
|
||||
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
|
||||
|
|
@ -707,7 +717,7 @@ def worker():
|
|||
|
||||
latent_swap = None
|
||||
if candidate_vae_swap is not None:
|
||||
progressbar(async_task, 13, 'VAE SD15 encoding ...')
|
||||
progressbar(async_task, 12, 'VAE SD15 encoding ...')
|
||||
latent_swap = core.encode_vae(
|
||||
vae=candidate_vae_swap,
|
||||
pixels=inpaint_pixel_fill)['samples']
|
||||
|
|
@ -833,15 +843,17 @@ def worker():
|
|||
zsnr=False)[0]
|
||||
print(f'Using {scheduler_name} scheduler.')
|
||||
|
||||
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])
|
||||
async_task.yields.append(['preview', (flags.preparation_step_count, 'Moving model to GPU ...', None)])
|
||||
|
||||
def callback(step, x0, x, total_steps, y):
|
||||
done_steps = current_task_id * steps + step
|
||||
async_task.yields.append(['preview', (
|
||||
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
||||
f'Sampling Image {current_task_id + 1}/{image_number}, Step {step + 1}/{total_steps} ...', y)])
|
||||
int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(done_steps) / float(all_steps)),
|
||||
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{image_number} ...', y)])
|
||||
|
||||
for current_task_id, task in enumerate(tasks):
|
||||
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float(current_task_id * steps) / float(all_steps))
|
||||
progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{image_number} ...')
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
|
|
@ -885,16 +897,19 @@ def worker():
|
|||
|
||||
img_paths = []
|
||||
|
||||
current_progress = int(flags.preparation_step_count + (100 - flags.preparation_step_count) * float((current_task_id + 1) * steps) / float(all_steps))
|
||||
if modules.config.default_black_out_nsfw or black_out_nsfw:
|
||||
progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)),
|
||||
'Checking for NSFW content ...')
|
||||
imgs = censor_batch(imgs)
|
||||
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
|
||||
imgs = default_censor(imgs)
|
||||
|
||||
progressbar(async_task, current_progress, f'Saving image {current_task_id + 1}/{image_number} to system ...')
|
||||
|
||||
for x in imgs:
|
||||
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
||||
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
||||
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
|
||||
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
|
||||
('Styles', 'styles',
|
||||
str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
|
||||
('Performance', 'performance', performance_selection.value)]
|
||||
|
||||
if performance_selection.steps() != steps:
|
||||
|
|
@ -917,7 +932,8 @@ def worker():
|
|||
if refiner_swap_method != flags.refiner_swap_method:
|
||||
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
|
||||
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
|
||||
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
|
||||
d.append(
|
||||
('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
|
||||
|
||||
d.append(('Sampler', 'sampler', sampler_name))
|
||||
d.append(('Scheduler', 'scheduler', scheduler_name))
|
||||
|
|
@ -937,16 +953,13 @@ def worker():
|
|||
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
||||
task['log_negative_prompt'], task['negative'],
|
||||
steps, base_model_name, refiner_model_name, loras, vae_name)
|
||||
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
||||
d.append(('Metadata Scheme', 'metadata_scheme',
|
||||
metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
||||
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
|
||||
img_paths.append(log(x, d, metadata_parser, output_format, task))
|
||||
|
||||
yield_result(async_task, img_paths, black_out_nsfw, False,
|
||||
do_not_show_finished_images=len(tasks) == 1
|
||||
or disable_intermediate_results
|
||||
or performance_selection == Performance.EXTREME_SPEED
|
||||
or performance_selection == Performance.LIGHTNING)
|
||||
|
||||
do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
||||
if async_task.last_stop == 'skip':
|
||||
print('User skipped')
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
|
||||
import numpy as np
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
import modules.config
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
|
||||
def numpy_to_pil(image):
|
||||
image = (image * 255).round().astype("uint8")
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
return pil_image
|
||||
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def censor_single(x):
|
||||
x_checked_image, has_nsfw_concept = check_safety(x)
|
||||
|
||||
# replace image with black pixels, keep dimensions
|
||||
# workaround due to different numpy / pytorch image matrix format
|
||||
if has_nsfw_concept[0]:
|
||||
imageshape = x_checked_image.shape
|
||||
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
|
||||
|
||||
return x_checked_image
|
||||
|
||||
|
||||
def censor_batch(images):
|
||||
images = [censor_single(image) for image in images]
|
||||
|
||||
return images
|
||||
|
|
@ -8,7 +8,8 @@ import modules.flags
|
|||
import modules.sdxl_styles
|
||||
|
||||
from modules.model_loader import load_file_from_url
|
||||
from modules.util import get_files_from_folder, makedirs_with_log
|
||||
from modules.util import makedirs_with_log
|
||||
from modules.extra_utils import get_files_from_folder
|
||||
from modules.flags import OutputFormat, Performance, MetadataScheme
|
||||
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ def get_config_path(key, default_value):
|
|||
else:
|
||||
return os.path.abspath(default_value)
|
||||
|
||||
|
||||
wildcards_max_bfs_depth = 64
|
||||
config_path = get_config_path('config_path', "./config.txt")
|
||||
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
|
||||
config_dict = {}
|
||||
|
|
@ -199,6 +200,7 @@ path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vi
|
|||
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
||||
path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
|
||||
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
|
||||
path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/')
|
||||
path_outputs = get_path_output()
|
||||
|
||||
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
||||
|
|
@ -463,6 +465,11 @@ example_inpaint_prompts = get_config_item_or_set_default(
|
|||
],
|
||||
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
|
||||
)
|
||||
default_black_out_nsfw = get_config_item_or_set_default(
|
||||
key='default_black_out_nsfw',
|
||||
default_value=False,
|
||||
validator=lambda x: isinstance(x, bool)
|
||||
)
|
||||
default_save_metadata_to_images = get_config_item_or_set_default(
|
||||
key='default_save_metadata_to_images',
|
||||
default_value=False,
|
||||
|
|
@ -731,5 +738,13 @@ def downloading_upscale_model():
|
|||
)
|
||||
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||
|
||||
def downloading_safety_checker_model():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin',
|
||||
model_dir=path_safety_checker,
|
||||
file_name='stable-diffusion-safety-checker.bin'
|
||||
)
|
||||
return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin')
|
||||
|
||||
|
||||
update_files()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
import os
|
||||
|
||||
|
||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||
if not os.path.isdir(folder_path):
|
||||
raise ValueError("Folder path is not a valid directory.")
|
||||
|
||||
filenames = []
|
||||
|
||||
for root, _, files in os.walk(folder_path, topdown=False):
|
||||
relative_path = os.path.relpath(root, folder_path)
|
||||
if relative_path == ".":
|
||||
relative_path = ""
|
||||
for filename in sorted(files, key=lambda s: s.casefold()):
|
||||
_, file_extension = os.path.splitext(filename)
|
||||
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
||||
path = os.path.join(relative_path, filename)
|
||||
filenames.append(path)
|
||||
|
||||
return filenames
|
||||
|
|
@ -97,6 +97,7 @@ metadata_scheme = [
|
|||
]
|
||||
|
||||
controlnet_image_count = 4
|
||||
preparation_step_count = 13
|
||||
|
||||
|
||||
class OutputFormat(Enum):
|
||||
|
|
|
|||
|
|
@ -2,14 +2,12 @@ import os
|
|||
import re
|
||||
import json
|
||||
import math
|
||||
import modules.config
|
||||
|
||||
from modules.util import get_files_from_folder
|
||||
from modules.extra_utils import get_files_from_folder
|
||||
from random import Random
|
||||
|
||||
# cannot use modules.config - validators causing circular imports
|
||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||
wildcards_max_bfs_depth = 64
|
||||
|
||||
|
||||
def normalize_key(k):
|
||||
|
|
@ -25,7 +23,6 @@ def normalize_key(k):
|
|||
|
||||
|
||||
styles = {}
|
||||
|
||||
styles_files = get_files_from_folder(styles_path, ['.json'])
|
||||
|
||||
for x in ['sdxl_styles_fooocus.json',
|
||||
|
|
@ -65,34 +62,7 @@ def apply_style(style, positive):
|
|||
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
||||
|
||||
|
||||
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
|
||||
for _ in range(wildcards_max_bfs_depth):
|
||||
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
||||
if len(placeholders) == 0:
|
||||
return wildcard_text
|
||||
|
||||
print(f'[Wildcards] processing: {wildcard_text}')
|
||||
for placeholder in placeholders:
|
||||
try:
|
||||
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
||||
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
||||
words = [x for x in words if x != '']
|
||||
assert len(words) > 0
|
||||
if read_wildcards_in_order:
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
|
||||
else:
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
||||
except:
|
||||
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
||||
f'Using "{placeholder}" as a normal word.')
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
||||
print(f'[Wildcards] {wildcard_text}')
|
||||
|
||||
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
||||
return wildcard_text
|
||||
|
||||
|
||||
def get_words(arrays, totalMult, index):
|
||||
def get_words(arrays, total_mult, index):
|
||||
if len(arrays) == 1:
|
||||
return [arrays[0].split(',')[index]]
|
||||
else:
|
||||
|
|
@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index):
|
|||
index -= index % len(words)
|
||||
index /= len(words)
|
||||
index = math.floor(index)
|
||||
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index)
|
||||
return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
|
||||
|
||||
|
||||
def apply_arrays(text, index):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import typing
|
||||
|
||||
import numpy as np
|
||||
import datetime
|
||||
import random
|
||||
import math
|
||||
import os
|
||||
import cv2
|
||||
import re
|
||||
from typing import List, Tuple, AnyStr, NamedTuple
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
|
|
@ -14,8 +15,16 @@ from PIL import Image
|
|||
import modules.sdxl_styles
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
# Regexp compiled once. Matches entries with the following pattern:
|
||||
# <lora:some_lora:1>
|
||||
# <lora:aNotherLora:-1.6>
|
||||
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
|
||||
|
||||
HASH_SHA256_LENGTH = 10
|
||||
|
||||
|
||||
def erode_or_dilate(x, k):
|
||||
k = int(k)
|
||||
if k > 0:
|
||||
|
|
@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
|
|||
return date_string, os.path.abspath(result), filename
|
||||
|
||||
|
||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||
if not os.path.isdir(folder_path):
|
||||
raise ValueError("Folder path is not a valid directory.")
|
||||
|
||||
filenames = []
|
||||
|
||||
for root, dirs, files in os.walk(folder_path, topdown=False):
|
||||
relative_path = os.path.relpath(root, folder_path)
|
||||
if relative_path == ".":
|
||||
relative_path = ""
|
||||
for filename in sorted(files, key=lambda s: s.casefold()):
|
||||
_, file_extension = os.path.splitext(filename)
|
||||
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
||||
path = os.path.join(relative_path, filename)
|
||||
filenames.append(path)
|
||||
|
||||
return filenames
|
||||
|
||||
|
||||
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
if use_addnet_hash:
|
||||
|
|
@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
|
|||
return list(reversed(extracted)), real_prompt, negative_prompt
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
class PromptStyle(NamedTuple):
|
||||
name: str
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
|
|
@ -382,10 +372,6 @@ def get_file_from_folder_list(name, folders):
|
|||
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
||||
|
||||
|
||||
def ordinal_suffix(number: int) -> str:
|
||||
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
|
||||
|
||||
|
||||
def makedirs_with_log(path):
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
|
@ -394,4 +380,47 @@ def makedirs_with_log(path):
|
|||
|
||||
|
||||
def get_enabled_loras(loras: list) -> list:
|
||||
return [[lora[1], lora[2]] for lora in loras if lora[0]]
|
||||
return [(lora[1], lora[2]) for lora in loras if lora[0]]
|
||||
|
||||
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
|
||||
new_loras = []
|
||||
updated_loras = []
|
||||
for token in prompt.split(","):
|
||||
m = LORAS_PROMPT_PATTERN.match(token)
|
||||
|
||||
if m:
|
||||
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
||||
|
||||
for lora in loras + new_loras:
|
||||
if lora[0] != "None":
|
||||
updated_loras.append(lora)
|
||||
|
||||
return updated_loras[:loras_limit]
|
||||
|
||||
|
||||
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
|
||||
for _ in range(modules.config.wildcards_max_bfs_depth):
|
||||
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
||||
if len(placeholders) == 0:
|
||||
return wildcard_text
|
||||
|
||||
print(f'[Wildcards] processing: {wildcard_text}')
|
||||
for placeholder in placeholders:
|
||||
try:
|
||||
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
||||
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
||||
words = [x for x in words if x != '']
|
||||
assert len(words) > 0
|
||||
if read_wildcards_in_order:
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
|
||||
else:
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
||||
except:
|
||||
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
||||
f'Using "{placeholder}" as a normal word.')
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
||||
print(f'[Wildcards] {wildcard_text}')
|
||||
|
||||
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
||||
return wildcard_text
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
import sys
|
||||
import pathlib
|
||||
|
||||
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import unittest
|
||||
|
||||
from modules import util
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_can_parse_tokens_with_lora(self):
|
||||
test_cases = [
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
|
||||
},
|
||||
# Test can not exceed limit
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||
"output": [("hey-lora.safetensors", 0.4)],
|
||||
},
|
||||
# test Loras from UI take precedence over prompt
|
||||
{
|
||||
"input": (
|
||||
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
||||
[("hey-lora.safetensors", 0.4)],
|
||||
5,
|
||||
),
|
||||
"output": [
|
||||
("hey-lora.safetensors", 0.4),
|
||||
("l1.safetensors", 0.4),
|
||||
("l2.safetensors", -0.2),
|
||||
("l3.safetensors", 0.3),
|
||||
("l4.safetensors", 0.5),
|
||||
],
|
||||
},
|
||||
# Test lora specification not separated by comma are ignored, only latest specified is used
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||
"output": [("you-lora.safetensors", 0.2)],
|
||||
},
|
||||
{
|
||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||
"output": []
|
||||
}
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
prompt, loras, loras_limit = test["input"]
|
||||
expected = test["output"]
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
|
||||
self.assertEqual(expected, actual)
|
||||
3
webui.py
3
webui.py
|
|
@ -512,7 +512,8 @@ with shared.gradio_root:
|
|||
info='Use black image if NSFW is detected.')
|
||||
|
||||
black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x),
|
||||
inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False)
|
||||
inputs=black_out_nsfw, outputs=disable_preview, queue=False,
|
||||
show_progress=False)
|
||||
|
||||
if not args_manager.args.disable_metadata:
|
||||
save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,
|
||||
|
|
|
|||
Loading…
Reference in New Issue