refactor: move apply_wildcards to module util

This commit is contained in:
Manuel Schmid 2024-05-18 16:43:02 +02:00
parent df543a5e65
commit a78b3841c9
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 33 additions and 34 deletions

View File

@ -46,8 +46,9 @@ def worker():
from modules.sdxl_styles import apply_style, fooocus_expansion, apply_arrays
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, parse_lora_references_from_prompt
from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras,
parse_lora_references_from_prompt, apply_wildcards)
from modules.upscaler import perform_upscale
from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme
@ -444,11 +445,11 @@ def worker():
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
task_prompt = modules.config.apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_arrays(task_prompt, i)
task_negative_prompt = modules.config.apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
positive_basic_workloads = []
negative_basic_workloads = []

View File

@ -680,32 +680,4 @@ def downloading_upscale_model():
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
update_files()

View File

@ -396,3 +396,29 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
return updated_loras[:loras_limit]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
for _ in range(modules.config.wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text