feat: inline lora optimisations (#2967)
* feat: add performance loras to the end of the loras array * fix: resolve circular dependency for unit tests * feat: allow multiple matches for each token, optimize and extract method cleanup_prompt * fix: update unit tests * feat: ignore custom wildcards
This commit is contained in:
parent
c995511705
commit
65a8b25129
|
|
@ -237,10 +237,12 @@ def worker():
|
||||||
|
|
||||||
steps = performance_selection.steps()
|
steps = performance_selection.steps()
|
||||||
|
|
||||||
|
performance_loras = []
|
||||||
|
|
||||||
if performance_selection == Performance.EXTREME_SPEED:
|
if performance_selection == Performance.EXTREME_SPEED:
|
||||||
print('Enter LCM mode.')
|
print('Enter LCM mode.')
|
||||||
progressbar(async_task, 1, 'Downloading LCM components ...')
|
progressbar(async_task, 1, 'Downloading LCM components ...')
|
||||||
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
|
performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
|
||||||
|
|
||||||
if refiner_model_name != 'None':
|
if refiner_model_name != 'None':
|
||||||
print(f'Refiner disabled in LCM mode.')
|
print(f'Refiner disabled in LCM mode.')
|
||||||
|
|
@ -259,7 +261,7 @@ def worker():
|
||||||
elif performance_selection == Performance.LIGHTNING:
|
elif performance_selection == Performance.LIGHTNING:
|
||||||
print('Enter Lightning mode.')
|
print('Enter Lightning mode.')
|
||||||
progressbar(async_task, 1, 'Downloading Lightning components ...')
|
progressbar(async_task, 1, 'Downloading Lightning components ...')
|
||||||
loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
|
performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
|
||||||
|
|
||||||
if refiner_model_name != 'None':
|
if refiner_model_name != 'None':
|
||||||
print(f'Refiner disabled in Lightning mode.')
|
print(f'Refiner disabled in Lightning mode.')
|
||||||
|
|
@ -278,7 +280,7 @@ def worker():
|
||||||
elif performance_selection == Performance.HYPER_SD:
|
elif performance_selection == Performance.HYPER_SD:
|
||||||
print('Enter Hyper-SD mode.')
|
print('Enter Hyper-SD mode.')
|
||||||
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
|
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
|
||||||
loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
||||||
|
|
||||||
if refiner_model_name != 'None':
|
if refiner_model_name != 'None':
|
||||||
print(f'Refiner disabled in Hyper-SD mode.')
|
print(f'Refiner disabled in Hyper-SD mode.')
|
||||||
|
|
@ -458,8 +460,8 @@ def worker():
|
||||||
|
|
||||||
progressbar(async_task, 2, 'Loading models ...')
|
progressbar(async_task, 2, 'Loading models ...')
|
||||||
|
|
||||||
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||||
|
loras += performance_loras
|
||||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,7 @@ import modules.flags
|
||||||
import modules.sdxl_styles
|
import modules.sdxl_styles
|
||||||
|
|
||||||
from modules.model_loader import load_file_from_url
|
from modules.model_loader import load_file_from_url
|
||||||
from modules.util import makedirs_with_log
|
from modules.extra_utils import makedirs_with_log, get_files_from_folder
|
||||||
from modules.extra_utils import get_files_from_folder
|
|
||||||
from modules.flags import OutputFormat, Performance, MetadataScheme
|
from modules.flags import OutputFormat, Performance, MetadataScheme
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
def makedirs_with_log(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
except OSError as error:
|
||||||
|
print(f'Directory {path} could not be created, reason: {error}')
|
||||||
|
|
||||||
|
|
||||||
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
def get_files_from_folder(folder_path, extensions=None, name_filter=None):
|
||||||
if not os.path.isdir(folder_path):
|
if not os.path.isdir(folder_path):
|
||||||
|
|
|
||||||
|
|
@ -12,15 +12,15 @@ import hashlib
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.config
|
||||||
import modules.sdxl_styles
|
import modules.sdxl_styles
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
# Regexp compiled once. Matches entries with the following pattern:
|
# Regexp compiled once. Matches entries with the following pattern:
|
||||||
# <lora:some_lora:1>
|
# <lora:some_lora:1>
|
||||||
# <lora:aNotherLora:-1.6>
|
# <lora:aNotherLora:-1.6>
|
||||||
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
|
LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)
|
||||||
|
|
||||||
HASH_SHA256_LENGTH = 10
|
HASH_SHA256_LENGTH = 10
|
||||||
|
|
||||||
|
|
@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders):
|
||||||
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
||||||
|
|
||||||
|
|
||||||
def makedirs_with_log(path):
|
def get_enabled_loras(loras: list, remove_none=True) -> list:
|
||||||
try:
|
return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
except OSError as error:
|
|
||||||
print(f'Directory {path} could not be created, reason: {error}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_loras(loras: list) -> list:
|
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
||||||
return [(lora[1], lora[2]) for lora in loras if lora[0]]
|
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||||
|
found_loras = []
|
||||||
|
prompt_without_loras = ""
|
||||||
|
for token in prompt.split(" "):
|
||||||
|
matches = LORAS_PROMPT_PATTERN.findall(token)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
for match in matches:
|
||||||
|
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
|
||||||
|
prompt_without_loras += token.replace(match[0], '')
|
||||||
|
else:
|
||||||
|
prompt_without_loras += token
|
||||||
|
prompt_without_loras += ' '
|
||||||
|
|
||||||
|
cleaned_prompt = prompt_without_loras[:-1]
|
||||||
|
if prompt_cleanup:
|
||||||
|
cleaned_prompt = cleanup_prompt(prompt_without_loras)
|
||||||
|
|
||||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
|
|
||||||
new_loras = []
|
new_loras = []
|
||||||
|
lora_names = [lora[0] for lora in loras]
|
||||||
|
for found_lora in found_loras:
|
||||||
|
if deduplicate_loras and found_lora[0] in lora_names:
|
||||||
|
continue
|
||||||
|
new_loras.append(found_lora)
|
||||||
|
|
||||||
|
if len(new_loras) == 0:
|
||||||
|
return loras, cleaned_prompt
|
||||||
|
|
||||||
updated_loras = []
|
updated_loras = []
|
||||||
for token in prompt.split(","):
|
|
||||||
m = LORAS_PROMPT_PATTERN.match(token)
|
|
||||||
|
|
||||||
if m:
|
|
||||||
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
|
||||||
|
|
||||||
for lora in loras + new_loras:
|
for lora in loras + new_loras:
|
||||||
if lora[0] != "None":
|
if lora[0] != "None":
|
||||||
updated_loras.append(lora)
|
updated_loras.append(lora)
|
||||||
|
|
||||||
return updated_loras[:loras_limit]
|
return updated_loras[:loras_limit], cleaned_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_prompt(prompt):
|
||||||
|
prompt = re.sub(' +', ' ', prompt)
|
||||||
|
prompt = re.sub(',+', ',', prompt)
|
||||||
|
cleaned_prompt = ''
|
||||||
|
for token in prompt.split(','):
|
||||||
|
token = token.strip()
|
||||||
|
if token == '':
|
||||||
|
continue
|
||||||
|
cleaned_prompt += token + ', '
|
||||||
|
return cleaned_prompt[:-2]
|
||||||
|
|
||||||
|
|
||||||
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
|
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,16 @@ class TestUtils(unittest.TestCase):
|
||||||
test_cases = [
|
test_cases = [
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||||
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
|
"output": (
|
||||||
|
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
||||||
},
|
},
|
||||||
# Test can not exceed limit
|
# Test can not exceed limit
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||||
"output": [("hey-lora.safetensors", 0.4)],
|
"output": (
|
||||||
|
[('hey-lora.safetensors', 0.4)],
|
||||||
|
'some prompt, very cool, cool'
|
||||||
|
),
|
||||||
},
|
},
|
||||||
# test Loras from UI take precedence over prompt
|
# test Loras from UI take precedence over prompt
|
||||||
{
|
{
|
||||||
|
|
@ -22,22 +26,33 @@ class TestUtils(unittest.TestCase):
|
||||||
[("hey-lora.safetensors", 0.4)],
|
[("hey-lora.safetensors", 0.4)],
|
||||||
5,
|
5,
|
||||||
),
|
),
|
||||||
"output": [
|
"output": (
|
||||||
("hey-lora.safetensors", 0.4),
|
[
|
||||||
("l1.safetensors", 0.4),
|
('hey-lora.safetensors', 0.4),
|
||||||
("l2.safetensors", -0.2),
|
('l1.safetensors', 0.4),
|
||||||
("l3.safetensors", 0.3),
|
('l2.safetensors', -0.2),
|
||||||
("l4.safetensors", 0.5),
|
('l3.safetensors', 0.3),
|
||||||
],
|
('l4.safetensors', 0.5)
|
||||||
|
],
|
||||||
|
'some prompt, very cool'
|
||||||
|
)
|
||||||
},
|
},
|
||||||
# Test lora specification not separated by comma are ignored, only latest specified is used
|
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||||
"output": [("you-lora.safetensors", 0.2)],
|
"output": (
|
||||||
|
[
|
||||||
|
('hey-lora.safetensors', 0.4),
|
||||||
|
('you-lora.safetensors', 0.2)
|
||||||
|
],
|
||||||
|
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
|
||||||
|
),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
|
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||||
"output": []
|
"output": (
|
||||||
|
[],
|
||||||
|
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
||||||
|
)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
*.txt
|
||||||
|
!animal.txt
|
||||||
|
!artist.txt
|
||||||
|
!color.txt
|
||||||
|
!color_flower.txt
|
||||||
|
!extended-color.txt
|
||||||
|
!flower.txt
|
||||||
|
!nationality.txt
|
||||||
Loading…
Reference in New Issue