Added unittests

This commit is contained in:
cantor-set 2024-02-22 18:42:30 -05:00
parent 31263b40f6
commit 24acbc39fe
5 changed files with 61 additions and 7 deletions

0
modules/__init__.py Normal file
View File

View File

@ -52,6 +52,8 @@ def worker():
)
from modules.upscaler import perform_upscale
MAX_LORAS = 5
try:
async_gradio_app = shared.gradio_root
flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
@ -140,7 +142,7 @@ def worker():
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = [[str(args.pop()), float(args.pop())] for _ in range(5)]
loras = [[str(args.pop()), float(args.pop())] for _ in range(MAX_LORAS)]
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
@ -381,7 +383,7 @@ def worker():
progressbar(async_task, 3, 'Loading models ...')
# Parse lora references from prompt
loras = parse_lora_references_from_prompt(prompt, loras)
loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=MAX_LORAS)
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,

View File

@ -12,6 +12,10 @@ from PIL import Image
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(".*<lora:(.+):([-+]?(?:\d*\.*\d*))>.*")
def erode_or_dilate(x, k):
k = int(k)
@ -183,13 +187,13 @@ def ordinal_suffix(number: int) -> str:
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]):
pattern = re.compile(".*<lora:(.+):(([0-9]*[.])?[0-9]+)>.*")
def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5):
new_loras = []
for token in items.split(","):
print(token)
m = pattern.match(token)
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
@ -200,4 +204,4 @@ def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, floa
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras
return updated_loras[:loras_limit]

4
tests/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())

44
tests/test_utils.py Normal file
View File

@ -0,0 +1,44 @@
import unittest
from modules import util
class TestUtils(unittest.TestCase):
def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)],
},
# test Loras from UI take precedence over prompt
{
"input": (
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
[("hey-lora.safetensors", 0.4)],
5,
),
"output": [
("hey-lora.safetensors", 0.4),
("l1.safetensors", 0.4),
("l2.safetensors", 0.2),
("l3.safetensors", 0.3),
("l4.safetensors", 0.5),
],
},
# Test lora specification not separated by comma are ignored, only latest specified is used
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)],
},
]
for test in test_cases:
promp, loras, loras_limit = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit)
self.assertEqual(expected, actual)