Added unittests
This commit is contained in:
parent
31263b40f6
commit
24acbc39fe
|
|
@ -52,6 +52,8 @@ def worker():
|
|||
)
|
||||
from modules.upscaler import perform_upscale
|
||||
|
||||
MAX_LORAS = 5
|
||||
|
||||
try:
|
||||
async_gradio_app = shared.gradio_root
|
||||
flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
|
||||
|
|
@ -140,7 +142,7 @@ def worker():
|
|||
base_model_name = args.pop()
|
||||
refiner_model_name = args.pop()
|
||||
refiner_switch = args.pop()
|
||||
loras = [[str(args.pop()), float(args.pop())] for _ in range(5)]
|
||||
loras = [[str(args.pop()), float(args.pop())] for _ in range(MAX_LORAS)]
|
||||
input_image_checkbox = args.pop()
|
||||
current_tab = args.pop()
|
||||
uov_method = args.pop()
|
||||
|
|
@ -381,7 +383,7 @@ def worker():
|
|||
progressbar(async_task, 3, 'Loading models ...')
|
||||
|
||||
# Parse lora references from prompt
|
||||
loras = parse_lora_references_from_prompt(prompt, loras)
|
||||
loras = parse_lora_references_from_prompt(prompt, loras, loras_limit=MAX_LORAS)
|
||||
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ from PIL import Image
|
|||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
# Regexp compiled once. Matches entries with the following pattern:
|
||||
# <lora:some_lora:1>
|
||||
# <lora:aNotherLora:-1.6>
|
||||
LORAS_PROMPT_PATTERN = re.compile(".*<lora:(.+):([-+]?(?:\d*\.*\d*))>.*")
|
||||
|
||||
def erode_or_dilate(x, k):
|
||||
k = int(k)
|
||||
|
|
@ -183,13 +187,13 @@ def ordinal_suffix(number: int) -> str:
|
|||
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
|
||||
|
||||
|
||||
def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]]):
|
||||
pattern = re.compile(".*<lora:(.+):(([0-9]*[.])?[0-9]+)>.*")
|
||||
def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5):
|
||||
|
||||
new_loras = []
|
||||
|
||||
for token in items.split(","):
|
||||
print(token)
|
||||
m = pattern.match(token)
|
||||
|
||||
m = LORAS_PROMPT_PATTERN.match(token)
|
||||
|
||||
if m:
|
||||
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
||||
|
|
@ -200,4 +204,4 @@ def parse_lora_references_from_prompt(items: str, loras: List[Tuple[AnyStr, floa
|
|||
if lora[0] != "None":
|
||||
updated_loras.append(lora)
|
||||
|
||||
return updated_loras
|
||||
return updated_loras[:loras_limit]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
import sys
|
||||
import pathlib
|
||||
|
||||
sys.path.append(pathlib.Path(f'{__file__}/../modules').parent.resolve())
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
import unittest
|
||||
from modules import util
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_can_parse_tokens_with_lora(self):
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
|
||||
},
|
||||
# Test can not exceed limit
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||
"output": [("hey-lora.safetensors", 0.4)],
|
||||
},
|
||||
# test Loras from UI take precedence over prompt
|
||||
{
|
||||
"input": (
|
||||
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
||||
[("hey-lora.safetensors", 0.4)],
|
||||
5,
|
||||
),
|
||||
"output": [
|
||||
("hey-lora.safetensors", 0.4),
|
||||
("l1.safetensors", 0.4),
|
||||
("l2.safetensors", 0.2),
|
||||
("l3.safetensors", 0.3),
|
||||
("l4.safetensors", 0.5),
|
||||
],
|
||||
},
|
||||
# Test lora specification not separated by comma are ignored, only latest specified is used
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||
"output": [("you-lora.safetensors", 0.2)],
|
||||
},
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
promp, loras, loras_limit = test["input"]
|
||||
expected = test["output"]
|
||||
actual = util.parse_lora_references_from_prompt(promp, loras, loras_limit)
|
||||
self.assertEqual(expected, actual)
|
||||
Loading…
Reference in New Issue