feat: inline lora optimisations (#2967)
* feat: add performance loras to the end of the loras array * fix: resolve circular dependency for unit tests * feat: allow multiple matches for each token, optimize and extract method cleanup_prompt * fix: update unit tests * feat: ignore custom wildcards
This commit is contained in:
parent
c995511705
commit
65a8b25129
|
|
@ -237,10 +237,12 @@ def worker():
|
|||
|
||||
steps = performance_selection.steps()
|
||||
|
||||
performance_loras = []
|
||||
|
||||
if performance_selection == Performance.EXTREME_SPEED:
|
||||
print('Enter LCM mode.')
|
||||
progressbar(async_task, 1, 'Downloading LCM components ...')
|
||||
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
|
||||
performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
|
||||
|
||||
if refiner_model_name != 'None':
|
||||
print(f'Refiner disabled in LCM mode.')
|
||||
|
|
@ -259,7 +261,7 @@ def worker():
|
|||
elif performance_selection == Performance.LIGHTNING:
|
||||
print('Enter Lightning mode.')
|
||||
progressbar(async_task, 1, 'Downloading Lightning components ...')
|
||||
loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
|
||||
performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
|
||||
|
||||
if refiner_model_name != 'None':
|
||||
print(f'Refiner disabled in Lightning mode.')
|
||||
|
|
@ -278,7 +280,7 @@ def worker():
|
|||
elif performance_selection == Performance.HYPER_SD:
|
||||
print('Enter Hyper-SD mode.')
|
||||
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
|
||||
loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
||||
performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
||||
|
||||
if refiner_model_name != 'None':
|
||||
print(f'Refiner disabled in Hyper-SD mode.')
|
||||
|
|
@ -458,8 +460,8 @@ def worker():
|
|||
|
||||
progressbar(async_task, 2, 'Loading models ...')
|
||||
|
||||
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||
|
||||
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||
loras += performance_loras
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ import modules.flags
|
|||
import modules.sdxl_styles
|
||||
|
||||
from modules.model_loader import load_file_from_url
|
||||
from modules.util import makedirs_with_log
|
||||
from modules.extra_utils import get_files_from_folder
|
||||
from modules.extra_utils import makedirs_with_log, get_files_from_folder
|
||||
from modules.flags import OutputFormat, Performance, MetadataScheme
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
import os
|
||||
|
||||
def makedirs_with_log(path):
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
except OSError as error:
|
||||
print(f'Directory {path} could not be created, reason: {error}')
|
||||
|
||||
|
||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||
if not os.path.isdir(folder_path):
|
||||
|
|
|
|||
|
|
@ -12,15 +12,15 @@ import hashlib
|
|||
|
||||
from PIL import Image
|
||||
|
||||
import modules.config
|
||||
import modules.sdxl_styles
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
# Regexp compiled once. Matches entries with the following pattern:
|
||||
# <lora:some_lora:1>
|
||||
# <lora:aNotherLora:-1.6>
|
||||
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
|
||||
LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)
|
||||
|
||||
HASH_SHA256_LENGTH = 10
|
||||
|
||||
|
|
@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders):
|
|||
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
||||
|
||||
|
||||
def makedirs_with_log(path):
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
except OSError as error:
|
||||
print(f'Directory {path} could not be created, reason: {error}')
|
||||
def get_enabled_loras(loras: list, remove_none=True) -> list:
|
||||
return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]
|
||||
|
||||
|
||||
def get_enabled_loras(loras: list) -> list:
|
||||
return [(lora[1], lora[2]) for lora in loras if lora[0]]
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
||||
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||
found_loras = []
|
||||
prompt_without_loras = ""
|
||||
for token in prompt.split(" "):
|
||||
matches = LORAS_PROMPT_PATTERN.findall(token)
|
||||
|
||||
if matches:
|
||||
for match in matches:
|
||||
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
|
||||
prompt_without_loras += token.replace(match[0], '')
|
||||
else:
|
||||
prompt_without_loras += token
|
||||
prompt_without_loras += ' '
|
||||
|
||||
cleaned_prompt = prompt_without_loras[:-1]
|
||||
if prompt_cleanup:
|
||||
cleaned_prompt = cleanup_prompt(prompt_without_loras)
|
||||
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
|
||||
new_loras = []
|
||||
lora_names = [lora[0] for lora in loras]
|
||||
for found_lora in found_loras:
|
||||
if deduplicate_loras and found_lora[0] in lora_names:
|
||||
continue
|
||||
new_loras.append(found_lora)
|
||||
|
||||
if len(new_loras) == 0:
|
||||
return loras, cleaned_prompt
|
||||
|
||||
updated_loras = []
|
||||
for token in prompt.split(","):
|
||||
m = LORAS_PROMPT_PATTERN.match(token)
|
||||
|
||||
if m:
|
||||
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
||||
|
||||
for lora in loras + new_loras:
|
||||
if lora[0] != "None":
|
||||
updated_loras.append(lora)
|
||||
|
||||
return updated_loras[:loras_limit]
|
||||
return updated_loras[:loras_limit], cleaned_prompt
|
||||
|
||||
|
||||
def cleanup_prompt(prompt):
|
||||
prompt = re.sub(' +', ' ', prompt)
|
||||
prompt = re.sub(',+', ',', prompt)
|
||||
cleaned_prompt = ''
|
||||
for token in prompt.split(','):
|
||||
token = token.strip()
|
||||
if token == '':
|
||||
continue
|
||||
cleaned_prompt += token + ', '
|
||||
return cleaned_prompt[:-2]
|
||||
|
||||
|
||||
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
|
||||
|
|
|
|||
|
|
@ -8,12 +8,16 @@ class TestUtils(unittest.TestCase):
|
|||
test_cases = [
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
||||
},
|
||||
# Test can not exceed limit
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||
"output": [("hey-lora.safetensors", 0.4)],
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt, very cool, cool'
|
||||
),
|
||||
},
|
||||
# test Loras from UI take precedence over prompt
|
||||
{
|
||||
|
|
@ -22,22 +26,33 @@ class TestUtils(unittest.TestCase):
|
|||
[("hey-lora.safetensors", 0.4)],
|
||||
5,
|
||||
),
|
||||
"output": [
|
||||
("hey-lora.safetensors", 0.4),
|
||||
("l1.safetensors", 0.4),
|
||||
("l2.safetensors", -0.2),
|
||||
("l3.safetensors", 0.3),
|
||||
("l4.safetensors", 0.5),
|
||||
"output": (
|
||||
[
|
||||
('hey-lora.safetensors', 0.4),
|
||||
('l1.safetensors', 0.4),
|
||||
('l2.safetensors', -0.2),
|
||||
('l3.safetensors', 0.3),
|
||||
('l4.safetensors', 0.5)
|
||||
],
|
||||
'some prompt, very cool'
|
||||
)
|
||||
},
|
||||
# Test lora specification not separated by comma are ignored, only latest specified is used
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||
"output": [("you-lora.safetensors", 0.2)],
|
||||
"output": (
|
||||
[
|
||||
('hey-lora.safetensors', 0.4),
|
||||
('you-lora.safetensors', 0.2)
|
||||
],
|
||||
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||
"output": []
|
||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||
"output": (
|
||||
[],
|
||||
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
*.txt
|
||||
!animal.txt
|
||||
!artist.txt
|
||||
!color.txt
|
||||
!color_flower.txt
|
||||
!extended-color.txt
|
||||
!flower.txt
|
||||
!nationality.txt
|
||||
Loading…
Reference in New Issue