From 0d1310d9e9f9eb15efc85caa228d4be0404dca29 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 20 May 2024 17:28:55 +0200 Subject: [PATCH] feat: allow multiple matches for each token, optimize and extract method cleanup_prompt --- modules/util.py | 55 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/modules/util.py b/modules/util.py index a245157f..52bc490a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,15 +12,15 @@ import hashlib from PIL import Image +import modules.config import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - # Regexp compiled once. Matches entries with the following pattern: # # -LORAS_PROMPT_PATTERN = re.compile(r".* .*", re.X) +LORAS_PROMPT_PATTERN = re.compile(r"()", re.X) HASH_SHA256_LENGTH = 10 @@ -372,26 +372,57 @@ def get_file_from_folder_list(name, folders): return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) +def get_enabled_loras(loras: list, remove_none=True) -> list: + return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)] -def get_enabled_loras(loras: list) -> list: - return [(lora[1], lora[2]) for lora in loras if lora[0]] +def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5, + prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]: + found_loras = [] + prompt_without_loras = "" + for token in prompt.split(" "): + matches = LORAS_PROMPT_PATTERN.findall(token) + if matches: + for match in matches: + found_loras.append((f"{match[1]}.safetensors", float(match[2]))) + prompt_without_loras += token.replace(match[0], '') + else: + prompt_without_loras += token + prompt_without_loras += ' ' + + cleaned_prompt = prompt_without_loras[:-1] + if prompt_cleanup: + cleaned_prompt = cleanup_prompt(prompt_without_loras) -def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]: new_loras = [] + lora_names = [lora[0] for lora in loras] + for found_lora in found_loras: + if deduplicate_loras and found_lora[0] in lora_names: + continue + new_loras.append(found_lora) + + if len(new_loras) == 0: + return loras, cleaned_prompt + updated_loras = [] - for token in prompt.split(","): - m = LORAS_PROMPT_PATTERN.match(token) - - if m: - new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2)))) - for lora in loras + new_loras: if lora[0] != "None": updated_loras.append(lora) - return updated_loras[:loras_limit] + return updated_loras[:loras_limit], cleaned_prompt + + +def cleanup_prompt(prompt): + prompt = re.sub(' +', ' ', prompt) + prompt = re.sub(',+', ',', prompt) + cleaned_prompt = '' + for token in prompt.split(','): + token = token.strip() + if token == '': + continue + cleaned_prompt += token + ', ' + return cleaned_prompt[:-2] def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str: