feat: add support for lora inline prompt references (#2323)

* Adding support to inline prompt references

* Added unittests

* Added an initial documentation for development guidelines

* Added a negative number

* renamed parameter

* removed wrongly committed file

* Code fixes

* Fixed circular reference

* Fixed typo. Added TODO

* Fixed merge

* Code cleanup

* Added missing refernce function

* Removed function from util.py... again...

* Update modules/async_worker.py

Implemented suggested change

Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com>

* Removed another circular reference

* Renamed module

* Addressed PR comments

* Added return type to function

* refactor: move apply_wildcards to module util

* refactor: code cleanup, unify usage of tuples in lora list

* docs: add instructions for running unittests on embedded python, code cleanup

* refactor: code cleanup, move makedirs_with_log back to util

---------

Co-authored-by: cantor-set <cantor-set@no-email.net>
Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com>
Co-authored-by: Manuel Schmid <dev@mash1t.de>
This commit is contained in:
cantor-set 2024-05-18 11:19:46 -04:00 committed by GitHub
parent 3a55e7e391
commit 3bae73e23e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 176 additions and 73 deletions

11
development.md Normal file
View File

@ -0,0 +1,11 @@
## Running unit tests
Native python:
```
python -m unittest tests/
```
Embedded python (Windows zip file installation method):
```
..\python_embeded\python.exe -m unittest
```

0
modules/__init__.py Normal file
View File

View File

@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all
patch_all() patch_all()
class AsyncTask: class AsyncTask:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
@ -44,11 +45,12 @@ def worker():
import args_manager import args_manager
from extras.censor import censor_batch, censor_single from extras.censor import censor_batch, censor_single
from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
from modules.private_logger import log from modules.private_logger import log
from extras.expansion import safe_str from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil,
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras,
parse_lora_references_from_prompt, apply_wildcards)
from modules.upscaler import perform_upscale from modules.upscaler import perform_upscale
from modules.flags import Performance from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme from modules.meta_parser import get_metadata_parser, MetadataScheme
@ -69,7 +71,8 @@ def worker():
print(f'[Fooocus] {text}') print(f'[Fooocus] {text}')
async_task.yields.append(['preview', (number, text, None)]) async_task.yields.append(['preview', (number, text, None)])
def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False,
progressbar_index=13):
if not isinstance(imgs, list): if not isinstance(imgs, list):
imgs = [imgs] imgs = [imgs]
@ -152,7 +155,8 @@ def worker():
base_model_name = args.pop() base_model_name = args.pop()
refiner_model_name = args.pop() refiner_model_name = args.pop()
refiner_switch = args.pop() refiner_switch = args.pop()
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in
range(modules.config.default_max_lora_number)])
input_image_checkbox = args.pop() input_image_checkbox = args.pop()
current_tab = args.pop() current_tab = args.pop()
uov_method = args.pop() uov_method = args.pop()
@ -202,7 +206,8 @@ def worker():
inpaint_erode_or_dilate = args.pop() inpaint_erode_or_dilate = args.pop()
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS metadata_scheme = MetadataScheme(
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
cn_tasks = {x: [] for x in flags.ip_list} cn_tasks = {x: [] for x in flags.ip_list}
for _ in range(flags.controlnet_image_count): for _ in range(flags.controlnet_image_count):
@ -433,6 +438,9 @@ def worker():
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(async_task, 3, 'Loading models ...') progressbar(async_task, 3, 'Loading models ...')
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras, loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
@ -450,8 +458,10 @@ def worker():
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_arrays(task_prompt, i) task_prompt = apply_arrays(task_prompt, i)
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
extra_negative_prompts]
positive_basic_workloads = [] positive_basic_workloads = []
negative_basic_workloads = [] negative_basic_workloads = []
@ -652,7 +662,8 @@ def worker():
) )
if debugging_inpaint_preprocessor: if debugging_inpaint_preprocessor:
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw,
do_not_show_finished_images=True)
return return
progressbar(async_task, 13, 'VAE Inpaint encoding ...') progressbar(async_task, 13, 'VAE Inpaint encoding ...')
@ -807,7 +818,8 @@ def worker():
done_steps = current_task_id * steps + step done_steps = current_task_id * steps + step
async_task.yields.append(['preview', ( async_task.yields.append(['preview', (
int(15.0 + 85.0 * float(done_steps) / float(all_steps)), int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling',
y)])
for current_task_id, task in enumerate(tasks): for current_task_id, task in enumerate(tasks):
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
@ -862,7 +874,8 @@ def worker():
d = [('Prompt', 'prompt', task['log_positive_prompt']), d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Styles', 'styles',
str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
('Performance', 'performance', performance_selection.value)] ('Performance', 'performance', performance_selection.value)]
if performance_selection.steps() != steps: if performance_selection.steps() != steps:
@ -885,7 +898,8 @@ def worker():
if refiner_swap_method != flags.refiner_swap_method: if refiner_swap_method != flags.refiner_swap_method:
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(
('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
d.append(('Sampler', 'sampler', sampler_name)) d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name)) d.append(('Scheduler', 'scheduler', scheduler_name))
@ -905,11 +919,13 @@ def worker():
metadata_parser.set_data(task['log_positive_prompt'], task['positive'], metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'], task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras, vae_name) steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Metadata Scheme', 'metadata_scheme',
metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format, task)) img_paths.append(log(x, d, metadata_parser, output_format, task))
yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) yield_result(async_task, img_paths, black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
except ldm_patched.modules.model_management.InterruptProcessingException as e: except ldm_patched.modules.model_management.InterruptProcessingException as e:
if async_task.last_stop == 'skip': if async_task.last_stop == 'skip':
print('User skipped') print('User skipped')

View File

@ -8,7 +8,8 @@ import modules.flags
import modules.sdxl_styles import modules.sdxl_styles
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder, makedirs_with_log from modules.util import makedirs_with_log
from modules.extra_utils import get_files_from_folder
from modules.flags import OutputFormat, Performance, MetadataScheme from modules.flags import OutputFormat, Performance, MetadataScheme
@ -20,7 +21,7 @@ def get_config_path(key, default_value):
else: else:
return os.path.abspath(default_value) return os.path.abspath(default_value)
wildcards_max_bfs_depth = 64
config_path = get_config_path('config_path', "./config.txt") config_path = get_config_path('config_path', "./config.txt")
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
config_dict = {} config_dict = {}

20
modules/extra_utils.py Normal file
View File

@ -0,0 +1,20 @@
import os
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, _, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames

View File

@ -2,14 +2,12 @@ import os
import re import re
import json import json
import math import math
import modules.config
from modules.util import get_files_from_folder from modules.extra_utils import get_files_from_folder
from random import Random from random import Random
# cannot use modules.config - validators causing circular imports # cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_max_bfs_depth = 64
def normalize_key(k): def normalize_key(k):
@ -25,7 +23,6 @@ def normalize_key(k):
styles = {} styles = {}
styles_files = get_files_from_folder(styles_path, ['.json']) styles_files = get_files_from_folder(styles_path, ['.json'])
for x in ['sdxl_styles_fooocus.json', for x in ['sdxl_styles_fooocus.json',
@ -65,34 +62,7 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines() return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): def get_words(arrays, total_mult, index):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_words(arrays, totalMult, index):
if len(arrays) == 1: if len(arrays) == 1:
return [arrays[0].split(',')[index]] return [arrays[0].split(',')[index]]
else: else:
@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index):
index -= index % len(words) index -= index % len(words)
index /= len(words) index /= len(words)
index = math.floor(index) index = math.floor(index)
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
def apply_arrays(text, index): def apply_arrays(text, index):

View File

@ -1,11 +1,12 @@
import typing
import numpy as np import numpy as np
import datetime import datetime
import random import random
import math import math
import os import os
import cv2 import cv2
import re
from typing import List, Tuple, AnyStr, NamedTuple
import json import json
import hashlib import hashlib
@ -14,8 +15,16 @@ from PIL import Image
import modules.sdxl_styles import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
HASH_SHA256_LENGTH = 10 HASH_SHA256_LENGTH = 10
def erode_or_dilate(x, k): def erode_or_dilate(x, k):
k = int(k) k = int(k)
if k > 0: if k > 0:
@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
return date_string, os.path.abspath(result), filename return date_string, os.path.abspath(result), filename
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, dirs, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
print(f"Calculating sha256 for {filename}: ", end='') print(f"Calculating sha256 for {filename}: ", end='')
if use_addnet_hash: if use_addnet_hash:
@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
return list(reversed(extracted)), real_prompt, negative_prompt return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple): class PromptStyle(NamedTuple):
name: str name: str
prompt: str prompt: str
negative_prompt: str negative_prompt: str
@ -394,4 +384,47 @@ def makedirs_with_log(path):
def get_enabled_loras(loras: list) -> list: def get_enabled_loras(loras: list) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]] return [(lora[1], lora[2]) for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras[:loras_limit]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
for _ in range(modules.config.wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text

4
tests/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())

48
tests/test_utils.py Normal file
View File

@ -0,0 +1,48 @@
import unittest
from modules import util
class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)],
},
# test Loras from UI take precedence over prompt
{
"input": (
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
),
"output": [
("hey-lora.safetensors", 0.4),
("l1.safetensors", 0.4),
("l2.safetensors", -0.2),
("l3.safetensors", 0.3),
("l4.safetensors", 0.5),
],
},
# Test lora specification not separated by comma are ignored, only latest specified is used
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)],
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
"output": []
}
]
for test in test_cases:
prompt, loras, loras_limit = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
self.assertEqual(expected, actual)