feat: only filter lora of selected performance instead of all performance LoRAs

This commit is contained in:
Manuel Schmid 2024-05-30 00:16:34 +02:00
parent 55b01a81a6
commit e3060e00d4
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 67 additions and 4 deletions

View File

@ -462,8 +462,10 @@ def worker():
progressbar(async_task, 2, 'Loading models ...')
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection)
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames)
loras += performance_loras
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

View File

@ -398,10 +398,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True,
lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]:
if lora_filenames is None:
lora_filenames = []
found_loras = []
prompt_without_loras = ''
cleaned_prompt = ''
for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)
@ -411,7 +416,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
lora_name = get_filname_by_stem(match[1], lora_filenames)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')

View File

@ -1,5 +1,7 @@
import os
import unittest
import modules.flags
from modules import util
@ -77,5 +79,59 @@ class TestUtils(unittest.TestCase):
for test in test_cases:
prompt, loras, loras_limit, skip_file_check = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit,
skip_file_check=skip_file_check)
self.assertEqual(expected, actual)
def test_can_parse_tokens_and_strip_performance_lora(self):
lora_filenames = [
'hey-lora.safetensors',
modules.flags.PerformanceLoRA.EXTREME_SPEED.value,
modules.flags.PerformanceLoRA.LIGHTNING.value,
os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value)
]
test_cases = [
{
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
}
]
for test in test_cases:
prompt, loras, loras_limit, skip_file_check, performance = test["input"]
lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance)
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames)
self.assertEqual(expected, actual)