feat: inline lora optimisations (#2967)

* feat: add performance loras to the end of the loras array

* fix: resolve circular dependency for unit tests

* feat: allow multiple matches for each token, optimize and extract method cleanup_prompt

* fix: update unit tests

* feat: ignore custom wildcards
This commit is contained in:
Manuel Schmid 2024-05-20 17:31:51 +02:00 committed by GitHub
parent c995511705
commit 65a8b25129
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 93 additions and 37 deletions

View File

@ -237,10 +237,12 @@ def worker():
steps = performance_selection.steps() steps = performance_selection.steps()
performance_loras = []
if performance_selection == Performance.EXTREME_SPEED: if performance_selection == Performance.EXTREME_SPEED:
print('Enter LCM mode.') print('Enter LCM mode.')
progressbar(async_task, 1, 'Downloading LCM components ...') progressbar(async_task, 1, 'Downloading LCM components ...')
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
if refiner_model_name != 'None': if refiner_model_name != 'None':
print(f'Refiner disabled in LCM mode.') print(f'Refiner disabled in LCM mode.')
@ -259,7 +261,7 @@ def worker():
elif performance_selection == Performance.LIGHTNING: elif performance_selection == Performance.LIGHTNING:
print('Enter Lightning mode.') print('Enter Lightning mode.')
progressbar(async_task, 1, 'Downloading Lightning components ...') progressbar(async_task, 1, 'Downloading Lightning components ...')
loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)] performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
if refiner_model_name != 'None': if refiner_model_name != 'None':
print(f'Refiner disabled in Lightning mode.') print(f'Refiner disabled in Lightning mode.')
@ -278,7 +280,7 @@ def worker():
elif performance_selection == Performance.HYPER_SD: elif performance_selection == Performance.HYPER_SD:
print('Enter Hyper-SD mode.') print('Enter Hyper-SD mode.')
progressbar(async_task, 1, 'Downloading Hyper-SD components ...') progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)] performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
if refiner_model_name != 'None': if refiner_model_name != 'None':
print(f'Refiner disabled in Hyper-SD mode.') print(f'Refiner disabled in Hyper-SD mode.')
@ -458,8 +460,8 @@ def worker():
progressbar(async_task, 2, 'Loading models ...') progressbar(async_task, 2, 'Loading models ...')
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number) loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
loras += performance_loras
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras, loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

View File

@ -8,8 +8,7 @@ import modules.flags
import modules.sdxl_styles import modules.sdxl_styles
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.util import makedirs_with_log from modules.extra_utils import makedirs_with_log, get_files_from_folder
from modules.extra_utils import get_files_from_folder
from modules.flags import OutputFormat, Performance, MetadataScheme from modules.flags import OutputFormat, Performance, MetadataScheme

View File

@ -1,5 +1,11 @@
import os import os
def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')
def get_files_from_folder(folder_path, extensions=None, name_filter=None): def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path): if not os.path.isdir(folder_path):

View File

@ -12,15 +12,15 @@ import hashlib
from PIL import Image from PIL import Image
import modules.config
import modules.sdxl_styles import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern: # Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1> # <lora:some_lora:1>
# <lora:aNotherLora:-1.6> # <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X) LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)
HASH_SHA256_LENGTH = 10 HASH_SHA256_LENGTH = 10
@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders):
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
def makedirs_with_log(path): def get_enabled_loras(loras: list, remove_none=True) -> list:
try: return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')
def get_enabled_loras(loras: list) -> list: def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
return [(lora[1], lora[2]) for lora in loras if lora[0]] prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ""
for token in prompt.split(" "):
matches = LORAS_PROMPT_PATTERN.findall(token)
if matches:
for match in matches:
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
prompt_without_loras += token.replace(match[0], '')
else:
prompt_without_loras += token
prompt_without_loras += ' '
cleaned_prompt = prompt_without_loras[:-1]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = [] new_loras = []
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and found_lora[0] in lora_names:
continue
new_loras.append(found_lora)
if len(new_loras) == 0:
return loras, cleaned_prompt
updated_loras = [] updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
for lora in loras + new_loras: for lora in loras + new_loras:
if lora[0] != "None": if lora[0] != "None":
updated_loras.append(lora) updated_loras.append(lora)
return updated_loras[:loras_limit] return updated_loras[:loras_limit], cleaned_prompt
def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)
cleaned_prompt = ''
for token in prompt.split(','):
token = token.strip()
if token == '':
continue
cleaned_prompt += token + ', '
return cleaned_prompt[:-2]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:

View File

@ -8,12 +8,16 @@ class TestUtils(unittest.TestCase):
test_cases = [ test_cases = [
{ {
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5), "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)], "output": (
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
}, },
# Test can not exceed limit # Test can not exceed limit
{ {
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1), "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)], "output": (
[('hey-lora.safetensors', 0.4)],
'some prompt, very cool, cool'
),
}, },
# test Loras from UI take precedence over prompt # test Loras from UI take precedence over prompt
{ {
@ -22,22 +26,33 @@ class TestUtils(unittest.TestCase):
[("hey-lora.safetensors", 0.4)], [("hey-lora.safetensors", 0.4)],
5, 5,
), ),
"output": [ "output": (
("hey-lora.safetensors", 0.4), [
("l1.safetensors", 0.4), ('hey-lora.safetensors', 0.4),
("l2.safetensors", -0.2), ('l1.safetensors', 0.4),
("l3.safetensors", 0.3), ('l2.safetensors', -0.2),
("l4.safetensors", 0.5), ('l3.safetensors', 0.3),
], ('l4.safetensors', 0.5)
],
'some prompt, very cool'
)
}, },
# Test lora specification not separated by comma are ignored, only latest specified is used
{ {
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3), "input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)], "output": (
[
('hey-lora.safetensors', 0.4),
('you-lora.safetensors', 0.2)
],
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
),
}, },
{ {
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6), "input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
"output": [] "output": (
[],
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
)
} }
] ]

8
wildcards/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
*.txt
!animal.txt
!artist.txt
!color.txt
!color_flower.txt
!extended-color.txt
!flower.txt
!nationality.txt