feat: only use valid inline loras, add subfolder support (#2968)
This commit is contained in:
parent
ac14d9d03c
commit
7537612bcc
|
|
@ -547,6 +547,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
|
|||
|
||||
model_filenames = []
|
||||
lora_filenames = []
|
||||
lora_filenames_no_special = []
|
||||
vae_filenames = []
|
||||
wildcard_filenames = []
|
||||
|
||||
|
|
@ -556,6 +557,16 @@ sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
|
|||
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
|
||||
|
||||
|
||||
def remove_special_loras(lora_filenames):
|
||||
global loras_metadata_remove
|
||||
|
||||
loras_no_special = lora_filenames.copy()
|
||||
for lora_to_remove in loras_metadata_remove:
|
||||
if lora_to_remove in loras_no_special:
|
||||
loras_no_special.remove(lora_to_remove)
|
||||
return loras_no_special
|
||||
|
||||
|
||||
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||
if extensions is None:
|
||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||
|
|
@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
|||
|
||||
|
||||
def update_files():
|
||||
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
|
||||
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
|
||||
model_filenames = get_model_filenames(paths_checkpoints)
|
||||
lora_filenames = get_model_filenames(paths_loras)
|
||||
lora_filenames_no_special = remove_special_loras(lora_filenames)
|
||||
vae_filenames = get_model_filenames(path_vae)
|
||||
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||
available_presets = get_presets()
|
||||
|
|
|
|||
|
|
@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
|
|||
def get_sha256(filepath):
|
||||
global hash_cache
|
||||
if filepath not in hash_cache:
|
||||
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
|
||||
hash_cache[filepath] = sha256(filepath)
|
||||
|
||||
return hash_cache[filepath]
|
||||
|
|
@ -293,12 +292,6 @@ class MetadataParser(ABC):
|
|||
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
|
||||
self.vae_name = Path(vae_name).stem
|
||||
|
||||
@staticmethod
|
||||
def remove_special_loras(lora_filenames):
|
||||
for lora_to_remove in modules.config.loras_metadata_remove:
|
||||
if lora_to_remove in lora_filenames:
|
||||
lora_filenames.remove(lora_to_remove)
|
||||
|
||||
|
||||
class A1111MetadataParser(MetadataParser):
|
||||
def get_scheme(self) -> MetadataScheme:
|
||||
|
|
@ -415,13 +408,11 @@ class A1111MetadataParser(MetadataParser):
|
|||
lora_data = data['lora_hashes']
|
||||
|
||||
if lora_data != '':
|
||||
lora_filenames = modules.config.lora_filenames.copy()
|
||||
self.remove_special_loras(lora_filenames)
|
||||
for li, lora in enumerate(lora_data.split(', ')):
|
||||
lora_split = lora.split(': ')
|
||||
lora_name = lora_split[0]
|
||||
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
|
||||
for filename in lora_filenames:
|
||||
for filename in modules.config.lora_filenames_no_special:
|
||||
path = Path(filename)
|
||||
if lora_name == path.stem:
|
||||
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
|
||||
|
|
@ -510,19 +501,15 @@ class FooocusMetadataParser(MetadataParser):
|
|||
return MetadataScheme.FOOOCUS
|
||||
|
||||
def parse_json(self, metadata: dict) -> dict:
|
||||
model_filenames = modules.config.model_filenames.copy()
|
||||
lora_filenames = modules.config.lora_filenames.copy()
|
||||
vae_filenames = modules.config.vae_filenames.copy()
|
||||
self.remove_special_loras(lora_filenames)
|
||||
for key, value in metadata.items():
|
||||
if value in ['', 'None']:
|
||||
continue
|
||||
if key in ['base_model', 'refiner_model']:
|
||||
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
|
||||
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
|
||||
elif key.startswith('lora_combined_'):
|
||||
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
|
||||
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
|
||||
elif key == 'vae':
|
||||
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
|
||||
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import datetime
|
||||
import random
|
||||
|
|
@ -360,6 +362,14 @@ def is_json(data: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None:
|
||||
for filename in filenames:
|
||||
path = Path(filename)
|
||||
if lora_name == path.stem:
|
||||
return filename
|
||||
return None
|
||||
|
||||
|
||||
def get_file_from_folder_list(name, folders):
|
||||
if not isinstance(folders, list):
|
||||
folders = [folders]
|
||||
|
|
@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
|
|||
|
||||
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
||||
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||
found_loras = []
|
||||
prompt_without_loras = ""
|
||||
for token in prompt.split(" "):
|
||||
prompt_without_loras = ''
|
||||
cleaned_prompt = ''
|
||||
for token in prompt.split(','):
|
||||
matches = LORAS_PROMPT_PATTERN.findall(token)
|
||||
|
||||
if matches:
|
||||
for match in matches:
|
||||
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
|
||||
prompt_without_loras += token.replace(match[0], '')
|
||||
else:
|
||||
prompt_without_loras += token
|
||||
prompt_without_loras += ' '
|
||||
if len(matches) == 0:
|
||||
prompt_without_loras += token + ', '
|
||||
continue
|
||||
for match in matches:
|
||||
lora_name = match[1] + '.safetensors'
|
||||
if not skip_file_check:
|
||||
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
|
||||
if lora_name is not None:
|
||||
found_loras.append((lora_name, float(match[2])))
|
||||
token = token.replace(match[0], '')
|
||||
prompt_without_loras += token + ', '
|
||||
|
||||
if prompt_without_loras != '':
|
||||
cleaned_prompt = prompt_without_loras[:-2]
|
||||
|
||||
cleaned_prompt = prompt_without_loras[:-1]
|
||||
if prompt_cleanup:
|
||||
cleaned_prompt = cleanup_prompt(prompt_without_loras)
|
||||
|
||||
new_loras = []
|
||||
lora_names = [lora[0] for lora in loras]
|
||||
for found_lora in found_loras:
|
||||
if deduplicate_loras and found_lora[0] in lora_names:
|
||||
if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras):
|
||||
continue
|
||||
new_loras.append(found_lora)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase):
|
|||
def test_can_parse_tokens_with_lora(self):
|
||||
test_cases = [
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
||||
},
|
||||
# Test can not exceed limit
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
|
||||
"output": (
|
||||
[('hey-lora.safetensors', 0.4)],
|
||||
'some prompt, very cool, cool'
|
||||
|
|
@ -25,6 +25,7 @@ class TestUtils(unittest.TestCase):
|
|||
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
||||
[("hey-lora.safetensors", 0.4)],
|
||||
5,
|
||||
True
|
||||
),
|
||||
"output": (
|
||||
[
|
||||
|
|
@ -37,18 +38,35 @@ class TestUtils(unittest.TestCase):
|
|||
'some prompt, very cool'
|
||||
)
|
||||
},
|
||||
# test correct matching even if there is no space separating loras in the same token
|
||||
{
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
|
||||
"output": (
|
||||
[
|
||||
('hey-lora.safetensors', 0.4),
|
||||
('you-lora.safetensors', 0.2)
|
||||
],
|
||||
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
|
||||
'some prompt, very cool'
|
||||
),
|
||||
},
|
||||
# test deduplication, also selected loras are never overridden with loras in prompt
|
||||
{
|
||||
"input": (
|
||||
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
|
||||
[('you-lora.safetensors', 0.3)],
|
||||
3,
|
||||
True
|
||||
),
|
||||
"output": (
|
||||
[
|
||||
('you-lora.safetensors', 0.3),
|
||||
('hey-lora.safetensors', 0.4)
|
||||
],
|
||||
'some prompt, very cool'
|
||||
),
|
||||
},
|
||||
{
|
||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
|
||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
|
||||
"output": (
|
||||
[],
|
||||
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
||||
|
|
@ -57,7 +75,7 @@ class TestUtils(unittest.TestCase):
|
|||
]
|
||||
|
||||
for test in test_cases:
|
||||
prompt, loras, loras_limit = test["input"]
|
||||
prompt, loras, loras_limit, skip_file_check = test["input"]
|
||||
expected = test["output"]
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
|
||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
|
||||
self.assertEqual(expected, actual)
|
||||
|
|
|
|||
Loading…
Reference in New Issue