feat: optimize performance lora filtering in metadata (#3048)

* feat: add remove_performance_lora method

* feat: use class PerformanceLoRA instead of strings in config

* refactor: cleanup flags, use __member__ to check if enums contains key

* feat: only filter lora of selected performance instead of all performance LoRAs

* fix: disable intermediate results for all restricted performances

too fast for Gradio, which becomes a bottleneck

* refactor: rename parse_json to to_json, rename parse_string to to_string

* feat: use speed steps as default instead of hardcoded 30

* feat: add method to_steps to Performance

* refactor: remove method ordinal_suffix, not needed anymore

* feat: only filter lora of selected performance instead of all performance LoRAs

both metadata and history log

* feat: do not filter LoRAs in metadata parser but rather in metadata load action
This commit is contained in:
Manuel Schmid 2024-05-30 16:14:28 +02:00 committed by GitHub
parent 3ef663c5b7
commit 4e658bb63a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 144 additions and 53 deletions

View File

@ -462,8 +462,10 @@ def worker():
progressbar(async_task, 2, 'Loading models ...')
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection)
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames)
loras += performance_loras
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

View File

@ -548,25 +548,9 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
model_filenames = []
lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = []
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
def remove_special_loras(lora_filenames):
global loras_metadata_remove
loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
@ -582,10 +566,9 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
def update_files():
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
@ -634,26 +617,27 @@ def downloading_sdxl_lcm_lora():
load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_lcm_lora
file_name=modules.flags.PerformanceLoRA.EXTREME_SPEED.value
)
return sdxl_lcm_lora
return modules.flags.PerformanceLoRA.EXTREME_SPEED.value
def downloading_sdxl_lightning_lora():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_lightning_4step_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_lightning_lora
file_name=modules.flags.PerformanceLoRA.LIGHTNING.value
)
return sdxl_lightning_lora
return modules.flags.PerformanceLoRA.LIGHTNING.value
def downloading_sdxl_hyper_sd_lora():
load_file_from_url(
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_hyper_sd_4step_lora.safetensors',
model_dir=paths_loras[0],
file_name=sdxl_hyper_sd_lora
file_name=modules.flags.PerformanceLoRA.HYPER_SD.value
)
return sdxl_hyper_sd_lora
return modules.flags.PerformanceLoRA.HYPER_SD.value
def downloading_controlnet_canny():

View File

@ -48,7 +48,8 @@ SAMPLERS = KSAMPLER | SAMPLER_EXTRA
KSAMPLER_NAMES = list(KSAMPLER.keys())
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo",
"align_your_steps", "tcd"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
sampler_list = SAMPLER_NAMES
@ -91,6 +92,7 @@ sdxl_aspect_ratios = [
'1664*576', '1728*576'
]
class MetadataScheme(Enum):
FOOOCUS = 'fooocus'
A1111 = 'a1111'
@ -115,6 +117,14 @@ class OutputFormat(Enum):
return list(map(lambda c: c.value, cls))
class PerformanceLoRA(Enum):
QUALITY = None
SPEED = None
EXTREME_SPEED = 'sdxl_lcm_lora.safetensors'
LIGHTNING = 'sdxl_lightning_4step_lora.safetensors'
HYPER_SD = 'sdxl_hyper_sd_4step_lora.safetensors'
class Steps(IntEnum):
QUALITY = 60
SPEED = 30
@ -142,6 +152,10 @@ class Performance(Enum):
def list(cls) -> list:
return list(map(lambda c: c.value, cls))
@classmethod
def by_steps(cls, steps: int | str):
return cls[Steps(int(steps)).name]
@classmethod
def has_restricted_features(cls, x) -> bool:
if isinstance(x, Performance):
@ -149,7 +163,10 @@ class Performance(Enum):
return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value, cls.HYPER_SD.value]
def steps(self) -> int | None:
return Steps[self.name].value if Steps[self.name] else None
return Steps[self.name].value if self.name in Steps.__members__ else None
def steps_uov(self) -> int | None:
return StepsUOV[self.name].value if Steps[self.name] else None
return StepsUOV[self.name].value if self.name in StepsUOV.__members__ else None
def lora_filename(self) -> str | None:
return PerformanceLoRA[self.name].value if self.name in PerformanceLoRA.__members__ else None

View File

@ -32,7 +32,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results)
performance = get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_number('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
@ -59,19 +59,27 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
# prevent performance LoRAs to be added twice, by performance and by lora
performance_filename = None
if performance is not None and performance in Performance.list():
performance = Performance(performance)
performance_filename = performance.lora_filename()
for i in range(modules.config.default_max_lora_number):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results, performance_filename)
return results
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None) -> str | None:
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str)
results.append(h)
return h
except:
results.append(gr.update())
return None
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
@ -181,7 +189,7 @@ def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list,
results.append(gr.update())
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, performance_filename: str | None):
try:
split_data = source_dict.get(key, source_dict.get(fallback)).split(' : ')
enabled = True
@ -193,6 +201,9 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
name = split_data[1]
weight = split_data[2]
if name == performance_filename:
raise Exception
weight = float(weight)
results.append(enabled)
results.append(name)
@ -248,7 +259,7 @@ class MetadataParser(ABC):
self.full_prompt: str = ''
self.raw_negative_prompt: str = ''
self.full_negative_prompt: str = ''
self.steps: int = 30
self.steps: int = Steps.SPEED.value
self.base_model_name: str = ''
self.base_model_hash: str = ''
self.refiner_model_name: str = ''
@ -261,11 +272,11 @@ class MetadataParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_json(self, metadata: dict | str) -> dict:
def to_json(self, metadata: dict | str) -> dict:
raise NotImplementedError
@abstractmethod
def parse_string(self, metadata: dict) -> str:
def to_string(self, metadata: dict) -> str:
raise NotImplementedError
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
@ -328,7 +339,7 @@ class A1111MetadataParser(MetadataParser):
'version': 'Version'
}
def parse_json(self, metadata: str) -> dict:
def to_json(self, metadata: str) -> dict:
metadata_prompt = ''
metadata_negative_prompt = ''
@ -382,9 +393,9 @@ class A1111MetadataParser(MetadataParser):
data['styles'] = str(found_styles)
# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
if 'steps' in data and 'performance' in data is None:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
data['performance'] = Performance.by_steps(data['steps']).value
except ValueError | KeyError:
pass
@ -414,7 +425,7 @@ class A1111MetadataParser(MetadataParser):
lora_split = lora.split(': ')
lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in modules.config.lora_filenames_no_special:
for filename in modules.config.lora_filenames:
path = Path(filename)
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
@ -422,7 +433,7 @@ class A1111MetadataParser(MetadataParser):
return data
def parse_string(self, metadata: dict) -> str:
def to_string(self, metadata: dict) -> str:
data = {k: v for _, k, v in metadata}
width, height = eval(data['resolution'])
@ -502,14 +513,14 @@ class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS
def parse_json(self, metadata: dict) -> dict:
def to_json(self, metadata: dict) -> dict:
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else:
@ -517,7 +528,7 @@ class FooocusMetadataParser(MetadataParser):
return metadata
def parse_string(self, metadata: list) -> str:
def to_string(self, metadata: list) -> str:
for li, (label, key, value) in enumerate(metadata):
# remove model folder paths from metadata
if key.startswith('lora_combined_'):
@ -557,6 +568,8 @@ class FooocusMetadataParser(MetadataParser):
elif value == path.stem:
return filename
return None
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
match metadata_scheme:

View File

@ -27,7 +27,7 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
parsed_parameters = metadata_parser.parse_string(metadata.copy()) if metadata_parser is not None else ''
parsed_parameters = metadata_parser.to_string(metadata.copy()) if metadata_parser is not None else ''
image = Image.fromarray(img)
if output_format == OutputFormat.PNG.value:

View File

@ -16,6 +16,7 @@ from PIL import Image
import modules.config
import modules.sdxl_styles
from modules.flags import Performance
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@ -381,9 +382,6 @@ def get_file_from_folder_list(name, folders):
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
def ordinal_suffix(number: int) -> str:
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
def makedirs_with_log(path):
try:
@ -397,10 +395,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True,
lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]:
if lora_filenames is None:
lora_filenames = []
found_loras = []
prompt_without_loras = ''
cleaned_prompt = ''
for token in prompt.split(','):
matches = LORAS_PROMPT_PATTERN.findall(token)
@ -410,7 +413,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
for match in matches:
lora_name = match[1] + '.safetensors'
if not skip_file_check:
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
lora_name = get_filname_by_stem(match[1], lora_filenames)
if lora_name is not None:
found_loras.append((lora_name, float(match[2])))
token = token.replace(match[0], '')
@ -440,6 +443,22 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
return updated_loras[:loras_limit], cleaned_prompt
def remove_performance_lora(filenames: list, performance: Performance | None):
loras_without_performance = filenames.copy()
if performance is None:
return loras_without_performance
performance_lora = performance.lora_filename()
for filename in filenames:
path = Path(filename)
if performance_lora == path.name:
loras_without_performance.remove(filename)
return loras_without_performance
def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)

View File

@ -1,5 +1,7 @@
import os
import unittest
import modules.flags
from modules import util
@ -77,5 +79,59 @@ class TestUtils(unittest.TestCase):
for test in test_cases:
prompt, loras, loras_limit, skip_file_check = test["input"]
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit,
skip_file_check=skip_file_check)
self.assertEqual(expected, actual)
def test_can_parse_tokens_and_strip_performance_lora(self):
lora_filenames = [
'hey-lora.safetensors',
modules.flags.PerformanceLoRA.EXTREME_SPEED.value,
modules.flags.PerformanceLoRA.LIGHTNING.value,
os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value)
]
test_cases = [
{
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
},
{
"input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD),
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt'
),
}
]
for test in test_cases:
prompt, loras, loras_limit, skip_file_check, performance = test["input"]
lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance)
expected = test["output"]
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames)
self.assertEqual(expected, actual)

View File

@ -461,8 +461,8 @@ with shared.gradio_root:
interactive=not modules.config.default_black_out_nsfw,
info='Disable preview during generation.')
disable_intermediate_results = gr.Checkbox(label='Disable Intermediate Results',
value=modules.config.default_performance == flags.Performance.EXTREME_SPEED.value,
interactive=modules.config.default_performance != flags.Performance.EXTREME_SPEED.value,
value=flags.Performance.has_restricted_features(modules.config.default_performance),
interactive=not flags.Performance.has_restricted_features(modules.config.default_performance),
info='Disable intermediate results during generation, only show final gallery.')
disable_seed_increment = gr.Checkbox(label='Disable seed increment',
info='Disable automatic seed increment when image number is > 1.',
@ -713,7 +713,7 @@ with shared.gradio_root:
parsed_parameters = {}
else:
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
parsed_parameters = metadata_parser.parse_json(parameters)
parsed_parameters = metadata_parser.to_json(parameters)
return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating)