feat: add A1111 prompt style detection

only detects one style as Fooocus doesn't wrap {prompt} with the whole style, but has a separate prompt string for each style
This commit is contained in:
Manuel Schmid 2024-01-28 23:52:06 +01:00
parent 236278948b
commit 5e84a45e22
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 104 additions and 12 deletions

View File

@ -6,7 +6,7 @@ from PIL import Image
import modules.config
import fooocus_version
# import advanced_parameters
from modules.util import quote, unquote, is_json
from modules.util import quote, unquote, extract_styles_from_prompt, is_json
from modules.flags import MetadataScheme, Performance, Steps
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
@ -65,17 +65,10 @@ class A1111MetadataParser(MetadataParser):
prompt += ('' if prompt == '' else "\n") + line
# set defaults
data = {
'styles': '[]'
}
# if shared.opts.infotext_styles != "Ignore":
# found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt,
# negative_prompt)
#
# if shared.opts.infotext_styles == "Apply":
# res["Styles array"] = found_styles
# elif shared.opts.infotext_styles == "Apply if any" and found_styles:
# res["Styles array"] = found_styles
data = {}
found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt)
data['styles'] = str(found_styles)
data |= {
'prompt': prompt,

View File

@ -1,3 +1,5 @@
import typing
import numpy as np
import datetime
import random
@ -9,6 +11,7 @@ import json
from PIL import Image
from hashlib import sha256
import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@ -207,6 +210,102 @@ def unquote(text):
return text
def unwrap_style_text_from_prompt(style_text, prompt):
"""
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.
Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
stripped_prompt = prompt
stripped_style_text = style_text
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style.
try:
left, right = stripped_style_text.split("{prompt}", 2)
except ValueError as e:
# If the style text has multple "{prompt}"s, we can't split it into
# two parts. This is an error, but we can't do anything about it.
print(f"Unable to compare style text to prompt:\n{style_text}")
print(f"Error: {e}")
return False, prompt
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
# Work out whether the given prompt ends with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
if prompt.endswith(", "):
prompt = prompt[:-2]
return True, prompt
return False, prompt
def extract_original_prompts(style, prompt, negative_prompt):
"""
Takes a style and compares it to the prompt and negative prompt. If the style
matches, returns True plus the prompt and negative prompt with the style text
removed. Otherwise, returns False with the original prompt and negative prompt.
"""
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
match_positive, extracted_positive = unwrap_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt
match_negative, extracted_negative = unwrap_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
return False, prompt, negative_prompt
return True, extracted_positive, extracted_negative
def extract_styles_from_prompt(prompt, negative_prompt):
extracted = []
applicable_styles = []
for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items():
applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt))
while True:
found_style = None
for style in applicable_styles:
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
style, prompt, negative_prompt
)
if is_match:
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
break
if not found_style:
break
applicable_styles.remove(found_style)
extracted.append(found_style.name)
return list(reversed(extracted)), prompt, negative_prompt
class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
def is_json(data: str) -> bool:
try:
loaded_json = json.loads(data)