refactor: code cleanup, unify usage of tuples in lora list
This commit is contained in:
parent
a78b3841c9
commit
4e610411fb
|
|
@ -148,7 +148,7 @@ def worker():
|
|||
base_model_name = args.pop()
|
||||
refiner_model_name = args.pop()
|
||||
refiner_switch = args.pop()
|
||||
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
|
||||
loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in range(modules.config.default_max_lora_number)])
|
||||
input_image_checkbox = args.pop()
|
||||
current_tab = args.pop()
|
||||
uov_method = args.pop()
|
||||
|
|
@ -428,7 +428,6 @@ def worker():
|
|||
|
||||
progressbar(async_task, 3, 'Loading models ...')
|
||||
|
||||
# Parse lora references from prompt
|
||||
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
|
||||
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import os
|
|||
import re
|
||||
import json
|
||||
import math
|
||||
#import modules.config
|
||||
|
||||
from modules.extra_utils import get_files_from_folder
|
||||
|
||||
|
|
@ -10,7 +9,6 @@ from modules.extra_utils import get_files_from_folder
|
|||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||
|
||||
|
||||
|
||||
def normalize_key(k):
|
||||
k = k.replace('-', ' ')
|
||||
words = k.split(' ')
|
||||
|
|
@ -24,7 +22,6 @@ def normalize_key(k):
|
|||
|
||||
|
||||
styles = {}
|
||||
|
||||
styles_files = get_files_from_folder(styles_path, ['.json'])
|
||||
|
||||
for x in ['sdxl_styles_fooocus.json',
|
||||
|
|
@ -59,7 +56,7 @@ def apply_style(style, positive):
|
|||
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
||||
|
||||
|
||||
def get_words(arrays, totalMult, index):
|
||||
def get_words(arrays, total_mult, index):
|
||||
if len(arrays) == 1:
|
||||
return [arrays[0].split(',')[index]]
|
||||
else:
|
||||
|
|
@ -68,7 +65,7 @@ def get_words(arrays, totalMult, index):
|
|||
index -= index % len(words)
|
||||
index /= len(words)
|
||||
index = math.floor(index)
|
||||
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index)
|
||||
return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
|
||||
|
||||
|
||||
def apply_arrays(text, index):
|
||||
|
|
|
|||
|
|
@ -375,22 +375,19 @@ def ordinal_suffix(number: int) -> str:
|
|||
|
||||
|
||||
def get_enabled_loras(loras: list) -> list:
|
||||
return [[lora[1], lora[2]] for lora in loras if lora[0]]
|
||||
return [(lora[1], lora[2]) for lora in loras if lora[0]]
|
||||
|
||||
|
||||
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
|
||||
|
||||
new_loras = []
|
||||
|
||||
updated_loras = []
|
||||
for token in prompt.split(","):
|
||||
|
||||
m = LORAS_PROMPT_PATTERN.match(token)
|
||||
|
||||
if m:
|
||||
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
|
||||
|
||||
updated_loras = []
|
||||
for lora in loras + new_loras:
|
||||
|
||||
if lora[0] != "None":
|
||||
updated_loras.append(lora)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue