feat: only filter lora of selected performance instead of all performance LoRAs

both metadata and history log
This commit is contained in:
Manuel Schmid 2024-05-30 00:22:31 +02:00
parent 9c8ffbbe18
commit 91281e5561
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 38 additions and 26 deletions

View File

@ -548,25 +548,9 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
model_filenames = []
lora_filenames = []
lora_filenames_no_special = []
vae_filenames = []
wildcard_filenames = []
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
def remove_special_loras(lora_filenames):
global loras_metadata_remove
loras_no_special = lora_filenames.copy()
for lora_to_remove in loras_metadata_remove:
if lora_to_remove in loras_no_special:
loras_no_special.remove(lora_to_remove)
return loras_no_special
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None:
@ -582,10 +566,9 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
def update_files():
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
lora_filenames_no_special = remove_special_loras(lora_filenames)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()

View File

@ -32,7 +32,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results)
performance = get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_number('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
@ -59,19 +59,26 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
performance_filename = None
if performance is not None and performance in Performance.list():
performance = Performance(performance)
performance_filename = performance.lora_filename()
for i in range(modules.config.default_max_lora_number):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results, performance_filename)
return results
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None) -> str | None:
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str)
results.append(h)
return h
except:
results.append(gr.update())
return None
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
@ -181,7 +188,7 @@ def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list,
results.append(gr.update())
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, performance_filename: str | None):
try:
split_data = source_dict.get(key, source_dict.get(fallback)).split(' : ')
enabled = True
@ -193,6 +200,9 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
name = split_data[1]
weight = split_data[2]
if name == performance_filename:
raise Exception
weight = float(weight)
results.append(enabled)
results.append(name)
@ -381,10 +391,19 @@ class A1111MetadataParser(MetadataParser):
data['styles'] = str(found_styles)
performance: Performance | None = None
performance_lora = None
if 'performance' in data and data['performance'] in Performance.list():
performance = Performance(data['performance'])
data['performance'] = performance.value
performance_lora = performance.lora_filename()
# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
if 'steps' in data and performance is None:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
performance = Performance.by_steps(data['steps'])
data['performance'] = performance.value
performance_lora = performance.lora_filename()
except ValueError | KeyError:
pass
@ -414,8 +433,10 @@ class A1111MetadataParser(MetadataParser):
lora_split = lora.split(': ')
lora_name = lora_split[0]
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
for filename in modules.config.lora_filenames_no_special:
for filename in modules.config.lora_filenames:
path = Path(filename)
if performance_lora is not None and path.name == performance_lora:
break
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
break
@ -503,13 +524,19 @@ class FooocusMetadataParser(MetadataParser):
return MetadataScheme.FOOOCUS
def to_json(self, metadata: dict) -> dict:
performance = None
if 'performance' in metadata and metadata['performance'] in Performance.list():
performance = Performance(metadata['performance'])
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance)
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
else:
@ -557,6 +584,8 @@ class FooocusMetadataParser(MetadataParser):
elif value == path.stem:
return filename
return None
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
match metadata_scheme: