From 65a8b25129c52ccb6f9fe5933202712b533c977d Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Mon, 20 May 2024 17:31:51 +0200 Subject: [PATCH 1/3] feat: inline lora optimisations (#2967) * feat: add performance loras to the end of the loras array * fix: resolve circular dependency for unit tests * feat: allow multiple matches for each token, optimize and extract method cleanup_prompt * fix: update unit tests * feat: ignore custom wildcards --- modules/async_worker.py | 12 +++++---- modules/config.py | 3 +-- modules/extra_utils.py | 6 +++++ modules/util.py | 60 +++++++++++++++++++++++++++++------------ tests/test_utils.py | 41 +++++++++++++++++++--------- wildcards/.gitignore | 8 ++++++ 6 files changed, 93 insertions(+), 37 deletions(-) create mode 100644 wildcards/.gitignore diff --git a/modules/async_worker.py b/modules/async_worker.py index 8e4d6d95..594886d2 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -237,10 +237,12 @@ def worker(): steps = performance_selection.steps() + performance_loras = [] + if performance_selection == Performance.EXTREME_SPEED: print('Enter LCM mode.') progressbar(async_task, 1, 'Downloading LCM components ...') - loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] + performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] if refiner_model_name != 'None': print(f'Refiner disabled in LCM mode.') @@ -259,7 +261,7 @@ def worker(): elif performance_selection == Performance.LIGHTNING: print('Enter Lightning mode.') progressbar(async_task, 1, 'Downloading Lightning components ...') - loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] + performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] if refiner_model_name != 'None': print(f'Refiner disabled in Lightning mode.') @@ -278,7 +280,7 @@ def worker(): elif performance_selection == Performance.HYPER_SD: print('Enter Hyper-SD mode.') progressbar(async_task, 1, 'Downloading Hyper-SD components ...') - loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)] + performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)] if refiner_model_name != 'None': print(f'Refiner disabled in Hyper-SD mode.') @@ -458,8 +460,8 @@ def worker(): progressbar(async_task, 2, 'Loading models ...') - loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) - + loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) + loras += performance_loras pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) diff --git a/modules/config.py b/modules/config.py index db7036c5..913fb281 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,8 +8,7 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.util import makedirs_with_log -from modules.extra_utils import get_files_from_folder +from modules.extra_utils import makedirs_with_log, get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 3e95e8b5..9906c820 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,5 +1,11 @@ import os +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}') + def get_files_from_folder(folder_path, extensions=None, name_filter=None): if not os.path.isdir(folder_path): diff --git a/modules/util.py b/modules/util.py index 8e85ffbe..52bc490a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,15 +12,15 @@ import hashlib from PIL import Image +import modules.config import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - # Regexp compiled once. Matches entries with the following pattern: # # -LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) +LORAS_PROMPT_PATTERN = re.compile(r"()", re.X) HASH_SHA256_LENGTH = 10 @@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders): return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) -def makedirs_with_log(path): - try: - os.makedirs(path, exist_ok=True) - except OSError as error: - print(f'Directory {path} could not be created, reason: {error}') +def get_enabled_loras(loras: list, remove_none=True) -> list: + return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)] -def get_enabled_loras(loras: list) -> list: - return [(lora[1], lora[2]) for lora in loras if lora[0]] +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, + prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + found_loras = [] + prompt_without_loras = "" + for token in prompt.split(" "): + matches = LORAS_PROMPT_PATTERN.findall(token) + if matches: + for match in matches: + found_loras.append((f"{match[1]}.safetensors", float(match[2]))) + prompt_without_loras += token.replace(match[0], '') + else: + prompt_without_loras += token + prompt_without_loras += ' ' + + cleaned_prompt = prompt_without_loras[:-1] + if prompt_cleanup: + cleaned_prompt = cleanup_prompt(prompt_without_loras) -def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: new_loras = [] + lora_names = [lora[0] for lora in loras] + for found_lora in found_loras: + if deduplicate_loras and found_lora[0] in lora_names: + continue + new_loras.append(found_lora) + + if len(new_loras) == 0: + return loras, cleaned_prompt + updated_loras = [] - for token in prompt.split(","): - m = LORAS_PROMPT_PATTERN.match(token) - - if m: - new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) - for lora in loras + new_loras: if lora[0] != "None": updated_loras.append(lora) - return updated_loras[:loras_limit] + return updated_loras[:loras_limit], cleaned_prompt + + +def cleanup_prompt(prompt): + prompt = re.sub(' +', ' ', prompt) + prompt = re.sub(',+', ',', prompt) + cleaned_prompt = '' + for token in prompt.split(','): + token = token.strip() + if token == '': + continue + cleaned_prompt += token + ', ' + return cleaned_prompt[:-2] def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: diff --git a/tests/test_utils.py b/tests/test_utils.py index 0698dcc8..9f81005b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,12 +8,16 @@ class TestUtils(unittest.TestCase): test_cases = [ { "input": ("some prompt, very cool, , cool ", [], 5), - "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + "output": ( + [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), }, # Test can not exceed limit { "input": ("some prompt, very cool, , cool ", [], 1), - "output": [("hey-lora.safetensors", 0.4)], + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt, very cool, cool' + ), }, # test Loras from UI take precedence over prompt { @@ -22,22 +26,33 @@ class TestUtils(unittest.TestCase): [("hey-lora.safetensors", 0.4)], 5, ), - "output": [ - ("hey-lora.safetensors", 0.4), - ("l1.safetensors", 0.4), - ("l2.safetensors", -0.2), - ("l3.safetensors", 0.3), - ("l4.safetensors", 0.5), - ], + "output": ( + [ + ('hey-lora.safetensors', 0.4), + ('l1.safetensors', 0.4), + ('l2.safetensors', -0.2), + ('l3.safetensors', 0.3), + ('l4.safetensors', 0.5) + ], + 'some prompt, very cool' + ) }, - # Test lora specification not separated by comma are ignored, only latest specified is used { "input": ("some prompt, very cool, ", [], 3), - "output": [("you-lora.safetensors", 0.2)], + "output": ( + [ + ('hey-lora.safetensors', 0.4), + ('you-lora.safetensors', 0.2) + ], + 'some prompt, very cool, ' + ), }, { - "input": (", , and ", [], 6), - "output": [] + "input": (", , , and ", [], 6), + "output": ( + [], + ', , , and ' + ) } ] diff --git a/wildcards/.gitignore b/wildcards/.gitignore new file mode 100644 index 00000000..7e4ac188 --- /dev/null +++ b/wildcards/.gitignore @@ -0,0 +1,8 @@ +*.txt +!animal.txt +!artist.txt +!color.txt +!color_flower.txt +!extended-color.txt +!flower.txt +!nationality.txt \ No newline at end of file From ac14d9d03ce731c0f57961ab1fde9c4e276bad99 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Mon, 20 May 2024 17:33:12 +0200 Subject: [PATCH 2/3] feat: change code owner from @lllyasviel to @mashb1t (#2948) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 585eb87a..f9876685 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @lllyasviel +* @mashb1t From 045d03ddada61d21fd3369fd1a92de62dee37cc8 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 20 May 2024 19:08:35 +0200 Subject: [PATCH 3/3] feat: only use valid inline loras, add subfolder support --- modules/config.py | 14 +++++++++++++- modules/meta_parser.py | 21 ++++----------------- modules/util.py | 41 +++++++++++++++++++++++++++++------------ tests/test_utils.py | 32 +++++++++++++++++++++++++------- 4 files changed, 71 insertions(+), 37 deletions(-) diff --git a/modules/config.py b/modules/config.py index 913fb281..94046661 100644 --- a/modules/config.py +++ b/modules/config.py @@ -547,6 +547,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file: model_filenames = [] lora_filenames = [] +lora_filenames_no_special = [] vae_filenames = [] wildcard_filenames = [] @@ -556,6 +557,16 @@ sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors' loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora] +def remove_special_loras(lora_filenames): + global loras_metadata_remove + + loras_no_special = lora_filenames.copy() + for lora_to_remove in loras_metadata_remove: + if lora_to_remove in loras_no_special: + loras_no_special.remove(lora_to_remove) + return loras_no_special + + def get_model_filenames(folder_paths, extensions=None, name_filter=None): if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] @@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): def update_files(): - global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets + global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) + lora_filenames_no_special = remove_special_loras(lora_filenames) vae_filenames = get_model_filenames(path_vae) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) available_presets = get_presets() diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 84032e82..2469da5f 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): def get_sha256(filepath): global hash_cache if filepath not in hash_cache: - # is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors' hash_cache[filepath] = sha256(filepath) return hash_cache[filepath] @@ -293,12 +292,6 @@ class MetadataParser(ABC): self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) self.vae_name = Path(vae_name).stem - @staticmethod - def remove_special_loras(lora_filenames): - for lora_to_remove in modules.config.loras_metadata_remove: - if lora_to_remove in lora_filenames: - lora_filenames.remove(lora_to_remove) - class A1111MetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -415,13 +408,11 @@ class A1111MetadataParser(MetadataParser): lora_data = data['lora_hashes'] if lora_data != '': - lora_filenames = modules.config.lora_filenames.copy() - self.remove_special_loras(lora_filenames) for li, lora in enumerate(lora_data.split(', ')): lora_split = lora.split(': ') lora_name = lora_split[0] lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] - for filename in lora_filenames: + for filename in modules.config.lora_filenames_no_special: path = Path(filename) if lora_name == path.stem: data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' @@ -510,19 +501,15 @@ class FooocusMetadataParser(MetadataParser): return MetadataScheme.FOOOCUS def parse_json(self, metadata: dict) -> dict: - model_filenames = modules.config.model_filenames.copy() - lora_filenames = modules.config.lora_filenames.copy() - vae_filenames = modules.config.vae_filenames.copy() - self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: continue if key in ['base_model', 'refiner_model']: - metadata[key] = self.replace_value_with_filename(key, value, model_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames) elif key.startswith('lora_combined_'): - metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special) elif key == 'vae': - metadata[key] = self.replace_value_with_filename(key, value, vae_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames) else: continue diff --git a/modules/util.py b/modules/util.py index 52bc490a..cb5580fb 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import datetime import random @@ -360,6 +362,14 @@ def is_json(data: str) -> bool: return True +def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None: + for filename in filenames: + path = Path(filename) + if lora_name == path.stem: + return filename + return None + + def get_file_from_folder_list(name, folders): if not isinstance(folders, list): folders = [folders] @@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list: def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, - prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: found_loras = [] - prompt_without_loras = "" - for token in prompt.split(" "): + prompt_without_loras = '' + cleaned_prompt = '' + for token in prompt.split(','): matches = LORAS_PROMPT_PATTERN.findall(token) - if matches: - for match in matches: - found_loras.append((f"{match[1]}.safetensors", float(match[2]))) - prompt_without_loras += token.replace(match[0], '') - else: - prompt_without_loras += token - prompt_without_loras += ' ' + if len(matches) == 0: + prompt_without_loras += token + ', ' + continue + for match in matches: + lora_name = match[1] + '.safetensors' + if not skip_file_check: + lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special) + if lora_name is not None: + found_loras.append((lora_name, float(match[2]))) + token = token.replace(match[0], '') + prompt_without_loras += token + ', ' + + if prompt_without_loras != '': + cleaned_prompt = prompt_without_loras[:-2] - cleaned_prompt = prompt_without_loras[:-1] if prompt_cleanup: cleaned_prompt = cleanup_prompt(prompt_without_loras) new_loras = [] lora_names = [lora[0] for lora in loras] for found_lora in found_loras: - if deduplicate_loras and found_lora[0] in lora_names: + if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras): continue new_loras.append(found_lora) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9f81005b..6fd550db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase): def test_can_parse_tokens_with_lora(self): test_cases = [ { - "input": ("some prompt, very cool, , cool ", [], 5), + "input": ("some prompt, very cool, , cool ", [], 5, True), "output": ( [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), }, # Test can not exceed limit { - "input": ("some prompt, very cool, , cool ", [], 1), + "input": ("some prompt, very cool, , cool ", [], 1, True), "output": ( [('hey-lora.safetensors', 0.4)], 'some prompt, very cool, cool' @@ -25,6 +25,7 @@ class TestUtils(unittest.TestCase): "some prompt, very cool, , , , , , ", [("hey-lora.safetensors", 0.4)], 5, + True ), "output": ( [ @@ -37,18 +38,35 @@ class TestUtils(unittest.TestCase): 'some prompt, very cool' ) }, + # test correct matching even if there is no space separating loras in the same token { - "input": ("some prompt, very cool, ", [], 3), + "input": ("some prompt, very cool, ", [], 3, True), "output": ( [ ('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2) ], - 'some prompt, very cool, ' + 'some prompt, very cool' + ), + }, + # test deduplication, also selected loras are never overridden with loras in prompt + { + "input": ( + "some prompt, very cool, ", + [('you-lora.safetensors', 0.3)], + 3, + True + ), + "output": ( + [ + ('you-lora.safetensors', 0.3), + ('hey-lora.safetensors', 0.4) + ], + 'some prompt, very cool' ), }, { - "input": (", , , and ", [], 6), + "input": (", , , and ", [], 6, True), "output": ( [], ', , , and ' @@ -57,7 +75,7 @@ class TestUtils(unittest.TestCase): ] for test in test_cases: - prompt, loras, loras_limit = test["input"] + prompt, loras, loras_limit, skip_file_check = test["input"] expected = test["output"] - actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit) + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check) self.assertEqual(expected, actual)