138 lines
5.5 KiB
Python
138 lines
5.5 KiB
Python
import os
|
|
import unittest
|
|
|
|
import modules.flags
|
|
from modules import util
|
|
|
|
|
|
class TestUtils(unittest.TestCase):
|
|
def test_can_parse_tokens_with_lora(self):
|
|
test_cases = [
|
|
{
|
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
|
},
|
|
# Test can not exceed limit
|
|
{
|
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt, very cool, cool'
|
|
),
|
|
},
|
|
# test Loras from UI take precedence over prompt
|
|
{
|
|
"input": (
|
|
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
|
[("hey-lora.safetensors", 0.4)],
|
|
5,
|
|
True
|
|
),
|
|
"output": (
|
|
[
|
|
('hey-lora.safetensors', 0.4),
|
|
('l1.safetensors', 0.4),
|
|
('l2.safetensors', -0.2),
|
|
('l3.safetensors', 0.3),
|
|
('l4.safetensors', 0.5)
|
|
],
|
|
'some prompt, very cool'
|
|
)
|
|
},
|
|
# test correct matching even if there is no space separating loras in the same token
|
|
{
|
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
|
|
"output": (
|
|
[
|
|
('hey-lora.safetensors', 0.4),
|
|
('you-lora.safetensors', 0.2)
|
|
],
|
|
'some prompt, very cool'
|
|
),
|
|
},
|
|
# test deduplication, also selected loras are never overridden with loras in prompt
|
|
{
|
|
"input": (
|
|
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
|
|
[('you-lora.safetensors', 0.3)],
|
|
3,
|
|
True
|
|
),
|
|
"output": (
|
|
[
|
|
('you-lora.safetensors', 0.3),
|
|
('hey-lora.safetensors', 0.4)
|
|
],
|
|
'some prompt, very cool'
|
|
),
|
|
},
|
|
{
|
|
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
|
|
"output": (
|
|
[],
|
|
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
|
)
|
|
}
|
|
]
|
|
|
|
for test in test_cases:
|
|
prompt, loras, loras_limit, skip_file_check = test["input"]
|
|
expected = test["output"]
|
|
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit,
|
|
skip_file_check=skip_file_check)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_can_parse_tokens_and_strip_performance_lora(self):
|
|
lora_filenames = [
|
|
'hey-lora.safetensors',
|
|
modules.flags.PerformanceLoRA.EXTREME_SPEED.value,
|
|
modules.flags.PerformanceLoRA.LIGHTNING.value,
|
|
os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value)
|
|
]
|
|
|
|
test_cases = [
|
|
{
|
|
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt'
|
|
),
|
|
},
|
|
{
|
|
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt'
|
|
),
|
|
},
|
|
{
|
|
"input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt'
|
|
),
|
|
},
|
|
{
|
|
"input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt'
|
|
),
|
|
},
|
|
{
|
|
"input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD),
|
|
"output": (
|
|
[('hey-lora.safetensors', 0.4)],
|
|
'some prompt'
|
|
),
|
|
}
|
|
]
|
|
|
|
for test in test_cases:
|
|
prompt, loras, loras_limit, skip_file_check, performance = test["input"]
|
|
lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance)
|
|
expected = test["output"]
|
|
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames)
|
|
self.assertEqual(expected, actual)
|