feat: add support for lora inline prompt references (#2323)

* Adding support to inline prompt references

* Added unittests

* Added an initial documentation for development guidelines

* Added a negative number

* renamed parameter

* removed wrongly committed file

* Code fixes

* Fixed circular reference

* Fixed typo. Added TODO

* Fixed merge

* Code cleanup

* Added missing refernce function

* Removed function from util.py... again...

* Update modules/async_worker.py

Implemented suggested change

Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com>

* Removed another circular reference

* Renamed module

* Addressed PR comments

* Added return type to function

* refactor: move apply_wildcards to module util

* refactor: code cleanup, unify usage of tuples in lora list

* docs: add instructions for running unittests on embedded python, code cleanup

* refactor: code cleanup, move makedirs_with_log back to util

---------

Co-authored-by: cantor-set <cantor-set@no-email.net>
Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com>
Co-authored-by: Manuel Schmid <dev@mash1t.de>
This commit is contained in:
cantor-set 2024-05-18 11:19:46 -04:00 committed by GitHub
parent 3a55e7e391
commit 3bae73e23e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 176 additions and 73 deletions

11
development.md Normal file
View File

@ -0,0 +1,11 @@
## Running unit tests
Native python:
```
python -m unittest tests/
```
Embedded python (Windows zip file installation method):
```
..\python_embeded\python.exe -m unittest
```

0
modules/__init__.py Normal file
View File

View File

@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all
patch_all()
class AsyncTask:
def __init__(self, args):
self.args = args
@ -44,11 +45,12 @@ def worker():
import args_manager
from extras.censor import censor_batch, censor_single
from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras
from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras,
parse_lora_references_from_prompt, apply_wildcards)
from modules.upscaler import perform_upscale
from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme
@ -69,7 +71,8 @@ def worker():
print(f'[Fooocus] {text}')
async_task.yields.append(['preview', (number, text, None)])
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13):
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False,
progressbar_index=13):
if not isinstance(imgs, list):
imgs = [imgs]
@ -152,7 +155,8 @@ def worker():
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in
range(modules.config.default_max_lora_number)])
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
@ -202,7 +206,8 @@ def worker():
inpaint_erode_or_dilate = args.pop()
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
metadata_scheme = MetadataScheme(
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
cn_tasks = {x: [] for x in flags.ip_list}
for _ in range(flags.controlnet_image_count):
@ -433,13 +438,16 @@ def worker():
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(async_task, 3, 'Loading models ...')
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
for i in range(image_number):
if disable_seed_increment:
task_seed = seed % (constants.MAX_SEED + 1)
@ -450,8 +458,10 @@ def worker():
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_arrays(task_prompt, i)
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
extra_negative_prompts]
positive_basic_workloads = []
negative_basic_workloads = []
@ -652,7 +662,8 @@ def worker():
)
if debugging_inpaint_preprocessor:
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True)
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw,
do_not_show_finished_images=True)
return
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
@ -807,7 +818,8 @@ def worker():
done_steps = current_task_id * steps + step
async_task.yields.append(['preview', (
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)])
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling',
y)])
for current_task_id, task in enumerate(tasks):
execution_start_time = time.perf_counter()
@ -862,7 +874,8 @@ def worker():
d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
('Styles', 'styles',
str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
('Performance', 'performance', performance_selection.value)]
if performance_selection.steps() != steps:
@ -885,7 +898,8 @@ def worker():
if refiner_swap_method != flags.refiner_swap_method:
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
d.append(
('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name))
@ -905,11 +919,13 @@ def worker():
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Metadata Scheme', 'metadata_scheme',
metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format, task))
yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
yield_result(async_task, img_paths, black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
except ldm_patched.modules.model_management.InterruptProcessingException as e:
if async_task.last_stop == 'skip':
print('User skipped')

View File

@ -8,7 +8,8 @@ import modules.flags
import modules.sdxl_styles
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder, makedirs_with_log
from modules.util import makedirs_with_log
from modules.extra_utils import get_files_from_folder
from modules.flags import OutputFormat, Performance, MetadataScheme
@ -20,7 +21,7 @@ def get_config_path(key, default_value):
else:
return os.path.abspath(default_value)
wildcards_max_bfs_depth = 64
config_path = get_config_path('config_path', "./config.txt")
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
config_dict = {}

20
modules/extra_utils.py Normal file
View File

@ -0,0 +1,20 @@
import os
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, _, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames

View File

@ -2,14 +2,12 @@ import os
import re
import json
import math
import modules.config
from modules.util import get_files_from_folder
from modules.extra_utils import get_files_from_folder
from random import Random
# cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_max_bfs_depth = 64
def normalize_key(k):
@ -25,7 +23,6 @@ def normalize_key(k):
styles = {}
styles_files = get_files_from_folder(styles_path, ['.json'])
for x in ['sdxl_styles_fooocus.json',
@ -65,34 +62,7 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_words(arrays, totalMult, index):
def get_words(arrays, total_mult, index):
if len(arrays) == 1:
return [arrays[0].split(',')[index]]
else:
@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index):
index -= index % len(words)
index /= len(words)
index = math.floor(index)
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index)
return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
def apply_arrays(text, index):

View File

@ -1,11 +1,12 @@
import typing
import numpy as np
import datetime
import random
import math
import os
import cv2
import re
from typing import List, Tuple, AnyStr, NamedTuple
import json
import hashlib
@ -14,8 +15,16 @@ from PIL import Image
import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
HASH_SHA256_LENGTH = 10
def erode_or_dilate(x, k):
k = int(k)
if k > 0:
@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
return date_string, os.path.abspath(result), filename
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, dirs, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
print(f"Calculating sha256 for {filename}: ", end='')
if use_addnet_hash:
@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple):
class PromptStyle(NamedTuple):
name: str
prompt: str
negative_prompt: str
@ -394,4 +384,47 @@ def makedirs_with_log(path):
def get_enabled_loras(loras: list) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]]
return [(lora[1], lora[2]) for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras[:loras_limit]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
for _ in range(modules.config.wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text

4
tests/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())

48
tests/test_utils.py Normal file
View File

@ -0,0 +1,48 @@
import unittest
from modules import util
class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)],
},
# test Loras from UI take precedence over prompt
{
"input": (
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
),
"output": [
("hey-lora.safetensors", 0.4),
("l1.safetensors", 0.4),
("l2.safetensors", -0.2),
("l3.safetensors", 0.3),
("l4.safetensors", 0.5),
],
},
# Test lora specification not separated by comma are ignored, only latest specified is used
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)],
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
"output": []
}
]
for test in test_cases:
prompt, loras, loras_limit = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
self.assertEqual(expected, actual)