diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/async_worker.py b/modules/async_worker.py index 37ab09ae..5fa5c0bd 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -52,6 +52,8 @@ def worker(): ) from modules.upscaler import perform_upscale + MAX_LORAS = 5 + try: async_gradio_app = shared.gradio_root flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}''' @@ -140,7 +142,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(5)] + loras = [[str(args.pop()), float(args.pop())] for _ in range(MAX_LORAS)] input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -381,7 +383,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') # Parse lora references from prompt - loras = parse_lora_references_from_prompt(prompt, loras) + loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=MAX_LORAS) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, diff --git a/modules/util.py b/modules/util.py index 2355dae3..d350f3c6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,6 +12,10 @@ from PIL import Image LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +# Regexp compiled once. Matches entries with the following pattern: +# +# +LORAS_PROMPT_PATTERN = re.compile(".*.*") def erode_or_dilate(x, k): k = int(k) @@ -183,13 +187,13 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') -def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]): - pattern = re.compile(".*.*") +def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5): + new_loras = [] for token in items.split(","): - print(token) - m = pattern.match(token) + + m = LORAS_PROMPT_PATTERN.match(token) if m: new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) @@ -200,4 +204,4 @@ def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, floa if lora[0] != "None": updated_loras.append(lora) - return updated_loras + return updated_loras[:loras_limit] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..f86b4227 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve()) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..998bf058 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,44 @@ +import unittest +from modules import util + + +class TestUtils(unittest.TestCase): + def test_can_parse_tokens_with_lora(self): + + test_cases = [ + { + "input": ("some prompt, very cool, , cool ", [], 5), + "output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], + }, + # Test can not exceed limit + { + "input": ("some prompt, very cool, , cool ", [], 1), + "output": [("hey-lora.safetensors", 0.4)], + }, + # test Loras from UI take precedence over prompt + { + "input": ( + "some prompt, very cool, , , , , , ", + [("hey-lora.safetensors", 0.4)], + 5, + ), + "output": [ + ("hey-lora.safetensors", 0.4), + ("l1.safetensors", 0.4), + ("l2.safetensors", 0.2), + ("l3.safetensors", 0.3), + ("l4.safetensors", 0.5), + ], + }, + # Test lora specification not separated by comma are ignored, only latest specified is used + { + "input": ("some prompt, very cool, ", [], 3), + "output": [("you-lora.safetensors", 0.2)], + }, + ] + + for test in test_cases: + promp, loras, loras_limit = test["input"] + expected = test["output"] + actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit) + self.assertEqual(expected, actual)