diff --git a/modules/async_worker.py b/modules/async_worker.py index d7d9b9fd..9c16d6fc 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -462,8 +462,10 @@ def worker(): progressbar(async_task, 2, 'Loading models ...') - loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) + lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection) + loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames) loras += performance_loras + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) diff --git a/modules/util.py b/modules/util.py index 3f712214..09d36770 100644 --- a/modules/util.py +++ b/modules/util.py @@ -398,10 +398,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list: def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, - skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True, + lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]: + if lora_filenames is None: + lora_filenames = [] + found_loras = [] prompt_without_loras = '' cleaned_prompt = '' + for token in prompt.split(','): matches = LORAS_PROMPT_PATTERN.findall(token) @@ -411,7 +416,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo for match in matches: lora_name = match[1] + '.safetensors' if not skip_file_check: - lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special) + lora_name = get_filname_by_stem(match[1], lora_filenames) if lora_name is not None: found_loras.append((lora_name, float(match[2]))) token = token.replace(match[0], '') diff --git a/tests/test_utils.py b/tests/test_utils.py index 6fd550db..c1f49c13 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,7 @@ +import os import unittest +import modules.flags from modules import util @@ -77,5 +79,59 @@ class TestUtils(unittest.TestCase): for test in test_cases: prompt, loras, loras_limit, skip_file_check = test["input"] expected = test["output"] - actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check) + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, + skip_file_check=skip_file_check) + self.assertEqual(expected, actual) + + def test_can_parse_tokens_and_strip_performance_lora(self): + lora_filenames = [ + 'hey-lora.safetensors', + modules.flags.PerformanceLoRA.EXTREME_SPEED.value, + modules.flags.PerformanceLoRA.LIGHTNING.value, + os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value) + ] + + test_cases = [ + { + "input": ("some prompt, ", [], 5, True, modules.flags.Performance.QUALITY), + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt' + ), + }, + { + "input": ("some prompt, ", [], 5, True, modules.flags.Performance.SPEED), + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt' + ), + }, + { + "input": ("some prompt, , ", [], 5, True, modules.flags.Performance.EXTREME_SPEED), + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt' + ), + }, + { + "input": ("some prompt, , ", [], 5, True, modules.flags.Performance.LIGHTNING), + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt' + ), + }, + { + "input": ("some prompt, , ", [], 5, True, modules.flags.Performance.HYPER_SD), + "output": ( + [('hey-lora.safetensors', 0.4)], + 'some prompt' + ), + } + ] + + for test in test_cases: + prompt, loras, loras_limit, skip_file_check, performance = test["input"] + lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance) + expected = test["output"] + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames) self.assertEqual(expected, actual)