From 31263b40f640d8efd7ed719b9c664d52acc5c88d Mon Sep 17 00:00:00 2001 From: cantor-set Date: Wed, 21 Feb 2024 21:15:24 -0500 Subject: [PATCH 01/22] Adding support to inline prompt references --- modules/async_worker.py | 17 +++++++++++++++-- modules/util.py | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 40abb7fa..37ab09ae 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -39,8 +39,17 @@ def worker(): from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix + from modules.util import ( + remove_empty_str, + HWC3, resize_image, + get_image_shape_ceil, + set_image_shape_ceil, + get_shape_ceil, + resample_image, + erode_or_dilate, + ordinal_suffix, + parse_lora_references_from_prompt + ) from modules.upscaler import perform_upscale try: @@ -370,6 +379,10 @@ def worker(): extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] progressbar(async_task, 3, 'Loading models ...') + + # Parse lora references from prompt + loras = parse_lora_references_from_prompt(prompt, loras) + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner) diff --git a/modules/util.py b/modules/util.py index c309480a..2355dae3 100644 --- a/modules/util.py +++ b/modules/util.py @@ -4,6 +4,8 @@ import random import math import os import cv2 +import re +from typing import List, Tuple, AnyStr from PIL import Image @@ -179,3 +181,23 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None): def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') + + +def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]): + pattern = re.compile(".*.*") + new_loras = [] + + for token in items.split(","): + print(token) + m = pattern.match(token) + + if m: + new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) + + updated_loras = [] + for lora in loras + new_loras: + + if lora[0] != "None": + updated_loras.append(lora) + + return updated_loras From 24acbc39fe97e2c1b12759471042b4f43e924e2a Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 22 Feb 2024 18:42:30 -0500 Subject: [PATCH 02/22] Added unittests --- modules/__init__.py | 0 modules/async_worker.py | 6 ++++-- modules/util.py | 14 ++++++++----- tests/__init__.py | 4 ++++ tests/test_utils.py | 44 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 7 deletions(-) create mode 100644 modules/__init__.py create mode 100644 tests/__init__.py create mode 100644 tests/test_utils.py diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/async_worker.py b/modules/async_worker.py index 37ab09ae..5fa5c0bd 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -52,6 +52,8 @@ def worker(): ) from modules.upscaler import perform_upscale + MAX_LORAS = 5 + try: async_gradio_app = shared.gradio_root flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}''' @@ -140,7 +142,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(5)] + loras = [[str(args.pop()), float(args.pop())] for _ in range(MAX_LORAS)] input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -381,7 +383,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') # Parse lora references from prompt - loras = parse_lora_references_from_prompt(prompt, loras) + loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=MAX_LORAS) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, diff --git a/modules/util.py b/modules/util.py index 2355dae3..d350f3c6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,6 +12,10 @@ from PIL import Image LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +# Regexp compiled once. Matches entries with the following pattern: +# +# +LORAS_PROMPT_PATTERN = re.compile(".*.*") def erode_or_dilate(x, k): k = int(k) @@ -183,13 +187,13 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') -def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]): - pattern = re.compile(".*.*") +def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5): + new_loras = [] for token in items.split(","): - print(token) - m = pattern.match(token) + + m = LORAS_PROMPT_PATTERN.match(token) if m: new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) @@ -200,4 +204,4 @@ def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, floa if lora[0] != "None": updated_loras.append(lora) - return updated_loras + return updated_loras[:loras_limit] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..f86b4227 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..998bf058 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,44 @@ +import unittest +from modules import util + + +class TestUtils(unittest.TestCase): + def test_can_parse_tokens_with_lora(self): + + test_cases = [ + { + "input": ("some prompt, very cool, , cool ", [], 5), + "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + }, + # Test can not exceed limit + { + "input": ("some prompt, very cool, , cool ", [], 1), + "output": [("hey-lora.safetensors", 0.4)], + }, + # test Loras from UI take precedence over prompt + { + "input": ( + "some prompt, very cool, , , , , , ", + [("hey-lora.safetensors", 0.4)], + 5, + ), + "output": [ + ("hey-lora.safetensors", 0.4), + ("l1.safetensors", 0.4), + ("l2.safetensors", 0.2), + ("l3.safetensors", 0.3), + ("l4.safetensors", 0.5), + ], + }, + # Test lora specification not separated by comma are ignored, only latest specified is used + { + "input": ("some prompt, very cool, ", [], 3), + "output": [("you-lora.safetensors", 0.2)], + }, + ] + + for test in test_cases: + promp, loras, loras_limit = test["input"] + expected = test["output"] + actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit) + self.assertEqual(expected, actual) From 2e6d4c3abbb286b4570e7fa736a44b6424eef5cb Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 22 Feb 2024 18:44:20 -0500 Subject: [PATCH 03/22] Added an initial documentation for development guidelines --- development.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 development.md diff --git a/development.md b/development.md new file mode 100644 index 00000000..a402a872 --- /dev/null +++ b/development.md @@ -0,0 +1,5 @@ +## Running unit tests +``` +python -m unittest tests/ +``` + From e802016043aabedc904a06f97ceae56df7a3a221 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 22 Feb 2024 18:50:06 -0500 Subject: [PATCH 04/22] Added a negative number --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 998bf058..c9861c9b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,14 +18,14 @@ class TestUtils(unittest.TestCase): # test Loras from UI take precedence over prompt { "input": ( - "some prompt, very cool, , , , , , ", + "some prompt, very cool, , , , , , ", [("hey-lora.safetensors", 0.4)], 5, ), "output": [ ("hey-lora.safetensors", 0.4), ("l1.safetensors", 0.4), - ("l2.safetensors", 0.2), + ("l2.safetensors", -0.2), ("l3.safetensors", 0.3), ("l4.safetensors", 0.5), ], From 9e60a8c3f88d764f8d8e7e971c02209344c0cc80 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 22 Feb 2024 18:51:51 -0500 Subject: [PATCH 05/22] renamed parameter --- modules/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/util.py b/modules/util.py index d350f3c6..527bc82b 100644 --- a/modules/util.py +++ b/modules/util.py @@ -187,11 +187,11 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') -def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5): +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5): new_loras = [] - for token in items.split(","): + for token in prompt.split(","): m = LORAS_PROMPT_PATTERN.match(token) From 84945aac1f1651c2d4072c00b7f0ae4dd806d8e9 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Sun, 3 Mar 2024 19:22:56 -0500 Subject: [PATCH 06/22] removed wrongly committed file --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index ba8efb3b..b2111c1f 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -4,7 +4,7 @@ transformers==4.30.2 safetensors==0.3.1 accelerate==0.21.0 pyyaml==6.0 -#Pillow==9.2.0 +Pillow==9.2.0 scipy==1.9.3 tqdm==4.64.1 psutil==5.9.5 From 0a24a8a67adc2ce924b5b19ce31285df7507b33b Mon Sep 17 00:00:00 2001 From: cantor-set Date: Sun, 3 Mar 2024 19:25:07 -0500 Subject: [PATCH 07/22] Code fixes --- modules/async_worker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 2ff1b71c..4047a86e 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -63,8 +63,6 @@ def worker(): pid = os.getpid() print(f'Started worker with PID {pid}') - MAX_LORAS = 5 - try: async_gradio_app = shared.gradio_root flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}''' @@ -165,8 +163,6 @@ def worker(): refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(MAX_LORAS)] - loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() @@ -430,7 +426,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') # Parse lora references from prompt - loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=MAX_LORAS) + loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=modules.config.default_max_lora_number) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, From 9d1c02d30d417d859164ac880063a028ca47a619 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Sun, 3 Mar 2024 20:40:44 -0500 Subject: [PATCH 08/22] Fixed circular reference --- modules/config.py | 3 ++- modules/path_utils.py | 20 ++++++++++++++++++++ modules/sdxl_styles.py | 3 +-- modules/util.py | 19 ------------------- tests/__init__.py | 2 +- 5 files changed, 24 insertions(+), 23 deletions(-) create mode 100644 modules/path_utils.py diff --git a/modules/config.py b/modules/config.py index 09c8fd7c..91a7546a 100644 --- a/modules/config.py +++ b/modules/config.py @@ -7,7 +7,8 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.util import get_files_from_folder, makedirs_with_log +from modules.util import makedirs_with_log +from modules.path_utils import get_files_from_folder from modules.flags import Performance, MetadataScheme def get_config_path(key, default_value): diff --git a/modules/path_utils.py b/modules/path_utils.py new file mode 100644 index 00000000..b4e27168 --- /dev/null +++ b/modules/path_utils.py @@ -0,0 +1,20 @@ + +import os + +def get_files_from_folder(folder_path, exensions=None, name_filter=None): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + + for root, dirs, files in os.walk(folder_path, topdown=False): + relative_path = os.path.relpath(root, folder_path) + if relative_path == ".": + relative_path = "" + for filename in sorted(files, key=lambda s: s.casefold()): + _, file_extension = os.path.splitext(filename) + if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _): + path = os.path.join(relative_path, filename) + filenames.append(path) + + return filenames diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 71afc402..ff3497f9 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -3,8 +3,7 @@ import re import json import math -from modules.util import get_files_from_folder - +from modules.path_utils import get_files_from_folder # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) diff --git a/modules/util.py b/modules/util.py index fb12af32..44f1e7ba 100644 --- a/modules/util.py +++ b/modules/util.py @@ -175,25 +175,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'): return date_string, os.path.abspath(result), filename -def get_files_from_folder(folder_path, exensions=None, name_filter=None): - if not os.path.isdir(folder_path): - raise ValueError("Folder path is not a valid directory.") - - filenames = [] - - for root, dirs, files in os.walk(folder_path, topdown=False): - relative_path = os.path.relpath(root, folder_path) - if relative_path == ".": - relative_path = "" - for filename in sorted(files, key=lambda s: s.casefold()): - _, file_extension = os.path.splitext(filename) - if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _): - path = os.path.join(relative_path, filename) - filenames.append(path) - - return filenames - - def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str: hash_sha256 = sha256() blksize = 1024 * 1024 diff --git a/tests/__init__.py b/tests/__init__.py index f86b4227..c424468f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ import sys import pathlib -sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) \ No newline at end of file +sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) From 1362533eda090088256c1eb6a23e2a00d3520c83 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Sun, 3 Mar 2024 20:47:17 -0500 Subject: [PATCH 09/22] Fixed typo. Added TODO --- modules/path_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/path_utils.py b/modules/path_utils.py index b4e27168..0a021520 100644 --- a/modules/path_utils.py +++ b/modules/path_utils.py @@ -1,19 +1,20 @@ import os -def get_files_from_folder(folder_path, exensions=None, name_filter=None): +#TODO: Use refactor to use glob instead +def get_files_from_folder(folder_path, extensions=None, name_filter=None): if not os.path.isdir(folder_path): raise ValueError("Folder path is not a valid directory.") filenames = [] - for root, dirs, files in os.walk(folder_path, topdown=False): + for root, _, files in os.walk(folder_path, topdown=False): relative_path = os.path.relpath(root, folder_path) if relative_path == ".": relative_path = "" for filename in sorted(files, key=lambda s: s.casefold()): _, file_extension = os.path.splitext(filename) - if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _): + if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): path = os.path.join(relative_path, filename) filenames.append(path) From 8a9f6090bc589096b9941c08e2d6366fc2e04d69 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 21 Mar 2024 10:34:29 -0400 Subject: [PATCH 10/22] Fixed merge --- modules/async_worker.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 8f5beca8..17455053 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -46,22 +46,9 @@ def worker(): from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str -<<<<<<< HEAD - from modules.util import ( - remove_empty_str, - HWC3, resize_image, - get_image_shape_ceil, - set_image_shape_ceil, - get_shape_ceil, - resample_image, - erode_or_dilate, - ordinal_suffix, - parse_lora_references_from_prompt - ) -======= from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras ->>>>>>> 978267f461e204c6c4359a79ed818ee2e3e1af39 + from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -161,13 +148,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() -<<<<<<< HEAD - - loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)]) - -======= loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) ->>>>>>> 978267f461e204c6c4359a79ed818ee2e3e1af39 input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() From 9923852f1e661373b5e775a03ec96c12e60a62b2 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 21 Mar 2024 10:40:15 -0400 Subject: [PATCH 11/22] Code cleanup --- modules/async_worker.py | 1 - modules/util.py | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 17455053..c22b8dc8 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -48,7 +48,6 @@ def worker(): from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras - from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme diff --git a/modules/util.py b/modules/util.py index c143e2c4..e36ea6ee 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,14 +1,11 @@ -import typing - import numpy as np import datetime import random import math import os import cv2 - import re -from typing import List, Tuple, AnyStr +from typing import List, Tuple, AnyStr, NamedTuple import json @@ -320,7 +317,7 @@ def extract_styles_from_prompt(prompt, negative_prompt): return list(reversed(extracted)), real_prompt, negative_prompt -class PromptStyle(typing.NamedTuple): +class PromptStyle(NamedTuple): name: str prompt: str negative_prompt: str From f289174af7426cdb8c8065ca2a375609903cd9ba Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 21 Mar 2024 22:46:59 -0400 Subject: [PATCH 12/22] Added missing refernce function --- modules/async_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index c22b8dc8..edf87dfc 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -47,7 +47,7 @@ def worker(): from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ - get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, parse_lora_references_from_prompt from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme From 6fed0961234a2068fcca37f5f0d511b1f85a372c Mon Sep 17 00:00:00 2001 From: cantor-set Date: Wed, 27 Mar 2024 12:36:05 -0400 Subject: [PATCH 13/22] Removed function from util.py... again... --- modules/util.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/modules/util.py b/modules/util.py index 5b621226..b261d79e 100644 --- a/modules/util.py +++ b/modules/util.py @@ -401,23 +401,4 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo if lora[0] != "None": updated_loras.append(lora) - return updated_loras[:loras_limit] - -def get_files_from_folder(folder_path, extensions=None, name_filter=None): - if not os.path.isdir(folder_path): - raise ValueError("Folder path is not a valid directory.") - - filenames = [] - - for root, dirs, files in os.walk(folder_path, topdown=False): - relative_path = os.path.relpath(root, folder_path) - if relative_path == ".": - relative_path = "" - for filename in sorted(files, key=lambda s: s.casefold()): - _, file_extension = os.path.splitext(filename) - if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): - path = os.path.join(relative_path, filename) - filenames.append(path) - - return filenames - + return updated_loras[:loras_limit] \ No newline at end of file From af8980105ab0c7f33b269073d48a0a67019185d5 Mon Sep 17 00:00:00 2001 From: cantor-set <32692347+cantor-set@users.noreply.github.com> Date: Tue, 9 Apr 2024 23:48:35 -0400 Subject: [PATCH 14/22] Update modules/async_worker.py Implemented suggested change Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> --- modules/async_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 4e02e6e1..33c6d95e 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -428,7 +428,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') # Parse lora references from prompt - loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=modules.config.default_max_lora_number) + loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, From 3e8681a4d773109e7f43d144f6c47a6a81d0f5cb Mon Sep 17 00:00:00 2001 From: cantor-set Date: Wed, 10 Apr 2024 00:34:00 -0400 Subject: [PATCH 15/22] Removed another circular reference --- modules/async_worker.py | 10 +++++----- modules/config.py | 35 +++++++++++++++++++++++++++++++---- modules/sdxl_styles.py | 33 +++------------------------------ modules/util.py | 10 ++-------- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 33c6d95e..5f03eaf4 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,7 +43,7 @@ def worker(): import fooocus_version import args_manager - from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays + from modules.sdxl_styles import apply_style, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ @@ -444,11 +444,11 @@ def worker(): task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not task_rng = random.Random(task_seed) # may bind to inpaint noise in the future - task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) + task_prompt = modules.config.apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) - task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_negative_prompt = modules.config.apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) + task_extra_positive_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] + task_extra_negative_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] diff --git a/modules/config.py b/modules/config.py index f893c844..5747d1a9 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,6 +1,7 @@ import os import json import math +import re import numbers import args_manager import tempfile @@ -8,9 +9,7 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url - -from modules.util import makedirs_with_log -from modules.path_utils import get_files_from_folder +from modules.extra_utils import get_files_from_folder, makedirs_with_log from modules.flags import OutputFormat, Performance, MetadataScheme def get_config_path(key, default_value): @@ -21,7 +20,7 @@ def get_config_path(key, default_value): else: return os.path.abspath(default_value) - +wildcards_max_bfs_depth = 64 config_path = get_config_path('config_path', "./config.txt") config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") config_dict = {} @@ -681,4 +680,32 @@ def downloading_upscale_model(): return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') +def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): + for _ in range(wildcards_max_bfs_depth): + placeholders = re.findall(r'__([\w-]+)__', wildcard_text) + if len(placeholders) == 0: + return wildcard_text + + print(f'[Wildcards] processing: {wildcard_text}') + for placeholder in placeholders: + try: + matches = [x for x in wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] + words = open(os.path.join(path_wildcards, matches[0]), encoding='utf-8').read().splitlines() + words = [x for x in words if x != ''] + assert len(words) > 0 + if read_wildcards_in_order: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) + else: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) + except: + print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' + f'Using "{placeholder}" as a normal word.') + wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) + print(f'[Wildcards] {wildcard_text}') + + print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') + return wildcard_text + + + update_files() diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 51f452ea..e4b754e8 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -2,13 +2,13 @@ import os import re import json import math -import modules.config +#import modules.config -from modules.path_utils import get_files_from_folder +from modules.extra_utils import get_files_from_folder # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) -wildcards_max_bfs_depth = 64 + def normalize_key(k): @@ -59,33 +59,6 @@ def apply_style(style, positive): return p.replace('{prompt}', positive).splitlines(), n.splitlines() -def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): - for _ in range(wildcards_max_bfs_depth): - placeholders = re.findall(r'__([\w-]+)__', wildcard_text) - if len(placeholders) == 0: - return wildcard_text - - print(f'[Wildcards] processing: {wildcard_text}') - for placeholder in placeholders: - try: - matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] - words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() - words = [x for x in words if x != ''] - assert len(words) > 0 - if read_wildcards_in_order: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) - else: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) - except: - print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' - f'Using "{placeholder}" as a normal word.') - wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) - print(f'[Wildcards] {wildcard_text}') - - print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') - return wildcard_text - - def get_words(arrays, totalMult, index): if len(arrays) == 1: return [arrays[0].split(',')[index]] diff --git a/modules/util.py b/modules/util.py index b261d79e..8b641ea0 100644 --- a/modules/util.py +++ b/modules/util.py @@ -374,13 +374,6 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') - -def makedirs_with_log(path): - try: - os.makedirs(path, exist_ok=True) - except OSError as error: - print(f'Directory {path} could not be created, reason: {error}') - def get_enabled_loras(loras: list) -> list: return [[lora[1], lora[2]] for lora in loras if lora[0]] @@ -401,4 +394,5 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo if lora[0] != "None": updated_loras.append(lora) - return updated_loras[:loras_limit] \ No newline at end of file + return updated_loras[:loras_limit] + From 1c3c9bc714c3c845bba9321d739f1820d92309db Mon Sep 17 00:00:00 2001 From: cantor-set Date: Wed, 10 Apr 2024 00:34:29 -0400 Subject: [PATCH 16/22] Renamed module --- modules/extra_utils.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 modules/extra_utils.py diff --git a/modules/extra_utils.py b/modules/extra_utils.py new file mode 100644 index 00000000..6bbf7557 --- /dev/null +++ b/modules/extra_utils.py @@ -0,0 +1,27 @@ + +import os + +#TODO: Use refactor to use glob instead +def get_files_from_folder(folder_path, extensions=None, name_filter=None): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + + for root, _, files in os.walk(folder_path, topdown=False): + relative_path = os.path.relpath(root, folder_path) + if relative_path == ".": + relative_path = "" + for filename in sorted(files, key=lambda s: s.casefold()): + _, file_extension = os.path.splitext(filename) + if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): + path = os.path.join(relative_path, filename) + filenames.append(path) + + return filenames + +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}') From 06726e795e3dac7eb34eb27320f0ed21c8becb7c Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 18 Apr 2024 09:28:30 -0400 Subject: [PATCH 17/22] Addressed PR comments --- modules/path_utils.py | 21 --------------------- modules/util.py | 2 +- tests/test_utils.py | 4 ++++ 3 files changed, 5 insertions(+), 22 deletions(-) delete mode 100644 modules/path_utils.py diff --git a/modules/path_utils.py b/modules/path_utils.py deleted file mode 100644 index 0a021520..00000000 --- a/modules/path_utils.py +++ /dev/null @@ -1,21 +0,0 @@ - -import os - -#TODO: Use refactor to use glob instead -def get_files_from_folder(folder_path, extensions=None, name_filter=None): - if not os.path.isdir(folder_path): - raise ValueError("Folder path is not a valid directory.") - - filenames = [] - - for root, _, files in os.walk(folder_path, topdown=False): - relative_path = os.path.relpath(root, folder_path) - if relative_path == ".": - relative_path = "" - for filename in sorted(files, key=lambda s: s.casefold()): - _, file_extension = os.path.splitext(filename) - if (extensions is None or file_extension.lower() in extensions) and (name_filter is None or name_filter in _): - path = os.path.join(relative_path, filename) - filenames.append(path) - - return filenames diff --git a/modules/util.py b/modules/util.py index 8b641ea0..f1b6ddf3 100644 --- a/modules/util.py +++ b/modules/util.py @@ -20,7 +20,7 @@ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.L # Regexp compiled once. Matches entries with the following pattern: # # -LORAS_PROMPT_PATTERN = re.compile(".*.*") +LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) HASH_SHA256_LENGTH = 10 diff --git a/tests/test_utils.py b/tests/test_utils.py index c9861c9b..7b1cf87c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -35,6 +35,10 @@ class TestUtils(unittest.TestCase): "input": ("some prompt, very cool, ", [], 3), "output": [("you-lora.safetensors", 0.2)], }, + { + "input": (", , and ",[], 6), + "output": [] + } ] for test in test_cases: From a27d49bf8ae1277992e9ca68c1b159b8c3d74460 Mon Sep 17 00:00:00 2001 From: cantor-set Date: Thu, 18 Apr 2024 09:33:32 -0400 Subject: [PATCH 18/22] Added return type to function --- modules/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/util.py b/modules/util.py index f1b6ddf3..7251ba14 100644 --- a/modules/util.py +++ b/modules/util.py @@ -377,7 +377,7 @@ def ordinal_suffix(number: int) -> str: def get_enabled_loras(loras: list) -> list: return [[lora[1], lora[2]] for lora in loras if lora[0]] -def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5): +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: new_loras = [] From a78b3841c9c8b79c4caeb8f64d8d228fc96c1f18 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 16:43:02 +0200 Subject: [PATCH 19/22] refactor: move apply_wildcards to module util --- modules/async_worker.py | 13 +++++++------ modules/config.py | 28 ---------------------------- modules/util.py | 26 ++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 5f03eaf4..6f52711e 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -46,8 +46,9 @@ def worker(): from modules.sdxl_styles import apply_style, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ - get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, parse_lora_references_from_prompt + from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, + parse_lora_references_from_prompt, apply_wildcards) from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -444,11 +445,11 @@ def worker(): task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not task_rng = random.Random(task_seed) # may bind to inpaint noise in the future - task_prompt = modules.config.apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) + task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) - task_negative_prompt = modules.config.apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [modules.config.apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] diff --git a/modules/config.py b/modules/config.py index 5747d1a9..090a996c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -680,32 +680,4 @@ def downloading_upscale_model(): return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') -def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order): - for _ in range(wildcards_max_bfs_depth): - placeholders = re.findall(r'__([\w-]+)__', wildcard_text) - if len(placeholders) == 0: - return wildcard_text - - print(f'[Wildcards] processing: {wildcard_text}') - for placeholder in placeholders: - try: - matches = [x for x in wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] - words = open(os.path.join(path_wildcards, matches[0]), encoding='utf-8').read().splitlines() - words = [x for x in words if x != ''] - assert len(words) > 0 - if read_wildcards_in_order: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) - else: - wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) - except: - print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' - f'Using "{placeholder}" as a normal word.') - wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) - print(f'[Wildcards] {wildcard_text}') - - print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') - return wildcard_text - - - update_files() diff --git a/modules/util.py b/modules/util.py index 7251ba14..ea5fa87e 100644 --- a/modules/util.py +++ b/modules/util.py @@ -396,3 +396,29 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo return updated_loras[:loras_limit] + +def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: + for _ in range(modules.config.wildcards_max_bfs_depth): + placeholders = re.findall(r'__([\w-]+)__', wildcard_text) + if len(placeholders) == 0: + return wildcard_text + + print(f'[Wildcards] processing: {wildcard_text}') + for placeholder in placeholders: + try: + matches = [x for x in modules.config.wildcard_filenames if os.path.splitext(os.path.basename(x))[0] == placeholder] + words = open(os.path.join(modules.config.path_wildcards, matches[0]), encoding='utf-8').read().splitlines() + words = [x for x in words if x != ''] + assert len(words) > 0 + if read_wildcards_in_order: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', words[i % len(words)], 1) + else: + wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) + except: + print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' + f'Using "{placeholder}" as a normal word.') + wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) + print(f'[Wildcards] {wildcard_text}') + + print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') + return wildcard_text From 4e610411fb372a4b3b7882dc150dbe6501d5c607 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 16:43:53 +0200 Subject: [PATCH 20/22] refactor: code cleanup, unify usage of tuples in lora list --- modules/async_worker.py | 3 +-- modules/sdxl_styles.py | 7 ++----- modules/util.py | 9 +++------ 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 6f52711e..3cbac64c 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -148,7 +148,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -428,7 +428,6 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') - # Parse lora references from prompt loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index e4b754e8..125de174 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -2,7 +2,6 @@ import os import re import json import math -#import modules.config from modules.extra_utils import get_files_from_folder @@ -10,7 +9,6 @@ from modules.extra_utils import get_files_from_folder styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) - def normalize_key(k): k = k.replace('-', ' ') words = k.split(' ') @@ -24,7 +22,6 @@ def normalize_key(k): styles = {} - styles_files = get_files_from_folder(styles_path, ['.json']) for x in ['sdxl_styles_fooocus.json', @@ -59,7 +56,7 @@ def apply_style(style, positive): return p.replace('{prompt}', positive).splitlines(), n.splitlines() -def get_words(arrays, totalMult, index): +def get_words(arrays, total_mult, index): if len(arrays) == 1: return [arrays[0].split(',')[index]] else: @@ -68,7 +65,7 @@ def get_words(arrays, totalMult, index): index -= index % len(words) index /= len(words) index = math.floor(index) - return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) + return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index) def apply_arrays(text, index): diff --git a/modules/util.py b/modules/util.py index ea5fa87e..234a8af5 100644 --- a/modules/util.py +++ b/modules/util.py @@ -375,22 +375,19 @@ def ordinal_suffix(number: int) -> str: def get_enabled_loras(loras: list) -> list: - return [[lora[1], lora[2]] for lora in loras if lora[0]] + return [(lora[1], lora[2]) for lora in loras if lora[0]] + def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: - new_loras = [] - + updated_loras = [] for token in prompt.split(","): - m = LORAS_PROMPT_PATTERN.match(token) if m: new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) - updated_loras = [] for lora in loras + new_loras: - if lora[0] != "None": updated_loras.append(lora) From f4fc21d05d49d726680c67ea794759d6360be811 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 16:55:39 +0200 Subject: [PATCH 21/22] docs: add instructions for running unittests on embedded python, code cleanup --- development.md | 6 ++++++ tests/test_utils.py | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/development.md b/development.md index a402a872..bbb3def9 100644 --- a/development.md +++ b/development.md @@ -1,5 +1,11 @@ ## Running unit tests + +Native python: ``` python -m unittest tests/ ``` +Embedded python (Windows zip file installation method): +``` +..\python_embeded\python.exe -m unittest +``` diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b1cf87c..0698dcc8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,10 @@ import unittest + from modules import util class TestUtils(unittest.TestCase): def test_can_parse_tokens_with_lora(self): - test_cases = [ { "input": ("some prompt, very cool, , cool ", [], 5), @@ -36,13 +36,13 @@ class TestUtils(unittest.TestCase): "output": [("you-lora.safetensors", 0.2)], }, { - "input": (", , and ",[], 6), - "output": [] + "input": (", , and ", [], 6), + "output": [] } ] for test in test_cases: - promp, loras, loras_limit = test["input"] + prompt, loras, loras_limit = test["input"] expected = test["output"] - actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit) + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit) self.assertEqual(expected, actual) From 80cad7787f28a2b6057ac132bf6eede96fb9518f Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 17:18:23 +0200 Subject: [PATCH 22/22] refactor: code cleanup, move makedirs_with_log back to util --- modules/async_worker.py | 42 ++++++++++++++++++++++++++--------------- modules/config.py | 5 +++-- modules/extra_utils.py | 9 +-------- modules/util.py | 8 +++++++- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index e8a068db..7f0a46e3 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all patch_all() + class AsyncTask: def __init__(self, args): self.args = args @@ -44,12 +45,12 @@ def worker(): import args_manager from extras.censor import censor_batch, censor_single - from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name + from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, - get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, - parse_lora_references_from_prompt, apply_wildcards) + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, + parse_lora_references_from_prompt, apply_wildcards) from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -70,7 +71,8 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, + progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] @@ -153,7 +155,8 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in + range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -203,7 +206,8 @@ def worker(): inpaint_erode_or_dilate = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS + metadata_scheme = MetadataScheme( + args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS cn_tasks = {x: [] for x in flags.ip_list} for _ in range(flags.controlnet_image_count): @@ -443,7 +447,7 @@ def worker(): progressbar(async_task, 3, 'Processing prompts ...') tasks = [] - + for i in range(image_number): if disable_seed_increment: task_seed = seed % (constants.MAX_SEED + 1) @@ -454,8 +458,10 @@ def worker(): task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] @@ -656,7 +662,8 @@ def worker(): ) if debugging_inpaint_preprocessor: - yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, + do_not_show_finished_images=True) return progressbar(async_task, 13, 'VAE Inpaint encoding ...') @@ -811,7 +818,8 @@ def worker(): done_steps = current_task_id * steps + step async_task.yields.append(['preview', ( int(15.0 + 85.0 * float(done_steps) / float(all_steps)), - f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) + f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', + y)]) for current_task_id, task in enumerate(tasks): execution_start_time = time.perf_counter() @@ -866,7 +874,8 @@ def worker(): d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), + ('Styles', 'styles', + str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: @@ -889,7 +898,8 @@ def worker(): if refiner_swap_method != flags.refiner_swap_method: d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: - d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) + d.append( + ('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) @@ -909,11 +919,13 @@ def worker(): metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], steps, base_model_name, refiner_model_name, loras, vae_name) - d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) + d.append(('Metadata Scheme', 'metadata_scheme', + metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format, task)) - yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) + yield_result(async_task, img_paths, black_out_nsfw, False, + do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index 701222c8..11fe3181 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,7 +1,6 @@ import os import json import math -import re import numbers import args_manager import tempfile @@ -9,9 +8,11 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.extra_utils import get_files_from_folder, makedirs_with_log +from modules.util import makedirs_with_log +from modules.extra_utils import get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme + def get_config_path(key, default_value): env = os.getenv(key) if env is not None and isinstance(env, str): diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 6bbf7557..3e95e8b5 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,7 +1,6 @@ - import os -#TODO: Use refactor to use glob instead + def get_files_from_folder(folder_path, extensions=None, name_filter=None): if not os.path.isdir(folder_path): raise ValueError("Folder path is not a valid directory.") @@ -19,9 +18,3 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): filenames.append(path) return filenames - -def makedirs_with_log(path): - try: - os.makedirs(path, exist_ok=True) - except OSError as error: - print(f'Directory {path} could not be created, reason: {error}') diff --git a/modules/util.py b/modules/util.py index 7d695df7..73430230 100644 --- a/modules/util.py +++ b/modules/util.py @@ -172,7 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'): return date_string, os.path.abspath(result), filename - def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): print(f"Calculating sha256 for {filename}: ", end='') if use_addnet_hash: @@ -377,6 +376,13 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}') + + def get_enabled_loras(loras: list) -> list: return [(lora[1], lora[2]) for lora in loras if lora[0]]