feat: only filter lora of selected performance instead of all performance LoRAs
This commit is contained in:
parent
55b01a81a6
commit
e3060e00d4
|
|
@ -462,8 +462,10 @@ def worker():
|
|||
|
||||
progressbar(async_task, 2, 'Loading models ...')
|
||||
|
||||
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection)
|
||||
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames)
|
||||
loras += performance_loras
|
||||
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||
|
|
|
|||
|
|
@ -398,10 +398,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
|
|||
|
||||
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
||||
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True,
|
||||
lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||
if lora_filenames is None:
|
||||
lora_filenames = []
|
||||
|
||||
found_loras = []
|
||||
prompt_without_loras = ''
|
||||
cleaned_prompt = ''
|
||||
|
||||
for token in prompt.split(','):
|
||||
matches = LORAS_PROMPT_PATTERN.findall(token)
|
||||
|
||||
|
|
@ -411,7 +416,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
|
|||
for match in matches:
|
||||
lora_name = match[1] + '.safetensors'
|
||||
if not skip_file_check:
|
||||
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
|
||||
lora_name = get_filname_by_stem(match[1], lora_filenames)
|
||||
if lora_name is not None:
|
||||
found_loras.append((lora_name, float(match[2])))
|
||||
token = token.replace(match[0], '')
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import modules.flags
|
||||
from modules import util
|
||||
|
||||
|
||||
|
|
@ -77,5 +79,59 @@ class TestUtils(unittest.TestCase):
|
|||
for test in test_cases:
|
||||
prompt, loras, loras_limit, skip_file_check = test["input"]
|
||||
expected = test["output"]
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit,
|
||||
skip_file_check=skip_file_check)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_can_parse_tokens_and_strip_performance_lora(self):
|
||||
lora_filenames = [
|
||||
'hey-lora.safetensors',
|
||||
modules.flags.PerformanceLoRA.EXTREME_SPEED.value,
|
||||
modules.flags.PerformanceLoRA.LIGHTNING.value,
|
||||
os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value)
|
||||
]
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt'
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
prompt, loras, loras_limit, skip_file_check, performance = test["input"]
|
||||
lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance)
|
||||
expected = test["output"]
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames)
|
||||
self.assertEqual(expected, actual)
|
||||
|
|
|
|||
Loading…
Reference in New Issue