feat: only use valid inline loras, add subfolder support (#2968)
This commit is contained in:
parent
ac14d9d03c
commit
7537612bcc
|
|
@ -547,6 +547,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
|
||||||
|
|
||||||
model_filenames = []
|
model_filenames = []
|
||||||
lora_filenames = []
|
lora_filenames = []
|
||||||
|
lora_filenames_no_special = []
|
||||||
vae_filenames = []
|
vae_filenames = []
|
||||||
wildcard_filenames = []
|
wildcard_filenames = []
|
||||||
|
|
||||||
|
|
@ -556,6 +557,16 @@ sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
|
||||||
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
|
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_special_loras(lora_filenames):
|
||||||
|
global loras_metadata_remove
|
||||||
|
|
||||||
|
loras_no_special = lora_filenames.copy()
|
||||||
|
for lora_to_remove in loras_metadata_remove:
|
||||||
|
if lora_to_remove in loras_no_special:
|
||||||
|
loras_no_special.remove(lora_to_remove)
|
||||||
|
return loras_no_special
|
||||||
|
|
||||||
|
|
||||||
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||||
if extensions is None:
|
if extensions is None:
|
||||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||||
|
|
@ -570,9 +581,10 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||||
|
|
||||||
|
|
||||||
def update_files():
|
def update_files():
|
||||||
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
|
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
|
||||||
model_filenames = get_model_filenames(paths_checkpoints)
|
model_filenames = get_model_filenames(paths_checkpoints)
|
||||||
lora_filenames = get_model_filenames(paths_loras)
|
lora_filenames = get_model_filenames(paths_loras)
|
||||||
|
lora_filenames_no_special = remove_special_loras(lora_filenames)
|
||||||
vae_filenames = get_model_filenames(path_vae)
|
vae_filenames = get_model_filenames(path_vae)
|
||||||
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||||
available_presets = get_presets()
|
available_presets = get_presets()
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
|
||||||
def get_sha256(filepath):
|
def get_sha256(filepath):
|
||||||
global hash_cache
|
global hash_cache
|
||||||
if filepath not in hash_cache:
|
if filepath not in hash_cache:
|
||||||
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
|
|
||||||
hash_cache[filepath] = sha256(filepath)
|
hash_cache[filepath] = sha256(filepath)
|
||||||
|
|
||||||
return hash_cache[filepath]
|
return hash_cache[filepath]
|
||||||
|
|
@ -293,12 +292,6 @@ class MetadataParser(ABC):
|
||||||
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
|
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
|
||||||
self.vae_name = Path(vae_name).stem
|
self.vae_name = Path(vae_name).stem
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_special_loras(lora_filenames):
|
|
||||||
for lora_to_remove in modules.config.loras_metadata_remove:
|
|
||||||
if lora_to_remove in lora_filenames:
|
|
||||||
lora_filenames.remove(lora_to_remove)
|
|
||||||
|
|
||||||
|
|
||||||
class A1111MetadataParser(MetadataParser):
|
class A1111MetadataParser(MetadataParser):
|
||||||
def get_scheme(self) -> MetadataScheme:
|
def get_scheme(self) -> MetadataScheme:
|
||||||
|
|
@ -415,13 +408,11 @@ class A1111MetadataParser(MetadataParser):
|
||||||
lora_data = data['lora_hashes']
|
lora_data = data['lora_hashes']
|
||||||
|
|
||||||
if lora_data != '':
|
if lora_data != '':
|
||||||
lora_filenames = modules.config.lora_filenames.copy()
|
|
||||||
self.remove_special_loras(lora_filenames)
|
|
||||||
for li, lora in enumerate(lora_data.split(', ')):
|
for li, lora in enumerate(lora_data.split(', ')):
|
||||||
lora_split = lora.split(': ')
|
lora_split = lora.split(': ')
|
||||||
lora_name = lora_split[0]
|
lora_name = lora_split[0]
|
||||||
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
|
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
|
||||||
for filename in lora_filenames:
|
for filename in modules.config.lora_filenames_no_special:
|
||||||
path = Path(filename)
|
path = Path(filename)
|
||||||
if lora_name == path.stem:
|
if lora_name == path.stem:
|
||||||
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
|
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
|
||||||
|
|
@ -510,19 +501,15 @@ class FooocusMetadataParser(MetadataParser):
|
||||||
return MetadataScheme.FOOOCUS
|
return MetadataScheme.FOOOCUS
|
||||||
|
|
||||||
def parse_json(self, metadata: dict) -> dict:
|
def parse_json(self, metadata: dict) -> dict:
|
||||||
model_filenames = modules.config.model_filenames.copy()
|
|
||||||
lora_filenames = modules.config.lora_filenames.copy()
|
|
||||||
vae_filenames = modules.config.vae_filenames.copy()
|
|
||||||
self.remove_special_loras(lora_filenames)
|
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
if value in ['', 'None']:
|
if value in ['', 'None']:
|
||||||
continue
|
continue
|
||||||
if key in ['base_model', 'refiner_model']:
|
if key in ['base_model', 'refiner_model']:
|
||||||
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
|
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
|
||||||
elif key.startswith('lora_combined_'):
|
elif key.startswith('lora_combined_'):
|
||||||
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
|
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
|
||||||
elif key == 'vae':
|
elif key == 'vae':
|
||||||
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
|
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
|
|
@ -360,6 +362,14 @@ def is_json(data: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_filname_by_stem(lora_name, filenames: List[str]) -> str | None:
|
||||||
|
for filename in filenames:
|
||||||
|
path = Path(filename)
|
||||||
|
if lora_name == path.stem:
|
||||||
|
return filename
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_file_from_folder_list(name, folders):
|
def get_file_from_folder_list(name, folders):
|
||||||
if not isinstance(folders, list):
|
if not isinstance(folders, list):
|
||||||
folders = [folders]
|
folders = [folders]
|
||||||
|
|
@ -377,28 +387,35 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
|
||||||
|
|
||||||
|
|
||||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
|
||||||
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
|
||||||
found_loras = []
|
found_loras = []
|
||||||
prompt_without_loras = ""
|
prompt_without_loras = ''
|
||||||
for token in prompt.split(" "):
|
cleaned_prompt = ''
|
||||||
|
for token in prompt.split(','):
|
||||||
matches = LORAS_PROMPT_PATTERN.findall(token)
|
matches = LORAS_PROMPT_PATTERN.findall(token)
|
||||||
|
|
||||||
if matches:
|
if len(matches) == 0:
|
||||||
for match in matches:
|
prompt_without_loras += token + ', '
|
||||||
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
|
continue
|
||||||
prompt_without_loras += token.replace(match[0], '')
|
for match in matches:
|
||||||
else:
|
lora_name = match[1] + '.safetensors'
|
||||||
prompt_without_loras += token
|
if not skip_file_check:
|
||||||
prompt_without_loras += ' '
|
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
|
||||||
|
if lora_name is not None:
|
||||||
|
found_loras.append((lora_name, float(match[2])))
|
||||||
|
token = token.replace(match[0], '')
|
||||||
|
prompt_without_loras += token + ', '
|
||||||
|
|
||||||
|
if prompt_without_loras != '':
|
||||||
|
cleaned_prompt = prompt_without_loras[:-2]
|
||||||
|
|
||||||
cleaned_prompt = prompt_without_loras[:-1]
|
|
||||||
if prompt_cleanup:
|
if prompt_cleanup:
|
||||||
cleaned_prompt = cleanup_prompt(prompt_without_loras)
|
cleaned_prompt = cleanup_prompt(prompt_without_loras)
|
||||||
|
|
||||||
new_loras = []
|
new_loras = []
|
||||||
lora_names = [lora[0] for lora in loras]
|
lora_names = [lora[0] for lora in loras]
|
||||||
for found_lora in found_loras:
|
for found_lora in found_loras:
|
||||||
if deduplicate_loras and found_lora[0] in lora_names:
|
if deduplicate_loras and (found_lora[0] in lora_names or found_lora in new_loras):
|
||||||
continue
|
continue
|
||||||
new_loras.append(found_lora)
|
new_loras.append(found_lora)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ class TestUtils(unittest.TestCase):
|
||||||
def test_can_parse_tokens_with_lora(self):
|
def test_can_parse_tokens_with_lora(self):
|
||||||
test_cases = [
|
test_cases = [
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
|
||||||
"output": (
|
"output": (
|
||||||
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
||||||
},
|
},
|
||||||
# Test can not exceed limit
|
# Test can not exceed limit
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
|
||||||
"output": (
|
"output": (
|
||||||
[('hey-lora.safetensors', 0.4)],
|
[('hey-lora.safetensors', 0.4)],
|
||||||
'some prompt, very cool, cool'
|
'some prompt, very cool, cool'
|
||||||
|
|
@ -25,6 +25,7 @@ class TestUtils(unittest.TestCase):
|
||||||
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
"some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
||||||
[("hey-lora.safetensors", 0.4)],
|
[("hey-lora.safetensors", 0.4)],
|
||||||
5,
|
5,
|
||||||
|
True
|
||||||
),
|
),
|
||||||
"output": (
|
"output": (
|
||||||
[
|
[
|
||||||
|
|
@ -37,18 +38,35 @@ class TestUtils(unittest.TestCase):
|
||||||
'some prompt, very cool'
|
'some prompt, very cool'
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
# test correct matching even if there is no space separating loras in the same token
|
||||||
{
|
{
|
||||||
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
|
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
|
||||||
"output": (
|
"output": (
|
||||||
[
|
[
|
||||||
('hey-lora.safetensors', 0.4),
|
('hey-lora.safetensors', 0.4),
|
||||||
('you-lora.safetensors', 0.2)
|
('you-lora.safetensors', 0.2)
|
||||||
],
|
],
|
||||||
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
|
'some prompt, very cool'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
# test deduplication, also selected loras are never overridden with loras in prompt
|
||||||
|
{
|
||||||
|
"input": (
|
||||||
|
"some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
|
||||||
|
[('you-lora.safetensors', 0.3)],
|
||||||
|
3,
|
||||||
|
True
|
||||||
|
),
|
||||||
|
"output": (
|
||||||
|
[
|
||||||
|
('you-lora.safetensors', 0.3),
|
||||||
|
('hey-lora.safetensors', 0.4)
|
||||||
|
],
|
||||||
|
'some prompt, very cool'
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
|
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
|
||||||
"output": (
|
"output": (
|
||||||
[],
|
[],
|
||||||
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
||||||
|
|
@ -57,7 +75,7 @@ class TestUtils(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
for test in test_cases:
|
for test in test_cases:
|
||||||
prompt, loras, loras_limit = test["input"]
|
prompt, loras, loras_limit, skip_file_check = test["input"]
|
||||||
expected = test["output"]
|
expected = test["output"]
|
||||||
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit)
|
actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, skip_file_check=skip_file_check)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue