refactor: code cleanup, unify usage of tuples in lora list

This commit is contained in:
Manuel Schmid 2024-05-18 16:43:53 +02:00
parent a78b3841c9
commit 4e610411fb
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 6 additions and 13 deletions

View File

@ -148,7 +148,7 @@ def worker():
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in range(modules.config.default_max_lora_number)])
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
@ -428,7 +428,6 @@ def worker():
progressbar(async_task, 3, 'Loading models ...')
# Parse lora references from prompt
loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,

View File

@ -2,7 +2,6 @@ import os
import re
import json
import math
#import modules.config
from modules.extra_utils import get_files_from_folder
@ -10,7 +9,6 @@ from modules.extra_utils import get_files_from_folder
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
def normalize_key(k):
k = k.replace('-', ' ')
words = k.split(' ')
@ -24,7 +22,6 @@ def normalize_key(k):
styles = {}
styles_files = get_files_from_folder(styles_path, ['.json'])
for x in ['sdxl_styles_fooocus.json',
@ -59,7 +56,7 @@ def apply_style(style, positive):
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
def get_words(arrays, totalMult, index):
def get_words(arrays, total_mult, index):
if len(arrays) == 1:
return [arrays[0].split(',')[index]]
else:
@ -68,7 +65,7 @@ def get_words(arrays, totalMult, index):
index -= index % len(words)
index /= len(words)
index = math.floor(index)
return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index)
return [word] + get_words(arrays[1:], math.floor(total_mult / len(words)), index)
def apply_arrays(text, index):

View File

@ -375,22 +375,19 @@ def ordinal_suffix(number: int) -> str:
def get_enabled_loras(loras: list) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]]
return [(lora[1], lora[2]) for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
updated_loras = []
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)