fix: add workaround for multiline prompts
This commit is contained in:
parent
349556bfa6
commit
ed4a958da8
|
|
@ -342,9 +342,6 @@ def worker():
|
|||
|
||||
progressbar(async_task, 1, 'Initializing ...')
|
||||
|
||||
raw_prompt = prompt
|
||||
raw_negative_prompt = negative_prompt
|
||||
|
||||
if not skip_prompt_processing:
|
||||
|
||||
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
|
||||
|
|
@ -805,7 +802,9 @@ def worker():
|
|||
metadata_parser = None
|
||||
if save_metadata_to_images:
|
||||
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
|
||||
metadata_parser.set_data(task['positive'], task['negative'], steps, base_model_name, refiner_model_name, loras)
|
||||
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
||||
task['log_negative_prompt'], task['negative'],
|
||||
steps, base_model_name, refiner_model_name, loras)
|
||||
|
||||
for li, (n, w) in enumerate(loras):
|
||||
if n != 'None':
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import gradio as gr
|
|||
from PIL import Image
|
||||
|
||||
import modules.config
|
||||
import modules.config
|
||||
import modules.sdxl_styles
|
||||
from modules.flags import MetadataScheme, Performance, Steps
|
||||
from modules.flags import lora_count
|
||||
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256
|
||||
|
|
@ -187,7 +187,9 @@ def get_sha256(filepath):
|
|||
|
||||
class MetadataParser(ABC):
|
||||
def __init__(self):
|
||||
self.raw_prompt: str = ''
|
||||
self.full_prompt: str = ''
|
||||
self.raw_negative_prompt: str = ''
|
||||
self.full_negative_prompt: str = ''
|
||||
self.steps: int = 30
|
||||
self.base_model_name: str = ''
|
||||
|
|
@ -208,8 +210,10 @@ class MetadataParser(ABC):
|
|||
def parse_string(self, metadata: dict) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_data(self, full_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras):
|
||||
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras):
|
||||
self.raw_prompt = raw_prompt
|
||||
self.full_prompt = full_prompt
|
||||
self.raw_negative_prompt = raw_negative_prompt
|
||||
self.full_negative_prompt = full_negative_prompt
|
||||
self.steps = steps
|
||||
self.base_model_name = Path(base_model_name).stem
|
||||
|
|
@ -235,6 +239,8 @@ class A1111MetadataParser(MetadataParser):
|
|||
return MetadataScheme.A1111
|
||||
|
||||
fooocus_to_a1111 = {
|
||||
'raw_prompt': 'Raw prompt',
|
||||
'raw_negative_prompt': 'Raw negative prompt',
|
||||
'negative_prompt': 'Negative prompt',
|
||||
'styles': 'Styles',
|
||||
'performance': 'Performance',
|
||||
|
|
@ -260,8 +266,8 @@ class A1111MetadataParser(MetadataParser):
|
|||
}
|
||||
|
||||
def parse_json(self, metadata: str) -> dict:
|
||||
prompt = ''
|
||||
negative_prompt = ''
|
||||
metadata_prompt = ''
|
||||
metadata_negative_prompt = ''
|
||||
|
||||
done_with_prompt = False
|
||||
|
||||
|
|
@ -276,16 +282,15 @@ class A1111MetadataParser(MetadataParser):
|
|||
done_with_prompt = True
|
||||
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
|
||||
if done_with_prompt:
|
||||
negative_prompt += ('' if negative_prompt == '' else "\n") + line
|
||||
metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line
|
||||
else:
|
||||
prompt += ('' if prompt == '' else "\n") + line
|
||||
metadata_prompt += ('' if metadata_prompt == '' else "\n") + line
|
||||
|
||||
found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt)
|
||||
found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt)
|
||||
|
||||
data = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt,
|
||||
'styles': str(found_styles)
|
||||
'negative_prompt': negative_prompt
|
||||
}
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
|
|
@ -295,12 +300,24 @@ class A1111MetadataParser(MetadataParser):
|
|||
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
data[f'resolution'] = str((m.group(1), m.group(2)))
|
||||
data['resolution'] = str((m.group(1), m.group(2)))
|
||||
else:
|
||||
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
|
||||
except Exception:
|
||||
print(f"Error parsing \"{k}: {v}\"")
|
||||
|
||||
# workaround for multiline prompts
|
||||
if 'raw_prompt' in data:
|
||||
data['prompt'] = data['raw_prompt']
|
||||
raw_prompt = data['raw_prompt'].replace("\n", ', ')
|
||||
if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles:
|
||||
found_styles.append(modules.sdxl_styles.fooocus_expansion)
|
||||
|
||||
if 'raw_negative_prompt' in data:
|
||||
data['negative_prompt'] = data['raw_negative_prompt']
|
||||
|
||||
data['styles'] = str(found_styles)
|
||||
|
||||
# try to load performance based on steps, fallback for direct A1111 imports
|
||||
if 'steps' in data and 'performance' not in data:
|
||||
try:
|
||||
|
|
@ -344,7 +361,11 @@ class A1111MetadataParser(MetadataParser):
|
|||
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
|
||||
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
|
||||
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
|
||||
|
||||
self.fooocus_to_a1111['performance']: data['performance'],
|
||||
# workaround for multiline prompts
|
||||
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
|
||||
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
|
||||
}
|
||||
|
||||
# TODO evaluate if this should always be added
|
||||
|
|
|
|||
|
|
@ -311,18 +311,20 @@ def extract_styles_from_prompt(prompt, negative_prompt):
|
|||
applicable_styles.remove(found_style)
|
||||
extracted.append(found_style.name)
|
||||
|
||||
# TODO check multiline prompt
|
||||
# add prompt expansion if not all styles could be resolved
|
||||
if prompt != '':
|
||||
if prompt != real_prompt:
|
||||
if real_prompt != '':
|
||||
extracted.append(modules.sdxl_styles.fooocus_expansion)
|
||||
|
||||
# find real_prompt when only prompt expansion is selected
|
||||
if real_prompt == '':
|
||||
else:
|
||||
# find real_prompt when only prompt expansion is selected
|
||||
first_word = prompt.split(', ')[0]
|
||||
first_word_positions = [i for i in range(len(prompt)) if prompt.startswith(first_word, i)]
|
||||
real_prompt = prompt[:first_word_positions[-1]]
|
||||
if real_prompt.endswith(', '):
|
||||
real_prompt = real_prompt[:-2]
|
||||
if len(first_word_positions) > 1:
|
||||
real_prompt = prompt[:first_word_positions[-1]]
|
||||
extracted.append(modules.sdxl_styles.fooocus_expansion)
|
||||
if real_prompt.endswith(', '):
|
||||
real_prompt = real_prompt[:-2]
|
||||
|
||||
return list(reversed(extracted)), real_prompt, negative_prompt
|
||||
|
||||
|
|
@ -337,6 +339,6 @@ def is_json(data: str) -> bool:
|
|||
try:
|
||||
loaded_json = json.loads(data)
|
||||
assert isinstance(loaded_json, dict)
|
||||
except ValueError:
|
||||
except (ValueError, AssertionError):
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in New Issue