fix: add workaround for multiline prompts

This commit is contained in:
Manuel Schmid 2024-02-02 22:04:28 +01:00
parent 349556bfa6
commit ed4a958da8
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 44 additions and 22 deletions

View File

@ -342,9 +342,6 @@ def worker():
progressbar(async_task, 1, 'Initializing ...')
raw_prompt = prompt
raw_negative_prompt = negative_prompt
if not skip_prompt_processing:
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
@ -805,7 +802,9 @@ def worker():
metadata_parser = None
if save_metadata_to_images:
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['positive'], task['negative'], steps, base_model_name, refiner_model_name, loras)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras)
for li, (n, w) in enumerate(loras):
if n != 'None':

View File

@ -8,7 +8,7 @@ import gradio as gr
from PIL import Image
import modules.config
import modules.config
import modules.sdxl_styles
from modules.flags import MetadataScheme, Performance, Steps
from modules.flags import lora_count
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256
@ -187,7 +187,9 @@ def get_sha256(filepath):
class MetadataParser(ABC):
def __init__(self):
self.raw_prompt: str = ''
self.full_prompt: str = ''
self.raw_negative_prompt: str = ''
self.full_negative_prompt: str = ''
self.steps: int = 30
self.base_model_name: str = ''
@ -208,8 +210,10 @@ class MetadataParser(ABC):
def parse_string(self, metadata: dict) -> str:
raise NotImplementedError
def set_data(self, full_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras):
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras):
self.raw_prompt = raw_prompt
self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt
self.full_negative_prompt = full_negative_prompt
self.steps = steps
self.base_model_name = Path(base_model_name).stem
@ -235,6 +239,8 @@ class A1111MetadataParser(MetadataParser):
return MetadataScheme.A1111
fooocus_to_a1111 = {
'raw_prompt': 'Raw prompt',
'raw_negative_prompt': 'Raw negative prompt',
'negative_prompt': 'Negative prompt',
'styles': 'Styles',
'performance': 'Performance',
@ -260,8 +266,8 @@ class A1111MetadataParser(MetadataParser):
}
def parse_json(self, metadata: str) -> dict:
prompt = ''
negative_prompt = ''
metadata_prompt = ''
metadata_negative_prompt = ''
done_with_prompt = False
@ -276,16 +282,15 @@ class A1111MetadataParser(MetadataParser):
done_with_prompt = True
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
if done_with_prompt:
negative_prompt += ('' if negative_prompt == '' else "\n") + line
metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line
else:
prompt += ('' if prompt == '' else "\n") + line
metadata_prompt += ('' if metadata_prompt == '' else "\n") + line
found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt)
found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt)
data = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'styles': str(found_styles)
'negative_prompt': negative_prompt
}
for k, v in re_param.findall(lastline):
@ -295,12 +300,24 @@ class A1111MetadataParser(MetadataParser):
m = re_imagesize.match(v)
if m is not None:
data[f'resolution'] = str((m.group(1), m.group(2)))
data['resolution'] = str((m.group(1), m.group(2)))
else:
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# workaround for multiline prompts
if 'raw_prompt' in data:
data['prompt'] = data['raw_prompt']
raw_prompt = data['raw_prompt'].replace("\n", ', ')
if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles:
found_styles.append(modules.sdxl_styles.fooocus_expansion)
if 'raw_negative_prompt' in data:
data['negative_prompt'] = data['raw_negative_prompt']
data['styles'] = str(found_styles)
# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
try:
@ -344,7 +361,11 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
self.fooocus_to_a1111['performance']: data['performance'],
# workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
}
# TODO evaluate if this should always be added

View File

@ -311,18 +311,20 @@ def extract_styles_from_prompt(prompt, negative_prompt):
applicable_styles.remove(found_style)
extracted.append(found_style.name)
# TODO check multiline prompt
# add prompt expansion if not all styles could be resolved
if prompt != '':
if prompt != real_prompt:
if real_prompt != '':
extracted.append(modules.sdxl_styles.fooocus_expansion)
# find real_prompt when only prompt expansion is selected
if real_prompt == '':
else:
# find real_prompt when only prompt expansion is selected
first_word = prompt.split(', ')[0]
first_word_positions = [i for i in range(len(prompt)) if prompt.startswith(first_word, i)]
real_prompt = prompt[:first_word_positions[-1]]
if real_prompt.endswith(', '):
real_prompt = real_prompt[:-2]
if len(first_word_positions) > 1:
real_prompt = prompt[:first_word_positions[-1]]
extracted.append(modules.sdxl_styles.fooocus_expansion)
if real_prompt.endswith(', '):
real_prompt = real_prompt[:-2]
return list(reversed(extracted)), real_prompt, negative_prompt
@ -337,6 +339,6 @@ def is_json(data: str) -> bool:
try:
loaded_json = json.loads(data)
assert isinstance(loaded_json, dict)
except ValueError:
except (ValueError, AssertionError):
return False
return True