feat: add A1111 prompt style detection
only detects one style as Fooocus doesn't wrap {prompt} with the whole style, but has a separate prompt string for each style
This commit is contained in:
parent
236278948b
commit
5e84a45e22
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||
import modules.config
|
||||
import fooocus_version
|
||||
# import advanced_parameters
|
||||
from modules.util import quote, unquote, is_json
|
||||
from modules.util import quote, unquote, extract_styles_from_prompt, is_json
|
||||
from modules.flags import MetadataScheme, Performance, Steps
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
|
|
@ -65,17 +65,10 @@ class A1111MetadataParser(MetadataParser):
|
|||
prompt += ('' if prompt == '' else "\n") + line
|
||||
|
||||
# set defaults
|
||||
data = {
|
||||
'styles': '[]'
|
||||
}
|
||||
# if shared.opts.infotext_styles != "Ignore":
|
||||
# found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt,
|
||||
# negative_prompt)
|
||||
#
|
||||
# if shared.opts.infotext_styles == "Apply":
|
||||
# res["Styles array"] = found_styles
|
||||
# elif shared.opts.infotext_styles == "Apply if any" and found_styles:
|
||||
# res["Styles array"] = found_styles
|
||||
data = {}
|
||||
|
||||
found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt)
|
||||
data['styles'] = str(found_styles)
|
||||
|
||||
data |= {
|
||||
'prompt': prompt,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import typing
|
||||
|
||||
import numpy as np
|
||||
import datetime
|
||||
import random
|
||||
|
|
@ -9,6 +11,7 @@ import json
|
|||
from PIL import Image
|
||||
from hashlib import sha256
|
||||
|
||||
import modules.sdxl_styles
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
|
@ -207,6 +210,102 @@ def unquote(text):
|
|||
return text
|
||||
|
||||
|
||||
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||
"""
|
||||
Checks the prompt to see if the style text is wrapped around it. If so,
|
||||
returns True plus the prompt text without the style text. Otherwise, returns
|
||||
False with the original prompt.
|
||||
|
||||
Note that the "cleaned" version of the style text is only used for matching
|
||||
purposes here. It isn't returned; the original style text is not modified.
|
||||
"""
|
||||
stripped_prompt = prompt
|
||||
stripped_style_text = style_text
|
||||
if "{prompt}" in stripped_style_text:
|
||||
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||
# return True and the "inner" prompt text that isn't part of the style.
|
||||
try:
|
||||
left, right = stripped_style_text.split("{prompt}", 2)
|
||||
except ValueError as e:
|
||||
# If the style text has multple "{prompt}"s, we can't split it into
|
||||
# two parts. This is an error, but we can't do anything about it.
|
||||
print(f"Unable to compare style text to prompt:\n{style_text}")
|
||||
print(f"Error: {e}")
|
||||
return False, prompt
|
||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
||||
return True, prompt
|
||||
else:
|
||||
# Work out whether the given prompt ends with the style text. If so, we
|
||||
# return True and the prompt text up to where the style text starts.
|
||||
if stripped_prompt.endswith(stripped_style_text):
|
||||
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
||||
if prompt.endswith(", "):
|
||||
prompt = prompt[:-2]
|
||||
return True, prompt
|
||||
|
||||
return False, prompt
|
||||
|
||||
|
||||
def extract_original_prompts(style, prompt, negative_prompt):
|
||||
"""
|
||||
Takes a style and compares it to the prompt and negative prompt. If the style
|
||||
matches, returns True plus the prompt and negative prompt with the style text
|
||||
removed. Otherwise, returns False with the original prompt and negative prompt.
|
||||
"""
|
||||
if not style.prompt and not style.negative_prompt:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
||||
style.prompt, prompt
|
||||
)
|
||||
if not match_positive:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
||||
style.negative_prompt, negative_prompt
|
||||
)
|
||||
if not match_negative:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
return True, extracted_positive, extracted_negative
|
||||
|
||||
|
||||
def extract_styles_from_prompt(prompt, negative_prompt):
|
||||
extracted = []
|
||||
applicable_styles = []
|
||||
|
||||
for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items():
|
||||
applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt))
|
||||
|
||||
while True:
|
||||
found_style = None
|
||||
|
||||
for style in applicable_styles:
|
||||
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
|
||||
style, prompt, negative_prompt
|
||||
)
|
||||
if is_match:
|
||||
found_style = style
|
||||
prompt = new_prompt
|
||||
negative_prompt = new_neg_prompt
|
||||
break
|
||||
|
||||
if not found_style:
|
||||
break
|
||||
|
||||
applicable_styles.remove(found_style)
|
||||
extracted.append(found_style.name)
|
||||
|
||||
return list(reversed(extracted)), prompt, negative_prompt
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
name: str
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
|
||||
|
||||
def is_json(data: str) -> bool:
|
||||
try:
|
||||
loaded_json = json.loads(data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue