wip: add prompt style extraction for A1111 scheme

This commit is contained in:
Manuel Schmid 2024-01-29 01:52:12 +01:00
parent 5e84a45e22
commit f94b96f6eb
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 33 additions and 14 deletions

View File

@ -231,20 +231,28 @@ def unwrap_style_text_from_prompt(style_text, prompt):
# two parts. This is an error, but we can't do anything about it.
print(f"Unable to compare style text to prompt:\n{style_text}")
print(f"Error: {e}")
return False, prompt
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
return False, prompt, ''
left_pos = stripped_prompt.find(left)
right_pos = stripped_prompt.find(right)
if 0 <= left_pos < right_pos:
real_prompt = stripped_prompt[left_pos + len(left):right_pos]
prompt = stripped_prompt.replace(left + real_prompt + right, '', 1)
if prompt.startswith(", "):
prompt = prompt[2:]
if prompt.endswith(", "):
prompt = prompt[:-2]
return True, prompt, real_prompt
else:
# Work out whether the given prompt ends with the style text. If so, we
# Work out whether the given prompt starts with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
if prompt.endswith(", "):
prompt = prompt[:-2]
return True, prompt
return True, prompt, prompt
return False, prompt
return False, prompt, ''
def extract_original_prompts(style, prompt, negative_prompt):
@ -256,19 +264,19 @@ def extract_original_prompts(style, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
match_positive, extracted_positive = unwrap_style_text_from_prompt(
match_positive, extracted_positive, real_prompt = unwrap_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt
return False, prompt, negative_prompt, ''
match_negative, extracted_negative = unwrap_style_text_from_prompt(
match_negative, extracted_negative, _ = unwrap_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
return False, prompt, negative_prompt
return False, prompt, negative_prompt, ''
return True, extracted_positive, extracted_negative
return True, extracted_positive, extracted_negative, real_prompt
def extract_styles_from_prompt(prompt, negative_prompt):
@ -278,17 +286,22 @@ def extract_styles_from_prompt(prompt, negative_prompt):
for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items():
applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt))
real_prompt = ''
while True:
found_style = None
for style in applicable_styles:
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
is_match, new_prompt, new_neg_prompt, new_real_prompt = extract_original_prompts(
style, prompt, negative_prompt
)
if is_match:
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
# TODO this is a bit hacky tbh but works perfectly fine, check if all conditions are needed
if real_prompt == '' and new_real_prompt != '' and new_real_prompt != prompt:
real_prompt = new_real_prompt
break
if not found_style:
@ -297,7 +310,13 @@ def extract_styles_from_prompt(prompt, negative_prompt):
applicable_styles.remove(found_style)
extracted.append(found_style.name)
return list(reversed(extracted)), prompt, negative_prompt
# add prompt expansion if not all styles could be resolved
# TODO check if it's better to not add fooocus_expansion but just return prompt incl. fooocus_expansion words
# TODO evaluate if adding prompt expansion to metadata is a good idea
if prompt != '' and prompt != real_prompt:
extracted.append(modules.sdxl_styles.fooocus_expansion)
return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple):