feat: add support for lora inline prompt references (#2323)
* Adding support to inline prompt references * Added unittests * Added an initial documentation for development guidelines * Added a negative number * renamed parameter * removed wrongly committed file * Code fixes * Fixed circular reference * Fixed typo. Added TODO * Fixed merge * Code cleanup * Added missing refernce function * Removed function from util.py... again... * Update modules/async_worker.py Implemented suggested change Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> * Removed another circular reference * Renamed module * Addressed PR comments * Added return type to function * refactor: move apply_wildcards to module util * refactor: code cleanup, unify usage of tuples in lora list * docs: add instructions for running unittests on embedded python, code cleanup * refactor: code cleanup, move makedirs_with_log back to util --------- Co-authored-by: cantor-set <cantor-set@no-email.net> Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Co-authored-by: Manuel Schmid <dev@mash1t.de>
This commit is contained in:
parent
3a55e7e391
commit
3bae73e23e
|
|
@ -0,0 +1,11 @@
|
||||||
|
## Running unit tests
|
||||||
|
|
||||||
|
Native python:
|
||||||
|
```
|
||||||
|
python -m unittest tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
Embedded python (Windows zip file installation method):
|
||||||
|
```
|
||||||
|
..\python_embeded\python.exe -m unittest
|
||||||
|
```
|
||||||
|
|
@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all
|
||||||
|
|
||||||
patch_all()
|
patch_all()
|
||||||
|
|
||||||
|
|
||||||
class AsyncTask:
|
class AsyncTask:
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
@ -44,11 +45,12 @@ def worker():
|
||||||
import args_manager
|
import args_manager
|
||||||
|
|
||||||
from extras.censor import censor_batch, censor_single
|
from extras.censor import censor_batch, censor_single
|
||||||
from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name
|
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
|
||||||
from modules.private_logger import log
|
from modules.private_logger import log
|
||||||
from extras.expansion import safe_str
|
from extras.expansion import safe_str
|
||||||
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
|
from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
|
||||||
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras
|
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras,
|
||||||
|
parse_lora_references_from_prompt, apply_wildcards)
|
||||||
from modules.upscaler import perform_upscale
|
from modules.upscaler import perform_upscale
|
||||||
from modules.flags import Performance
|
from modules.flags import Performance
|
||||||
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
||||||
|
|
@ -69,7 +71,8 @@ def worker():
|
||||||
print(f'[Fooocus] {text}')
|
print(f'[Fooocus] {text}')
|
||||||
async_task.yields.append(['preview', (number, text, None)])
|
async_task.yields.append(['preview', (number, text, None)])
|
||||||
|
|
||||||
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13):
|
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False,
|
||||||
|
progressbar_index=13):
|
||||||
if not isinstance(imgs, list):
|
if not isinstance(imgs, list):
|
||||||
imgs = [imgs]
|
imgs = [imgs]
|
||||||
|
|
||||||
|
|
@ -152,7 +155,8 @@ def worker():
|
||||||
base_model_name = args.pop()
|
base_model_name = args.pop()
|
||||||
refiner_model_name = args.pop()
|
refiner_model_name = args.pop()
|
||||||
refiner_switch = args.pop()
|
refiner_switch = args.pop()
|
||||||
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
|
loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in
|
||||||
|
range(modules.config.default_max_lora_number)])
|
||||||
input_image_checkbox = args.pop()
|
input_image_checkbox = args.pop()
|
||||||
current_tab = args.pop()
|
current_tab = args.pop()
|
||||||
uov_method = args.pop()
|
uov_method = args.pop()
|
||||||
|
|
@ -202,7 +206,8 @@ def worker():
|
||||||
inpaint_erode_or_dilate = args.pop()
|
inpaint_erode_or_dilate = args.pop()
|
||||||
|
|
||||||
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
|
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
|
||||||
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
|
metadata_scheme = MetadataScheme(
|
||||||
|
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
|
||||||
|
|
||||||
cn_tasks = {x: [] for x in flags.ip_list}
|
cn_tasks = {x: [] for x in flags.ip_list}
|
||||||
for _ in range(flags.controlnet_image_count):
|
for _ in range(flags.controlnet_image_count):
|
||||||
|
|
@ -433,6 +438,9 @@ def worker():
|
||||||
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
||||||
|
|
||||||
progressbar(async_task, 3, 'Loading models ...')
|
progressbar(async_task, 3, 'Loading models ...')
|
||||||
|
|
||||||
|
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||||
|
|
||||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||||
|
|
@ -450,8 +458,10 @@ def worker():
|
||||||
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
|
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
|
||||||
task_prompt = apply_arrays(task_prompt, i)
|
task_prompt = apply_arrays(task_prompt, i)
|
||||||
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
|
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
|
||||||
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
|
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
||||||
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
|
extra_positive_prompts]
|
||||||
|
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
||||||
|
extra_negative_prompts]
|
||||||
|
|
||||||
positive_basic_workloads = []
|
positive_basic_workloads = []
|
||||||
negative_basic_workloads = []
|
negative_basic_workloads = []
|
||||||
|
|
@ -652,7 +662,8 @@ def worker():
|
||||||
)
|
)
|
||||||
|
|
||||||
if debugging_inpaint_preprocessor:
|
if debugging_inpaint_preprocessor:
|
||||||
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True)
|
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw,
|
||||||
|
do_not_show_finished_images=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
|
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
|
||||||
|
|
@ -807,7 +818,8 @@ def worker():
|
||||||
done_steps = current_task_id * steps + step
|
done_steps = current_task_id * steps + step
|
||||||
async_task.yields.append(['preview', (
|
async_task.yields.append(['preview', (
|
||||||
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
||||||
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)])
|
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling',
|
||||||
|
y)])
|
||||||
|
|
||||||
for current_task_id, task in enumerate(tasks):
|
for current_task_id, task in enumerate(tasks):
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
|
|
@ -862,7 +874,8 @@ def worker():
|
||||||
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
||||||
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
||||||
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
|
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
|
||||||
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
|
('Styles', 'styles',
|
||||||
|
str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
|
||||||
('Performance', 'performance', performance_selection.value)]
|
('Performance', 'performance', performance_selection.value)]
|
||||||
|
|
||||||
if performance_selection.steps() != steps:
|
if performance_selection.steps() != steps:
|
||||||
|
|
@ -885,7 +898,8 @@ def worker():
|
||||||
if refiner_swap_method != flags.refiner_swap_method:
|
if refiner_swap_method != flags.refiner_swap_method:
|
||||||
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
|
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
|
||||||
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
|
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
|
||||||
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
|
d.append(
|
||||||
|
('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
|
||||||
|
|
||||||
d.append(('Sampler', 'sampler', sampler_name))
|
d.append(('Sampler', 'sampler', sampler_name))
|
||||||
d.append(('Scheduler', 'scheduler', scheduler_name))
|
d.append(('Scheduler', 'scheduler', scheduler_name))
|
||||||
|
|
@ -905,11 +919,13 @@ def worker():
|
||||||
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
||||||
task['log_negative_prompt'], task['negative'],
|
task['log_negative_prompt'], task['negative'],
|
||||||
steps, base_model_name, refiner_model_name, loras, vae_name)
|
steps, base_model_name, refiner_model_name, loras, vae_name)
|
||||||
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
d.append(('Metadata Scheme', 'metadata_scheme',
|
||||||
|
metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
||||||
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
|
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
|
||||||
img_paths.append(log(x, d, metadata_parser, output_format, task))
|
img_paths.append(log(x, d, metadata_parser, output_format, task))
|
||||||
|
|
||||||
yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
|
yield_result(async_task, img_paths, black_out_nsfw, False,
|
||||||
|
do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
|
||||||
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
||||||
if async_task.last_stop == 'skip':
|
if async_task.last_stop == 'skip':
|
||||||
print('User skipped')
|
print('User skipped')
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ import modules.flags
|
||||||
import modules.sdxl_styles
|
import modules.sdxl_styles
|
||||||
|
|
||||||
from modules.model_loader import load_file_from_url
|
from modules.model_loader import load_file_from_url
|
||||||
from modules.util import get_files_from_folder, makedirs_with_log
|
from modules.util import makedirs_with_log
|
||||||
|
from modules.extra_utils import get_files_from_folder
|
||||||
from modules.flags import OutputFormat, Performance, MetadataScheme
|
from modules.flags import OutputFormat, Performance, MetadataScheme
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,7 +21,7 @@ def get_config_path(key, default_value):
|
||||||
else:
|
else:
|
||||||
return os.path.abspath(default_value)
|
return os.path.abspath(default_value)
|
||||||
|
|
||||||
|
wildcards_max_bfs_depth = 64
|
||||||
config_path = get_config_path('config_path', "./config.txt")
|
config_path = get_config_path('config_path', "./config.txt")
|
||||||
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
|
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
raise ValueError("Folder path is not a valid directory.")
|
||||||
|
|
||||||
|
filenames = []
|
||||||
|
|
||||||
|
for root, _, files in os.walk(folder_path, topdown=False):
|
||||||
|
relative_path = os.path.relpath(root, folder_path)
|
||||||
|
if relative_path == ".":
|
||||||
|
relative_path = ""
|
||||||
|
for filename in sorted(files, key=lambda s: s.casefold()):
|
||||||
|
_, file_extension = os.path.splitext(filename)
|
||||||
|
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
||||||
|
path = os.path.join(relative_path, filename)
|
||||||
|
filenames.append(path)
|
||||||
|
|
||||||
|
return filenames
|
||||||
|
|
@ -2,14 +2,12 @@ import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import modules.config
|
|
||||||
|
|
||||||
from modules.util import get_files_from_folder
|
from modules.extra_utils import get_files_from_folder
|
||||||
from random import Random
|
from random import Random
|
||||||
|
|
||||||
# cannot use modules.config - validators causing circular imports
|
# cannot use modules.config - validators causing circular imports
|
||||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||||
wildcards_max_bfs_depth = 64
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_key(k):
|
def normalize_key(k):
|
||||||
|
|
@ -25,7 +23,6 @@ def normalize_key(k):
|
||||||
|
|
||||||
|
|
||||||
styles = {}
|
styles = {}
|
||||||
|
|
||||||
styles_files = get_files_from_folder(styles_path, ['.json'])
|
styles_files = get_files_from_folder(styles_path, ['.json'])
|
||||||
|
|
||||||
for x in ['sdxl_styles_fooocus.json',
|
for x in ['sdxl_styles_fooocus.json',
|
||||||
|
|
@ -65,34 +62,7 @@ def apply_style(style, positive):
|
||||||
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
||||||
|
|
||||||
|
|
||||||
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
|
def get_words(arrays, total_mult, index):
|
||||||
for _ in range(wildcards_max_bfs_depth):
|
|
||||||
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
|
||||||
if len(placeholders) == 0:
|
|
||||||
return wildcard_text
|
|
||||||
|
|
||||||
print(f'[Wildcards] processing: {wildcard_text}')
|
|
||||||
for placeholder in placeholders:
|
|
||||||
try:
|
|
||||||
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
|
||||||
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
|
||||||
words = [x for x in words if x != '']
|
|
||||||
assert len(words) > 0
|
|
||||||
if read_wildcards_in_order:
|
|
||||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
|
|
||||||
else:
|
|
||||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
|
||||||
except:
|
|
||||||
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
|
||||||
f'Using "{placeholder}" as a normal word.')
|
|
||||||
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
|
||||||
print(f'[Wildcards] {wildcard_text}')
|
|
||||||
|
|
||||||
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
|
||||||
return wildcard_text
|
|
||||||
|
|
||||||
|
|
||||||
def get_words(arrays, totalMult, index):
|
|
||||||
if len(arrays) == 1:
|
if len(arrays) == 1:
|
||||||
return [arrays[0].split(',')[index]]
|
return [arrays[0].split(',')[index]]
|
||||||
else:
|
else:
|
||||||
|
|
@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index):
|
||||||
index -= index % len(words)
|
index -= index % len(words)
|
||||||
index /= len(words)
|
index /= len(words)
|
||||||
index = math.floor(index)
|
index = math.floor(index)
|
||||||
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index)
|
return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
|
||||||
|
|
||||||
|
|
||||||
def apply_arrays(text, index):
|
def apply_arrays(text, index):
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
import typing
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import re
|
||||||
|
from typing import List, Tuple, AnyStr, NamedTuple
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
|
@ -14,8 +15,16 @@ from PIL import Image
|
||||||
import modules.sdxl_styles
|
import modules.sdxl_styles
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
# Regexp compiled once. Matches entries with the following pattern:
|
||||||
|
# <lora:some_lora:1>
|
||||||
|
# <lora:aNotherLora:-1.6>
|
||||||
|
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
|
||||||
|
|
||||||
HASH_SHA256_LENGTH = 10
|
HASH_SHA256_LENGTH = 10
|
||||||
|
|
||||||
|
|
||||||
def erode_or_dilate(x, k):
|
def erode_or_dilate(x, k):
|
||||||
k = int(k)
|
k = int(k)
|
||||||
if k > 0:
|
if k > 0:
|
||||||
|
|
@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
|
||||||
return date_string, os.path.abspath(result), filename
|
return date_string, os.path.abspath(result), filename
|
||||||
|
|
||||||
|
|
||||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
|
||||||
if not os.path.isdir(folder_path):
|
|
||||||
raise ValueError("Folder path is not a valid directory.")
|
|
||||||
|
|
||||||
filenames = []
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(folder_path, topdown=False):
|
|
||||||
relative_path = os.path.relpath(root, folder_path)
|
|
||||||
if relative_path == ".":
|
|
||||||
relative_path = ""
|
|
||||||
for filename in sorted(files, key=lambda s: s.casefold()):
|
|
||||||
_, file_extension = os.path.splitext(filename)
|
|
||||||
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
|
|
||||||
path = os.path.join(relative_path, filename)
|
|
||||||
filenames.append(path)
|
|
||||||
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
|
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
|
||||||
print(f"Calculating sha256 for {filename}: ", end='')
|
print(f"Calculating sha256 for {filename}: ", end='')
|
||||||
if use_addnet_hash:
|
if use_addnet_hash:
|
||||||
|
|
@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
|
||||||
return list(reversed(extracted)), real_prompt, negative_prompt
|
return list(reversed(extracted)), real_prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(typing.NamedTuple):
|
class PromptStyle(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str
|
||||||
negative_prompt: str
|
negative_prompt: str
|
||||||
|
|
@ -394,4 +384,47 @@ def makedirs_with_log(path):
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_loras(loras: list) -> list:
|
def get_enabled_loras(loras: list) -> list:
|
||||||
return [[lora[1], lora[2]] for lora in loras if lora[0]]
|
return [(lora[1], lora[2]) for lora in loras if lora[0]]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
|
||||||
|
new_loras = []
|
||||||
|
updated_loras = []
|
||||||
|
for token in prompt.split(","):
|
||||||
|
m = LORAS_PROMPT_PATTERN.match(token)
|
||||||
|
|
||||||
|
if m:
|
||||||
|
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
||||||
|
|
||||||
|
for lora in loras + new_loras:
|
||||||
|
if lora[0] != "None":
|
||||||
|
updated_loras.append(lora)
|
||||||
|
|
||||||
|
return updated_loras[:loras_limit]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
|
||||||
|
for _ in range(modules.config.wildcards_max_bfs_depth):
|
||||||
|
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
||||||
|
if len(placeholders) == 0:
|
||||||
|
return wildcard_text
|
||||||
|
|
||||||
|
print(f'[Wildcards] processing: {wildcard_text}')
|
||||||
|
for placeholder in placeholders:
|
||||||
|
try:
|
||||||
|
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
|
||||||
|
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
|
||||||
|
words = [x for x in words if x != '']
|
||||||
|
assert len(words) > 0
|
||||||
|
if read_wildcards_in_order:
|
||||||
|
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
|
||||||
|
else:
|
||||||
|
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
||||||
|
except:
|
||||||
|
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
||||||
|
f'Using "{placeholder}" as a normal word.')
|
||||||
|
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
||||||
|
print(f'[Wildcards] {wildcard_text}')
|
||||||
|
|
||||||
|
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
||||||
|
return wildcard_text
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
import sys
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from modules import util
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(unittest.TestCase):
|
||||||
|
def test_can_parse_tokens_with_lora(self):
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||||
|
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
|
||||||
|
},
|
||||||
|
# Test can not exceed limit
|
||||||
|
{
|
||||||
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||||
|
"output": [("hey-lora.safetensors", 0.4)],
|
||||||
|
},
|
||||||
|
# test Loras from UI take precedence over prompt
|
||||||
|
{
|
||||||
|
"input": (
|
||||||
|
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
||||||
|
[("hey-lora.safetensors", 0.4)],
|
||||||
|
5,
|
||||||
|
),
|
||||||
|
"output": [
|
||||||
|
("hey-lora.safetensors", 0.4),
|
||||||
|
("l1.safetensors", 0.4),
|
||||||
|
("l2.safetensors", -0.2),
|
||||||
|
("l3.safetensors", 0.3),
|
||||||
|
("l4.safetensors", 0.5),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Test lora specification not separated by comma are ignored, only latest specified is used
|
||||||
|
{
|
||||||
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||||
|
"output": [("you-lora.safetensors", 0.2)],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||||
|
"output": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for test in test_cases:
|
||||||
|
prompt, loras, loras_limit = test["input"]
|
||||||
|
expected = test["output"]
|
||||||
|
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
Loading…
Reference in New Issue