diff --git a/modules/metadata.py b/modules/metadata.py index eb97498f..0d585f1c 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -6,7 +6,7 @@ from PIL import Image import modules.config import fooocus_version # import advanced_parameters -from modules.util import quote, unquote, is_json +from modules.util import quote, unquote, extract_styles_from_prompt, is_json from modules.flags import MetadataScheme, Performance, Steps re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' @@ -65,17 +65,10 @@ class A1111MetadataParser(MetadataParser): prompt += ('' if prompt == '' else "\n") + line # set defaults - data = { - 'styles': '[]' - } - # if shared.opts.infotext_styles != "Ignore": - # found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, - # negative_prompt) - # - # if shared.opts.infotext_styles == "Apply": - # res["Styles array"] = found_styles - # elif shared.opts.infotext_styles == "Apply if any" and found_styles: - # res["Styles array"] = found_styles + data = {} + + found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt) + data['styles'] = str(found_styles) data |= { 'prompt': prompt, diff --git a/modules/util.py b/modules/util.py index f7fcc4e7..a6804a9a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,3 +1,5 @@ +import typing + import numpy as np import datetime import random @@ -9,6 +11,7 @@ import json from PIL import Image from hashlib import sha256 +import modules.sdxl_styles LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -207,6 +210,102 @@ def unquote(text): return text +def unwrap_style_text_from_prompt(style_text, prompt): + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. + + Note that the "cleaned" version of the style text is only used for matching + purposes here. It isn't returned; the original style text is not modified. + """ + stripped_prompt = prompt + stripped_style_text = style_text + if "{prompt}" in stripped_style_text: + # Work out whether the prompt is wrapped in the style text. If so, we + # return True and the "inner" prompt text that isn't part of the style. + try: + left, right = stripped_style_text.split("{prompt}", 2) + except ValueError as e: + # If the style text has multple "{prompt}"s, we can't split it into + # two parts. This is an error, but we can't do anything about it. + print(f"Unable to compare style text to prompt:\n{style_text}") + print(f"Error: {e}") + return False, prompt + if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): + prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] + return True, prompt + else: + # Work out whether the given prompt ends with the style text. If so, we + # return True and the prompt text up to where the style text starts. + if stripped_prompt.endswith(stripped_style_text): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): + prompt = prompt[:-2] + return True, prompt + + return False, prompt + + +def extract_original_prompts(style, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ + if not style.prompt and not style.negative_prompt: + return False, prompt, negative_prompt + + match_positive, extracted_positive = unwrap_style_text_from_prompt( + style.prompt, prompt + ) + if not match_positive: + return False, prompt, negative_prompt + + match_negative, extracted_negative = unwrap_style_text_from_prompt( + style.negative_prompt, negative_prompt + ) + if not match_negative: + return False, prompt, negative_prompt + + return True, extracted_positive, extracted_negative + + +def extract_styles_from_prompt(prompt, negative_prompt): + extracted = [] + applicable_styles = [] + + for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items(): + applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt)) + + while True: + found_style = None + + for style in applicable_styles: + is_match, new_prompt, new_neg_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) + if is_match: + found_style = style + prompt = new_prompt + negative_prompt = new_neg_prompt + break + + if not found_style: + break + + applicable_styles.remove(found_style) + extracted.append(found_style.name) + + return list(reversed(extracted)), prompt, negative_prompt + + +class PromptStyle(typing.NamedTuple): + name: str + prompt: str + negative_prompt: str + + def is_json(data: str) -> bool: try: loaded_json = json.loads(data)