This commit is contained in:
cantor-set 2024-05-04 18:53:04 +02:00 committed by GitHub
commit efdcb6b8d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 160 additions and 67 deletions

5
development.md Normal file
View File

@ -0,0 +1,5 @@
## Running unit tests
```
python -m unittest tests/
```

0
modules/__init__.py Normal file
View File

View File

@ -43,11 +43,11 @@ def worker():
import fooocus_version
import args_manager
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
from modules.sdxl_styles import apply_style, fooocus_expansion, apply_arrays
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, parse_lora_references_from_prompt
from modules.upscaler import perform_upscale
from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme
@ -426,6 +426,10 @@ def worker():
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(async_task, 3, 'Loading models ...')
# Parse lora references from prompt
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner)
@ -440,11 +444,11 @@ def worker():
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = modules.config.apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
task_prompt = apply_arrays(task_prompt, i)
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
task_negative_prompt = modules.config.apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
task_extra_positive_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts]
positive_basic_workloads = []
negative_basic_workloads = []

View File

@ -1,6 +1,7 @@
import os
import json
import math
import re
import numbers
import args_manager
import tempfile
@ -8,10 +9,9 @@ import modules.flags
import modules.sdxl_styles
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder, makedirs_with_log
from modules.extra_utils import get_files_from_folder, makedirs_with_log
from modules.flags import OutputFormat, Performance, MetadataScheme
def get_config_path(key, default_value):
env = os.getenv(key)
if env is not None and isinstance(env, str):
@ -20,7 +20,7 @@ def get_config_path(key, default_value):
else:
return os.path.abspath(default_value)
wildcards_max_bfs_depth = 64
config_path = get_config_path('config_path', "./config.txt")
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
config_dict = {}
@ -680,4 +680,32 @@ def downloading_upscale_model():
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
update_files()

27
modules/extra_utils.py Normal file
View File

@ -0,0 +1,27 @@
import os
#TODO: Use refactor to use glob instead
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, _, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames
def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')

View File

@ -2,13 +2,13 @@ import os
import re
import json
import math
import modules.config
#import modules.config
from modules.util import get_files_from_folder
from modules.extra_utils import get_files_from_folder
# cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_max_bfs_depth = 64
def normalize_key(k):
@ -59,33 +59,6 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order):
for _ in range(wildcards_max_bfs_depth):
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
if len(placeholders) == 0:
return wildcard_text
print(f'[Wildcards] processing: {wildcard_text}')
for placeholder in placeholders:
try:
matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder]
words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines()
words = [x for x in words if x != '']
assert len(words) > 0
if read_wildcards_in_order:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1)
else:
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
except:
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
f'Using "{placeholder}" as a normal word.')
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
print(f'[Wildcards] {wildcard_text}')
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text
def get_words(arrays, totalMult, index):
if len(arrays) == 1:
return [arrays[0].split(',')[index]]

View File

@ -1,11 +1,12 @@
import typing
import numpy as np
import datetime
import random
import math
import os
import cv2
import re
from typing import List, Tuple, AnyStr, NamedTuple
import json
import hashlib
@ -14,8 +15,16 @@ from PIL import Image
import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
HASH_SHA256_LENGTH = 10
def erode_or_dilate(x, k):
k = int(k)
if k > 0:
@ -163,24 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'):
return date_string, os.path.abspath(result), filename
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")
filenames = []
for root, dirs, files in os.walk(folder_path, topdown=False):
relative_path = os.path.relpath(root, folder_path)
if relative_path == ".":
relative_path = ""
for filename in sorted(files, key=lambda s: s.casefold()):
_, file_extension = os.path.splitext(filename)
if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return filenames
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
print(f"Calculating sha256 for {filename}: ", end='')
@ -355,7 +346,7 @@ def extract_styles_from_prompt(prompt, negative_prompt):
return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple):
class PromptStyle(NamedTuple):
name: str
prompt: str
negative_prompt: str
@ -383,12 +374,25 @@ def ordinal_suffix(number: int) -> str:
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')
def get_enabled_loras(loras: list) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
updated_loras = []
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras[:loras_limit]

4
tests/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())

48
tests/test_utils.py Normal file
View File

@ -0,0 +1,48 @@
import unittest
from modules import util
class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)],
},
# test Loras from UI take precedence over prompt
{
"input": (
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
),
"output": [
("hey-lora.safetensors", 0.4),
("l1.safetensors", 0.4),
("l2.safetensors", -0.2),
("l3.safetensors", 0.3),
("l4.safetensors", 0.5),
],
},
# Test lora specification not separated by comma are ignored, only latest specified is used
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)],
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>",[], 6),
"output": []
}
]
for test in test_cases:
promp, loras, loras_limit = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit)
self.assertEqual(expected, actual)