From 65a8b25129c52ccb6f9fe5933202712b533c977d Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Mon, 20 May 2024 17:31:51 +0200 Subject: [PATCH] feat: inline lora optimisations (#2967) * feat: add performance loras to the end of the loras array * fix: resolve circular dependency for unit tests * feat: allow multiple matches for each token, optimize and extract method cleanup_prompt * fix: update unit tests * feat: ignore custom wildcards --- modules/async_worker.py | 12 +++++---- modules/config.py | 3 +-- modules/extra_utils.py | 6 +++++ modules/util.py | 60 +++++++++++++++++++++++++++++------------ tests/test_utils.py | 41 +++++++++++++++++++--------- wildcards/.gitignore | 8 ++++++ 6 files changed, 93 insertions(+), 37 deletions(-) create mode 100644 wildcards/.gitignore diff --git a/modules/async_worker.py b/modules/async_worker.py index 8e4d6d95..594886d2 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -237,10 +237,12 @@ def worker(): steps = performance_selection.steps() + performance_loras = [] + if performance_selection == Performance.EXTREME_SPEED: print('Enter LCM mode.') progressbar(async_task, 1, 'Downloading LCM components ...') - loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] + performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] if refiner_model_name != 'None': print(f'Refiner disabled in LCM mode.') @@ -259,7 +261,7 @@ def worker(): elif performance_selection == Performance.LIGHTNING: print('Enter Lightning mode.') progressbar(async_task, 1, 'Downloading Lightning components ...') - loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] + performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] if refiner_model_name != 'None': print(f'Refiner disabled in Lightning mode.') @@ -278,7 +280,7 @@ def worker(): elif performance_selection == Performance.HYPER_SD: print('Enter Hyper-SD mode.') progressbar(async_task, 1, 'Downloading Hyper-SD components ...') - loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)] + performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)] if refiner_model_name != 'None': print(f'Refiner disabled in Hyper-SD mode.') @@ -458,8 +460,8 @@ def worker(): progressbar(async_task, 2, 'Loading models ...') - loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) - + loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) + loras += performance_loras pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) diff --git a/modules/config.py b/modules/config.py index db7036c5..913fb281 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,8 +8,7 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.util import makedirs_with_log -from modules.extra_utils import get_files_from_folder +from modules.extra_utils import makedirs_with_log, get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 3e95e8b5..9906c820 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,5 +1,11 @@ import os +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}') + def get_files_from_folder(folder_path, extensions=None, name_filter=None): if not os.path.isdir(folder_path): diff --git a/modules/util.py b/modules/util.py index 8e85ffbe..52bc490a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,15 +12,15 @@ import hashlib from PIL import Image +import modules.config import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - # Regexp compiled once. Matches entries with the following pattern: # # -LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) +LORAS_PROMPT_PATTERN = re.compile(r"()", re.X) HASH_SHA256_LENGTH = 10 @@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders): return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) -def makedirs_with_log(path): - try: - os.makedirs(path, exist_ok=True) - except OSError as error: - print(f'Directory {path} could not be created, reason: {error}') +def get_enabled_loras(loras: list, remove_none=True) -> list: + return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)] -def get_enabled_loras(loras: list) -> list: - return [(lora[1], lora[2]) for lora in loras if lora[0]] +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, + prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + found_loras = [] + prompt_without_loras = "" + for token in prompt.split(" "): + matches = LORAS_PROMPT_PATTERN.findall(token) + if matches: + for match in matches: + found_loras.append((f"{match[1]}.safetensors", float(match[2]))) + prompt_without_loras += token.replace(match[0], '') + else: + prompt_without_loras += token + prompt_without_loras += ' ' + + cleaned_prompt = prompt_without_loras[:-1] + if prompt_cleanup: + cleaned_prompt = cleanup_prompt(prompt_without_loras) -def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: new_loras = [] + lora_names = [lora[0] for lora in loras] + for found_lora in found_loras: + if deduplicate_loras and found_lora[0] in lora_names: + continue + new_loras.append(found_lora) + + if len(new_loras) == 0: + return loras, cleaned_prompt + updated_loras = [] - for token in prompt.split(","): - m = LORAS_PROMPT_PATTERN.match(token) - - if m: - new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) - for lora in loras + new_loras: if lora[0] != "None": updated_loras.append(lora) - return updated_loras[:loras_limit] + return updated_loras[:loras_limit], cleaned_prompt + + +def cleanup_prompt(prompt): + prompt = re.sub(' +', ' ', prompt) + prompt = re.sub(',+', ',', prompt) + cleaned_prompt = '' + for token in prompt.split(','): + token = token.strip() + if token == '': + continue + cleaned_prompt += token + ', ' + return cleaned_prompt[:-2] def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: diff --git a/tests/test_utils.py b/tests/test_utils.py index 0698dcc8..9f81005b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,12 +8,16 @@ class TestUtils(unittest.TestCase): test_cases = [ { "input": ("some prompt, very cool, , cool ", [], 5), - "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + "output": ( + [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), }, # Test can not exceed limit { "input": ("some prompt, very cool, , cool ", [], 1), - "output": [("hey-lora.safetensors", 0.4)], + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt, very cool, cool' + ), }, # test Loras from UI take precedence over prompt { @@ -22,22 +26,33 @@ class TestUtils(unittest.TestCase): [("hey-lora.safetensors", 0.4)], 5, ), - "output": [ - ("hey-lora.safetensors", 0.4), - ("l1.safetensors", 0.4), - ("l2.safetensors", -0.2), - ("l3.safetensors", 0.3), - ("l4.safetensors", 0.5), - ], + "output": ( + [ + ('hey-lora.safetensors', 0.4), + ('l1.safetensors', 0.4), + ('l2.safetensors', -0.2), + ('l3.safetensors', 0.3), + ('l4.safetensors', 0.5) + ], + 'some prompt, very cool' + ) }, - # Test lora specification not separated by comma are ignored, only latest specified is used { "input": ("some prompt, very cool, ", [], 3), - "output": [("you-lora.safetensors", 0.2)], + "output": ( + [ + ('hey-lora.safetensors', 0.4), + ('you-lora.safetensors', 0.2) + ], + 'some prompt, very cool, ' + ), }, { - "input": (", , and ", [], 6), - "output": [] + "input": (", , , and ", [], 6), + "output": ( + [], + ', , , and ' + ) } ] diff --git a/wildcards/.gitignore b/wildcards/.gitignore new file mode 100644 index 00000000..7e4ac188 --- /dev/null +++ b/wildcards/.gitignore @@ -0,0 +1,8 @@ +*.txt +!animal.txt +!artist.txt +!color.txt +!color_flower.txt +!extended-color.txt +!flower.txt +!nationality.txt \ No newline at end of file