feat: allow multiple matches for each token, optimize and extract method cleanup_prompt

This commit is contained in:
Manuel Schmid 2024-05-20 17:28:55 +02:00
parent faa985c71c
commit 0d1310d9e9
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 43 additions and 12 deletions

View File

@ -12,15 +12,15 @@ import hashlib
from PIL import Image
import modules.config
import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)
HASH_SHA256_LENGTH = 10
@ -372,26 +372,57 @@ def get_file_from_folder_list(name, folders):
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
def get_enabled_loras(loras: list, remove_none=True) -> list:
return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]
def get_enabled_loras(loras: list) -> list:
return [(lora[1], lora[2]) for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ""
for token in prompt.split(" "):
matches = LORAS_PROMPT_PATTERN.findall(token)
if matches:
for match in matches:
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
prompt_without_loras += token.replace(match[0], '')
else:
prompt_without_loras += token
prompt_without_loras += ' '
cleaned_prompt = prompt_without_loras[:-1]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and found_lora[0] in lora_names:
continue
new_loras.append(found_lora)
if len(new_loras) == 0:
return loras, cleaned_prompt
updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)
return updated_loras[:loras_limit]
return updated_loras[:loras_limit], cleaned_prompt
def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)
cleaned_prompt = ''
for token in prompt.split(','):
token = token.strip()
if token == '':
continue
cleaned_prompt += token + ', '
return cleaned_prompt[:-2]
def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: