diff --git a/modules/config.py b/modules/config.py index 913fb281..94046661 100644 --- a/modules/config.py +++ b/modules/config.py @@ -547,6 +547,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file: model_filenames = [] lora_filenames = [] +lora_filenames_no_special = [] vae_filenames = [] wildcard_filenames = [] @@ -556,6 +557,16 @@ sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors' loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora] +def remove_special_loras(lora_filenames): + global loras_metadata_remove + + loras_no_special = lora_filenames.copy() + for lora_to_remove in loras_metadata_remove: + if lora_to_remove in loras_no_special: + loras_no_special.remove(lora_to_remove) + return loras_no_special + + def get_model_filenames(folder_paths, extensions=None, name_filter=None): if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] @@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): def update_files(): - global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets + global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) + lora_filenames_no_special = remove_special_loras(lora_filenames) vae_filenames = get_model_filenames(path_vae) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) available_presets = get_presets() diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 84032e82..2469da5f 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): def get_sha256(filepath): global hash_cache if filepath not in hash_cache: - # is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors' hash_cache[filepath] = sha256(filepath) return hash_cache[filepath] @@ -293,12 +292,6 @@ class MetadataParser(ABC): self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) self.vae_name = Path(vae_name).stem - @staticmethod - def remove_special_loras(lora_filenames): - for lora_to_remove in modules.config.loras_metadata_remove: - if lora_to_remove in lora_filenames: - lora_filenames.remove(lora_to_remove) - class A1111MetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -415,13 +408,11 @@ class A1111MetadataParser(MetadataParser): lora_data = data['lora_hashes'] if lora_data != '': - lora_filenames = modules.config.lora_filenames.copy() - self.remove_special_loras(lora_filenames) for li, lora in enumerate(lora_data.split(', ')): lora_split = lora.split(': ') lora_name = lora_split[0] lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] - for filename in lora_filenames: + for filename in modules.config.lora_filenames_no_special: path = Path(filename) if lora_name == path.stem: data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' @@ -510,19 +501,15 @@ class FooocusMetadataParser(MetadataParser): return MetadataScheme.FOOOCUS def parse_json(self, metadata: dict) -> dict: - model_filenames = modules.config.model_filenames.copy() - lora_filenames = modules.config.lora_filenames.copy() - vae_filenames = modules.config.vae_filenames.copy() - self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: continue if key in ['base_model', 'refiner_model']: - metadata[key] = self.replace_value_with_filename(key, value, model_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames) elif key.startswith('lora_combined_'): - metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special) elif key == 'vae': - metadata[key] = self.replace_value_with_filename(key, value, vae_filenames) + metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames) else: continue diff --git a/modules/util.py b/modules/util.py index 52bc490a..cb5580fb 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import datetime import random @@ -360,6 +362,14 @@ def is_json(data: str) -> bool: return True +def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None: + for filename in filenames: + path = Path(filename) + if lora_name == path.stem: + return filename + return None + + def get_file_from_folder_list(name, folders): if not isinstance(folders, list): folders = [folders] @@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list: def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, - prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: found_loras = [] - prompt_without_loras = "" - for token in prompt.split(" "): + prompt_without_loras = '' + cleaned_prompt = '' + for token in prompt.split(','): matches = LORAS_PROMPT_PATTERN.findall(token) - if matches: - for match in matches: - found_loras.append((f"{match[1]}.safetensors", float(match[2]))) - prompt_without_loras += token.replace(match[0], '') - else: - prompt_without_loras += token - prompt_without_loras += ' ' + if len(matches) == 0: + prompt_without_loras += token + ', ' + continue + for match in matches: + lora_name = match[1] + '.safetensors' + if not skip_file_check: + lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special) + if lora_name is not None: + found_loras.append((lora_name, float(match[2]))) + token = token.replace(match[0], '') + prompt_without_loras += token + ', ' + + if prompt_without_loras != '': + cleaned_prompt = prompt_without_loras[:-2] - cleaned_prompt = prompt_without_loras[:-1] if prompt_cleanup: cleaned_prompt = cleanup_prompt(prompt_without_loras) new_loras = [] lora_names = [lora[0] for lora in loras] for found_lora in found_loras: - if deduplicate_loras and found_lora[0] in lora_names: + if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras): continue new_loras.append(found_lora) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9f81005b..6fd550db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase): def test_can_parse_tokens_with_lora(self): test_cases = [ { - "input": ("some prompt, very cool, , cool ", [], 5), + "input": ("some prompt, very cool, , cool ", [], 5, True), "output": ( [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), }, # Test can not exceed limit { - "input": ("some prompt, very cool, , cool ", [], 1), + "input": ("some prompt, very cool, , cool ", [], 1, True), "output": ( [('hey-lora.safetensors', 0.4)], 'some prompt, very cool, cool' @@ -25,6 +25,7 @@ class TestUtils(unittest.TestCase): "some prompt, very cool, , , , , , ", [("hey-lora.safetensors", 0.4)], 5, + True ), "output": ( [ @@ -37,18 +38,35 @@ class TestUtils(unittest.TestCase): 'some prompt, very cool' ) }, + # test correct matching even if there is no space separating loras in the same token { - "input": ("some prompt, very cool, ", [], 3), + "input": ("some prompt, very cool, ", [], 3, True), "output": ( [ ('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2) ], - 'some prompt, very cool, ' + 'some prompt, very cool' + ), + }, + # test deduplication, also selected loras are never overridden with loras in prompt + { + "input": ( + "some prompt, very cool, ", + [('you-lora.safetensors', 0.3)], + 3, + True + ), + "output": ( + [ + ('you-lora.safetensors', 0.3), + ('hey-lora.safetensors', 0.4) + ], + 'some prompt, very cool' ), }, { - "input": (", , , and ", [], 6), + "input": (", , , and ", [], 6, True), "output": ( [], ', , , and ' @@ -57,7 +75,7 @@ class TestUtils(unittest.TestCase): ] for test in test_cases: - prompt, loras, loras_limit = test["input"] + prompt, loras, loras_limit, skip_file_check = test["input"] expected = test["output"] - actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit) + actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check) self.assertEqual(expected, actual)