feat: scan wildcard subdirectories (#2466)
* Fix typo * Scan wildcards recursively Adds a method for getting the top-most occurrence of a given file in a directory tree * Use already existing method for locating files * Fix issue with incorrect files being loaded When using the `name-filter` parameter in `get_model_filenames`, it doesn't guarantee the best match to be in the first index. This change adds a step to ensure the correct wildcard is being loaded. * feat: make path for wildcards configurable, cache filenames on refresh files, rename button variable * Fix formatting --------- Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
parent
400471f7af
commit
f6117180d4
|
|
@ -179,6 +179,7 @@ path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
|
|||
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
|
||||
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
||||
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
||||
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
|
||||
path_outputs = get_path_output()
|
||||
|
||||
|
||||
|
|
@ -508,22 +509,26 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
|
|||
|
||||
model_filenames = []
|
||||
lora_filenames = []
|
||||
wildcard_filenames = []
|
||||
|
||||
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
|
||||
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
|
||||
|
||||
|
||||
def get_model_filenames(folder_paths, name_filter=None):
|
||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||
if extensions is None:
|
||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||
files = []
|
||||
for folder in folder_paths:
|
||||
files += get_files_from_folder(folder, extensions, name_filter)
|
||||
return files
|
||||
|
||||
|
||||
def update_all_model_names():
|
||||
global model_filenames, lora_filenames
|
||||
def update_files():
|
||||
global model_filenames, lora_filenames, wildcard_filenames
|
||||
model_filenames = get_model_filenames(paths_checkpoints)
|
||||
lora_filenames = get_model_filenames(paths_loras)
|
||||
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -647,4 +652,4 @@ def downloading_upscale_model():
|
|||
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||
|
||||
|
||||
update_all_model_names()
|
||||
update_files()
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@ import os
|
|||
import re
|
||||
import json
|
||||
import math
|
||||
import modules.config
|
||||
|
||||
from modules.util import get_files_from_folder
|
||||
|
||||
|
||||
# cannot use modules.config - validators causing circular imports
|
||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
||||
wildcards_max_bfs_depth = 64
|
||||
|
||||
|
||||
|
|
@ -60,7 +59,7 @@ def apply_style(style, positive):
|
|||
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
||||
|
||||
|
||||
def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
||||
def apply_wildcards(wildcard_text, rng):
|
||||
for _ in range(wildcards_max_bfs_depth):
|
||||
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
||||
if len(placeholders) == 0:
|
||||
|
|
@ -69,7 +68,8 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
|||
print(f'[Wildcards] processing: {wildcard_text}')
|
||||
for placeholder in placeholders:
|
||||
try:
|
||||
words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
|
||||
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
||||
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
||||
words = [x for x in words if x != '']
|
||||
assert len(words) > 0
|
||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
||||
|
|
@ -82,8 +82,9 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
|||
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
||||
return wildcard_text
|
||||
|
||||
|
||||
def get_words(arrays, totalMult, index):
|
||||
if(len(arrays) == 1):
|
||||
if len(arrays) == 1:
|
||||
return [arrays[0].split(',')[index]]
|
||||
else:
|
||||
words = arrays[0].split(',')
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
|
|||
return date_string, os.path.abspath(result), filename
|
||||
|
||||
|
||||
def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||
if not os.path.isdir(folder_path):
|
||||
raise ValueError("Folder path is not a valid directory.")
|
||||
|
||||
|
|
@ -175,7 +175,7 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
|||
relative_path = ""
|
||||
for filename in sorted(files, key=lambda s: s.casefold()):
|
||||
_, file_extension = os.path.splitext(filename)
|
||||
if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _):
|
||||
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
||||
path = os.path.join(relative_path, filename)
|
||||
filenames.append(path)
|
||||
|
||||
|
|
|
|||
9
webui.py
9
webui.py
|
|
@ -366,7 +366,7 @@ with shared.gradio_root:
|
|||
lora_ctrls += [lora_enabled, lora_model, lora_weight]
|
||||
|
||||
with gr.Row():
|
||||
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||
refresh_files = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||
with gr.Tab(label='Advanced'):
|
||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
|
||||
value=modules.config.default_cfg_scale,
|
||||
|
|
@ -512,19 +512,18 @@ with shared.gradio_root:
|
|||
def dev_mode_checked(r):
|
||||
return gr.update(visible=r)
|
||||
|
||||
|
||||
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
|
||||
queue=False, show_progress=False)
|
||||
|
||||
def model_refresh_clicked():
|
||||
modules.config.update_all_model_names()
|
||||
def refresh_files_clicked():
|
||||
modules.config.update_files()
|
||||
results = [gr.update(choices=modules.config.model_filenames)]
|
||||
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||
for i in range(modules.config.default_max_lora_number):
|
||||
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||
return results
|
||||
|
||||
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
||||
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
||||
queue=False, show_progress=False)
|
||||
|
||||
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
|
||||
|
|
|
|||
Loading…
Reference in New Issue