feat: scan wildcard subdirectories (#2466)
* Fix typo * Scan wildcards recursively Adds a method for getting the top-most occurrence of a given file in a directory tree * Use already existing method for locating files * Fix issue with incorrect files being loaded When using the `name-filter` parameter in `get_model_filenames`, it doesn't guarantee the best match to be in the first index. This change adds a step to ensure the correct wildcard is being loaded. * feat: make path for wildcards configurable, cache filenames on refresh files, rename button variable * Fix formatting --------- Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
parent
400471f7af
commit
f6117180d4
|
|
@ -179,6 +179,7 @@ path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
|
||||||
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
|
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
|
||||||
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
||||||
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
||||||
|
path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
|
||||||
path_outputs = get_path_output()
|
path_outputs = get_path_output()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -508,22 +509,26 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
|
||||||
|
|
||||||
model_filenames = []
|
model_filenames = []
|
||||||
lora_filenames = []
|
lora_filenames = []
|
||||||
|
wildcard_filenames = []
|
||||||
|
|
||||||
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
|
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
|
||||||
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
|
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
|
||||||
|
|
||||||
|
|
||||||
def get_model_filenames(folder_paths, name_filter=None):
|
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
if extensions is None:
|
||||||
|
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||||
files = []
|
files = []
|
||||||
for folder in folder_paths:
|
for folder in folder_paths:
|
||||||
files += get_files_from_folder(folder, extensions, name_filter)
|
files += get_files_from_folder(folder, extensions, name_filter)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def update_all_model_names():
|
def update_files():
|
||||||
global model_filenames, lora_filenames
|
global model_filenames, lora_filenames, wildcard_filenames
|
||||||
model_filenames = get_model_filenames(paths_checkpoints)
|
model_filenames = get_model_filenames(paths_checkpoints)
|
||||||
lora_filenames = get_model_filenames(paths_loras)
|
lora_filenames = get_model_filenames(paths_loras)
|
||||||
|
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -647,4 +652,4 @@ def downloading_upscale_model():
|
||||||
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||||
|
|
||||||
|
|
||||||
update_all_model_names()
|
update_files()
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,12 @@ import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import modules.config
|
||||||
|
|
||||||
from modules.util import get_files_from_folder
|
from modules.util import get_files_from_folder
|
||||||
|
|
||||||
|
|
||||||
# cannot use modules.config - validators causing circular imports
|
# cannot use modules.config - validators causing circular imports
|
||||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||||
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
|
||||||
wildcards_max_bfs_depth = 64
|
wildcards_max_bfs_depth = 64
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,7 +59,7 @@ def apply_style(style, positive):
|
||||||
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
||||||
|
|
||||||
|
|
||||||
def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
def apply_wildcards(wildcard_text, rng):
|
||||||
for _ in range(wildcards_max_bfs_depth):
|
for _ in range(wildcards_max_bfs_depth):
|
||||||
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
||||||
if len(placeholders) == 0:
|
if len(placeholders) == 0:
|
||||||
|
|
@ -69,7 +68,8 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
||||||
print(f'[Wildcards] processing: {wildcard_text}')
|
print(f'[Wildcards] processing: {wildcard_text}')
|
||||||
for placeholder in placeholders:
|
for placeholder in placeholders:
|
||||||
try:
|
try:
|
||||||
words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
|
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
||||||
|
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
||||||
words = [x for x in words if x != '']
|
words = [x for x in words if x != '']
|
||||||
assert len(words) > 0
|
assert len(words) > 0
|
||||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
||||||
|
|
@ -82,8 +82,9 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
||||||
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
||||||
return wildcard_text
|
return wildcard_text
|
||||||
|
|
||||||
|
|
||||||
def get_words(arrays, totalMult, index):
|
def get_words(arrays, totalMult, index):
|
||||||
if(len(arrays) == 1):
|
if len(arrays) == 1:
|
||||||
return [arrays[0].split(',')[index]]
|
return [arrays[0].split(',')[index]]
|
||||||
else:
|
else:
|
||||||
words = arrays[0].split(',')
|
words = arrays[0].split(',')
|
||||||
|
|
|
||||||
|
|
@ -163,7 +163,7 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
|
||||||
return date_string, os.path.abspath(result), filename
|
return date_string, os.path.abspath(result), filename
|
||||||
|
|
||||||
|
|
||||||
def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||||
if not os.path.isdir(folder_path):
|
if not os.path.isdir(folder_path):
|
||||||
raise ValueError("Folder path is not a valid directory.")
|
raise ValueError("Folder path is not a valid directory.")
|
||||||
|
|
||||||
|
|
@ -175,7 +175,7 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
||||||
relative_path = ""
|
relative_path = ""
|
||||||
for filename in sorted(files, key=lambda s: s.casefold()):
|
for filename in sorted(files, key=lambda s: s.casefold()):
|
||||||
_, file_extension = os.path.splitext(filename)
|
_, file_extension = os.path.splitext(filename)
|
||||||
if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _):
|
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
||||||
path = os.path.join(relative_path, filename)
|
path = os.path.join(relative_path, filename)
|
||||||
filenames.append(path)
|
filenames.append(path)
|
||||||
|
|
||||||
|
|
|
||||||
9
webui.py
9
webui.py
|
|
@ -366,7 +366,7 @@ with shared.gradio_root:
|
||||||
lora_ctrls += [lora_enabled, lora_model, lora_weight]
|
lora_ctrls += [lora_enabled, lora_model, lora_weight]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
refresh_files = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||||
with gr.Tab(label='Advanced'):
|
with gr.Tab(label='Advanced'):
|
||||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
|
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
|
||||||
value=modules.config.default_cfg_scale,
|
value=modules.config.default_cfg_scale,
|
||||||
|
|
@ -512,19 +512,18 @@ with shared.gradio_root:
|
||||||
def dev_mode_checked(r):
|
def dev_mode_checked(r):
|
||||||
return gr.update(visible=r)
|
return gr.update(visible=r)
|
||||||
|
|
||||||
|
|
||||||
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
|
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
|
||||||
queue=False, show_progress=False)
|
queue=False, show_progress=False)
|
||||||
|
|
||||||
def model_refresh_clicked():
|
def refresh_files_clicked():
|
||||||
modules.config.update_all_model_names()
|
modules.config.update_files()
|
||||||
results = [gr.update(choices=modules.config.model_filenames)]
|
results = [gr.update(choices=modules.config.model_filenames)]
|
||||||
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
|
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||||
for i in range(modules.config.default_max_lora_number):
|
for i in range(modules.config.default_max_lora_number):
|
||||||
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
||||||
queue=False, show_progress=False)
|
queue=False, show_progress=False)
|
||||||
|
|
||||||
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
|
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue