feat: scan wildcard subdirectories (#2466)

* Fix typo

* Scan wildcards recursively

Adds a method for getting the top-most occurrence of a given file in a directory tree

* Use already existing method for locating files

* Fix issue with incorrect files being loaded

When using the `name-filter` parameter in `get_model_filenames`, it doesn't guarantee the best match to be in the first index. This change adds a step to ensure the correct wildcard is being loaded.

* feat: make path for wildcards configurable, cache filenames on refresh files, rename button variable

* Fix formatting

---------

Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
Cruxial 2024-03-10 21:35:41 +01:00 committed by GitHub
parent 400471f7af
commit f6117180d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 17 deletions

View File

@ -179,6 +179,7 @@ path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
path_outputs = get_path_output()
@ -508,22 +509,26 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
model_filenames = []
lora_filenames = []
wildcard_filenames = []
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
def get_model_filenames(folder_paths, name_filter=None):
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
files = []
for folder in folder_paths:
files += get_files_from_folder(folder, extensions, name_filter)
return files
def update_all_model_names():
global model_filenames, lora_filenames
def update_files():
global model_filenames, lora_filenames, wildcard_filenames
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
return
@ -647,4 +652,4 @@ def downloading_upscale_model():
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
update_all_model_names()
update_files()

View File

@ -2,13 +2,12 @@ import os
import re
import json
import math
import modules.config
from modules.util import get_files_from_folder
# cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
wildcards_max_bfs_depth = 64
@ -60,7 +59,7 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
def apply_wildcards(wildcard_text, rng):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
@ -69,7 +68,8 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
@ -82,8 +82,9 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_words(arrays, totalMult, index):
if(len(arrays) == 1):
if len(arrays) == 1:
return [arrays[0].split(',')[index]]
else:
words = arrays[0].split(',')

View File

@ -163,7 +163,7 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
return date_string, os.path.abspath(result), filename
def get_files_from_folder(folder_path, exensions=None, name_filter=None):
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
@ -175,7 +175,7 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _):
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)

View File

@ -366,7 +366,7 @@ with shared.gradio_root:
lora_ctrls += [lora_enabled, lora_model, lora_weight]
with gr.Row():
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
refresh_files = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
with gr.Tab(label='Advanced'):
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
value=modules.config.default_cfg_scale,
@ -512,19 +512,18 @@ with shared.gradio_root:
def dev_mode_checked(r):
return gr.update(visible=r)
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
queue=False, show_progress=False)
def model_refresh_clicked():
modules.config.update_all_model_names()
def refresh_files_clicked():
modules.config.update_files()
results = [gr.update(choices=modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
for i in range(modules.config.default_max_lora_number):
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
queue=False, show_progress=False)
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +