diff --git a/development.md b/development.md new file mode 100644 index 00000000..bbb3def9 --- /dev/null +++ b/development.md @@ -0,0 +1,11 @@ +## Running unit tests + +Native python: +``` +python -m unittest tests/ +``` + +Embedded python (Windows zip file installation method): +``` +..\python_embeded\python.exe -m unittest +``` diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/async_worker.py b/modules/async_worker.py index 6f0b30a9..7f0a46e3 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all patch_all() + class AsyncTask: def __init__(self, args): self.args = args @@ -44,11 +45,12 @@ def worker(): import args_manager from extras.censor import censor_batch, censor_single - from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name + from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ - get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras + from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, + parse_lora_references_from_prompt, apply_wildcards) from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -69,7 +71,8 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, + progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] @@ -152,7 +155,8 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in + range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -202,7 +206,8 @@ def worker(): inpaint_erode_or_dilate = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS + metadata_scheme = MetadataScheme( + args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS cn_tasks = {x: [] for x in flags.ip_list} for _ in range(flags.controlnet_image_count): @@ -433,13 +438,16 @@ def worker(): extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] progressbar(async_task, 3, 'Loading models ...') + + loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) progressbar(async_task, 3, 'Processing prompts ...') tasks = [] - + for i in range(image_number): if disable_seed_increment: task_seed = seed % (constants.MAX_SEED + 1) @@ -450,8 +458,10 @@ def worker(): task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] @@ -652,7 +662,8 @@ def worker(): ) if debugging_inpaint_preprocessor: - yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, + do_not_show_finished_images=True) return progressbar(async_task, 13, 'VAE Inpaint encoding ...') @@ -807,7 +818,8 @@ def worker(): done_steps = current_task_id * steps + step async_task.yields.append(['preview', ( int(15.0 + 85.0 * float(done_steps) / float(all_steps)), - f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) + f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', + y)]) for current_task_id, task in enumerate(tasks): execution_start_time = time.perf_counter() @@ -862,7 +874,8 @@ def worker(): d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), + ('Styles', 'styles', + str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: @@ -885,7 +898,8 @@ def worker(): if refiner_swap_method != flags.refiner_swap_method: d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: - d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) + d.append( + ('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) @@ -905,11 +919,13 @@ def worker(): metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], steps, base_model_name, refiner_model_name, loras, vae_name) - d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) + d.append(('Metadata Scheme', 'metadata_scheme', + metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format, task)) - yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) + yield_result(async_task, img_paths, black_out_nsfw, False, + do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index ffb74a23..11fe3181 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,7 +8,8 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.util import get_files_from_folder, makedirs_with_log +from modules.util import makedirs_with_log +from modules.extra_utils import get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme @@ -20,7 +21,7 @@ def get_config_path(key, default_value): else: return os.path.abspath(default_value) - +wildcards_max_bfs_depth = 64 config_path = get_config_path('config_path', "./config.txt") config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") config_dict = {} diff --git a/modules/extra_utils.py b/modules/extra_utils.py new file mode 100644 index 00000000..3e95e8b5 --- /dev/null +++ b/modules/extra_utils.py @@ -0,0 +1,20 @@ +import os + + +def get_files_from_folder(folder_path, extensions=None, name_filter=None): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + + for root, _, files in os.walk(folder_path, topdown=False): + relative_path = os.path.relpath(root, folder_path) + if relative_path == ".": + relative_path = "" + for filename in sorted(files, key=lambda s: s.casefold()): + _, file_extension = os.path.splitext(filename) + if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): + path = os.path.join(relative_path, filename) + filenames.append(path) + + return filenames diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 5b6afb59..12ab6c5c 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -2,14 +2,12 @@ import os import re import json import math -import modules.config -from modules.util import get_files_from_folder +from modules.extra_utils import get_files_from_folder from random import Random # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) -wildcards_max_bfs_depth = 64 def normalize_key(k): @@ -25,7 +23,6 @@ def normalize_key(k): styles = {} - styles_files = get_files_from_folder(styles_path, ['.json']) for x in ['sdxl_styles_fooocus.json', @@ -65,34 +62,7 @@ def apply_style(style, positive): return p.replace('{prompt}', positive).splitlines(), n.splitlines() -def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): - for _ in range(wildcards_max_bfs_depth): - placeholders = re.findall(r'__([\w-]+)__', wildcard_text) - if len(placeholders) == 0: - return wildcard_text - - print(f'[Wildcards] processing: {wildcard_text}') - for placeholder in placeholders: - try: - matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] - words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() - words = [x for x in words if x != ''] - assert len(words) > 0 - if read_wildcards_in_order: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) - else: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) - except: - print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' - f'Using "{placeholder}" as a normal word.') - wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) - print(f'[Wildcards] {wildcard_text}') - - print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') - return wildcard_text - - -def get_words(arrays, totalMult, index): +def get_words(arrays, total_mult, index): if len(arrays) == 1: return [arrays[0].split(',')[index]] else: @@ -101,7 +71,7 @@ def get_words(arrays, totalMult, index): index -= index % len(words) index /= len(words) index = math.floor(index) - return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) + return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index) def apply_arrays(text, index): diff --git a/modules/util.py b/modules/util.py index d2feecb6..73430230 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,11 +1,12 @@ -import typing - import numpy as np import datetime import random import math import os import cv2 +import re +from typing import List, Tuple, AnyStr, NamedTuple + import json import hashlib @@ -14,8 +15,16 @@ from PIL import Image import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + + +# Regexp compiled once. Matches entries with the following pattern: +# +# +LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) + HASH_SHA256_LENGTH = 10 + def erode_or_dilate(x, k): k = int(k) if k > 0: @@ -163,25 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'): return date_string, os.path.abspath(result), filename -def get_files_from_folder(folder_path, extensions=None, name_filter=None): - if not os.path.isdir(folder_path): - raise ValueError("Folder path is not a valid directory.") - - filenames = [] - - for root, dirs, files in os.walk(folder_path, topdown=False): - relative_path = os.path.relpath(root, folder_path) - if relative_path == ".": - relative_path = "" - for filename in sorted(files, key=lambda s: s.casefold()): - _, file_extension = os.path.splitext(filename) - if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): - path = os.path.join(relative_path, filename) - filenames.append(path) - - return filenames - - def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): print(f"Calculating sha256 for {filename}: ", end='') if use_addnet_hash: @@ -355,7 +345,7 @@ def extract_styles_from_prompt(prompt, negative_prompt): return list(reversed(extracted)), real_prompt, negative_prompt -class PromptStyle(typing.NamedTuple): +class PromptStyle(NamedTuple): name: str prompt: str negative_prompt: str @@ -394,4 +384,47 @@ def makedirs_with_log(path): def get_enabled_loras(loras: list) -> list: - return [[lora[1], lora[2]] for lora in loras if lora[0]] + return [(lora[1], lora[2]) for lora in loras if lora[0]] + + +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: + new_loras = [] + updated_loras = [] + for token in prompt.split(","): + m = LORAS_PROMPT_PATTERN.match(token) + + if m: + new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) + + for lora in loras + new_loras: + if lora[0] != "None": + updated_loras.append(lora) + + return updated_loras[:loras_limit] + + +def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: + for _ in range(modules.config.wildcards_max_bfs_depth): + placeholders = re.findall(r'__([\w-]+)__', wildcard_text) + if len(placeholders) == 0: + return wildcard_text + + print(f'[Wildcards] processing: {wildcard_text}') + for placeholder in placeholders: + try: + matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] + words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() + words = [x for x in words if x != ''] + assert len(words) > 0 + if read_wildcards_in_order: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) + else: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) + except: + print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' + f'Using "{placeholder}" as a normal word.') + wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) + print(f'[Wildcards] {wildcard_text}') + + print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') + return wildcard_text diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..c424468f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..0698dcc8 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,48 @@ +import unittest + +from modules import util + + +class TestUtils(unittest.TestCase): + def test_can_parse_tokens_with_lora(self): + test_cases = [ + { + "input": ("some prompt, very cool, , cool ", [], 5), + "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + }, + # Test can not exceed limit + { + "input": ("some prompt, very cool, , cool ", [], 1), + "output": [("hey-lora.safetensors", 0.4)], + }, + # test Loras from UI take precedence over prompt + { + "input": ( + "some prompt, very cool, , , , , , ", + [("hey-lora.safetensors", 0.4)], + 5, + ), + "output": [ + ("hey-lora.safetensors", 0.4), + ("l1.safetensors", 0.4), + ("l2.safetensors", -0.2), + ("l3.safetensors", 0.3), + ("l4.safetensors", 0.5), + ], + }, + # Test lora specification not separated by comma are ignored, only latest specified is used + { + "input": ("some prompt, very cool, ", [], 3), + "output": [("you-lora.safetensors", 0.2)], + }, + { + "input": (", , and ", [], 6), + "output": [] + } + ] + + for test in test_cases: + prompt, loras, loras_limit = test["input"] + expected = test["output"] + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit) + self.assertEqual(expected, actual)