From f94b96f6eb7aae9103cef7fe64ff7013dffaf49f Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 29 Jan 2024 01:52:12 +0100 Subject: [PATCH] wip: add prompt style extraction for A1111 scheme --- modules/util.py | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/modules/util.py b/modules/util.py index a6804a9a..4a709c31 100644 --- a/modules/util.py +++ b/modules/util.py @@ -231,20 +231,28 @@ def unwrap_style_text_from_prompt(style_text, prompt): # two parts. This is an error, but we can't do anything about it. print(f"Unable to compare style text to prompt:\n{style_text}") print(f"Error: {e}") - return False, prompt - if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] - return True, prompt + return False, prompt, '' + + left_pos = stripped_prompt.find(left) + right_pos = stripped_prompt.find(right) + if 0 <= left_pos < right_pos: + real_prompt = stripped_prompt[left_pos + len(left):right_pos] + prompt = stripped_prompt.replace(left + real_prompt + right, '', 1) + if prompt.startswith(", "): + prompt = prompt[2:] + if prompt.endswith(", "): + prompt = prompt[:-2] + return True, prompt, real_prompt else: - # Work out whether the given prompt ends with the style text. If so, we + # Work out whether the given prompt starts with the style text. If so, we # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] if prompt.endswith(", "): prompt = prompt[:-2] - return True, prompt + return True, prompt, prompt - return False, prompt + return False, prompt, '' def extract_original_prompts(style, prompt, negative_prompt): @@ -256,19 +264,19 @@ def extract_original_prompts(style, prompt, negative_prompt): if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = unwrap_style_text_from_prompt( + match_positive, extracted_positive, real_prompt = unwrap_style_text_from_prompt( style.prompt, prompt ) if not match_positive: - return False, prompt, negative_prompt + return False, prompt, negative_prompt, '' - match_negative, extracted_negative = unwrap_style_text_from_prompt( + match_negative, extracted_negative, _ = unwrap_style_text_from_prompt( style.negative_prompt, negative_prompt ) if not match_negative: - return False, prompt, negative_prompt + return False, prompt, negative_prompt, '' - return True, extracted_positive, extracted_negative + return True, extracted_positive, extracted_negative, real_prompt def extract_styles_from_prompt(prompt, negative_prompt): @@ -278,17 +286,22 @@ def extract_styles_from_prompt(prompt, negative_prompt): for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items(): applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt)) + real_prompt = '' + while True: found_style = None for style in applicable_styles: - is_match, new_prompt, new_neg_prompt = extract_original_prompts( + is_match, new_prompt, new_neg_prompt, new_real_prompt = extract_original_prompts( style, prompt, negative_prompt ) if is_match: found_style = style prompt = new_prompt negative_prompt = new_neg_prompt + # TODO this is a bit hacky tbh but works perfectly fine, check if all conditions are needed + if real_prompt == '' and new_real_prompt != '' and new_real_prompt != prompt: + real_prompt = new_real_prompt break if not found_style: @@ -297,7 +310,13 @@ def extract_styles_from_prompt(prompt, negative_prompt): applicable_styles.remove(found_style) extracted.append(found_style.name) - return list(reversed(extracted)), prompt, negative_prompt + # add prompt expansion if not all styles could be resolved + # TODO check if it's better to not add fooocus_expansion but just return prompt incl. fooocus_expansion words + # TODO evaluate if adding prompt expansion to metadata is a good idea + if prompt != '' and prompt != real_prompt: + extracted.append(modules.sdxl_styles.fooocus_expansion) + + return list(reversed(extracted)), real_prompt, negative_prompt class PromptStyle(typing.NamedTuple):