wip: add prompt style extraction for A1111 scheme
This commit is contained in:
parent
5e84a45e22
commit
f94b96f6eb
|
|
@ -231,20 +231,28 @@ def unwrap_style_text_from_prompt(style_text, prompt):
|
|||
# two parts. This is an error, but we can't do anything about it.
|
||||
print(f"Unable to compare style text to prompt:\n{style_text}")
|
||||
print(f"Error: {e}")
|
||||
return False, prompt
|
||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
||||
return True, prompt
|
||||
return False, prompt, ''
|
||||
|
||||
left_pos = stripped_prompt.find(left)
|
||||
right_pos = stripped_prompt.find(right)
|
||||
if 0 <= left_pos < right_pos:
|
||||
real_prompt = stripped_prompt[left_pos + len(left):right_pos]
|
||||
prompt = stripped_prompt.replace(left + real_prompt + right, '', 1)
|
||||
if prompt.startswith(", "):
|
||||
prompt = prompt[2:]
|
||||
if prompt.endswith(", "):
|
||||
prompt = prompt[:-2]
|
||||
return True, prompt, real_prompt
|
||||
else:
|
||||
# Work out whether the given prompt ends with the style text. If so, we
|
||||
# Work out whether the given prompt starts with the style text. If so, we
|
||||
# return True and the prompt text up to where the style text starts.
|
||||
if stripped_prompt.endswith(stripped_style_text):
|
||||
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
||||
if prompt.endswith(", "):
|
||||
prompt = prompt[:-2]
|
||||
return True, prompt
|
||||
return True, prompt, prompt
|
||||
|
||||
return False, prompt
|
||||
return False, prompt, ''
|
||||
|
||||
|
||||
def extract_original_prompts(style, prompt, negative_prompt):
|
||||
|
|
@ -256,19 +264,19 @@ def extract_original_prompts(style, prompt, negative_prompt):
|
|||
if not style.prompt and not style.negative_prompt:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
||||
match_positive, extracted_positive, real_prompt = unwrap_style_text_from_prompt(
|
||||
style.prompt, prompt
|
||||
)
|
||||
if not match_positive:
|
||||
return False, prompt, negative_prompt
|
||||
return False, prompt, negative_prompt, ''
|
||||
|
||||
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
||||
match_negative, extracted_negative, _ = unwrap_style_text_from_prompt(
|
||||
style.negative_prompt, negative_prompt
|
||||
)
|
||||
if not match_negative:
|
||||
return False, prompt, negative_prompt
|
||||
return False, prompt, negative_prompt, ''
|
||||
|
||||
return True, extracted_positive, extracted_negative
|
||||
return True, extracted_positive, extracted_negative, real_prompt
|
||||
|
||||
|
||||
def extract_styles_from_prompt(prompt, negative_prompt):
|
||||
|
|
@ -278,17 +286,22 @@ def extract_styles_from_prompt(prompt, negative_prompt):
|
|||
for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items():
|
||||
applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt))
|
||||
|
||||
real_prompt = ''
|
||||
|
||||
while True:
|
||||
found_style = None
|
||||
|
||||
for style in applicable_styles:
|
||||
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
|
||||
is_match, new_prompt, new_neg_prompt, new_real_prompt = extract_original_prompts(
|
||||
style, prompt, negative_prompt
|
||||
)
|
||||
if is_match:
|
||||
found_style = style
|
||||
prompt = new_prompt
|
||||
negative_prompt = new_neg_prompt
|
||||
# TODO this is a bit hacky tbh but works perfectly fine, check if all conditions are needed
|
||||
if real_prompt == '' and new_real_prompt != '' and new_real_prompt != prompt:
|
||||
real_prompt = new_real_prompt
|
||||
break
|
||||
|
||||
if not found_style:
|
||||
|
|
@ -297,7 +310,13 @@ def extract_styles_from_prompt(prompt, negative_prompt):
|
|||
applicable_styles.remove(found_style)
|
||||
extracted.append(found_style.name)
|
||||
|
||||
return list(reversed(extracted)), prompt, negative_prompt
|
||||
# add prompt expansion if not all styles could be resolved
|
||||
# TODO check if it's better to not add fooocus_expansion but just return prompt incl. fooocus_expansion words
|
||||
# TODO evaluate if adding prompt expansion to metadata is a good idea
|
||||
if prompt != '' and prompt != real_prompt:
|
||||
extracted.append(modules.sdxl_styles.fooocus_expansion)
|
||||
|
||||
return list(reversed(extracted)), real_prompt, negative_prompt
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
|
|
|
|||
Loading…
Reference in New Issue