From ed4a958da862b53c5541ab6dd9961073ceea566c Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 2 Feb 2024 22:04:28 +0100 Subject: [PATCH] fix: add workaround for multiline prompts --- modules/async_worker.py | 7 +++---- modules/meta_parser.py | 41 +++++++++++++++++++++++++++++++---------- modules/util.py | 18 ++++++++++-------- 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index c3f9be36..1e501ebb 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -342,9 +342,6 @@ def worker(): progressbar(async_task, 1, 'Initializing ...') - raw_prompt = prompt - raw_negative_prompt = negative_prompt - if not skip_prompt_processing: prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='') @@ -805,7 +802,9 @@ def worker(): metadata_parser = None if save_metadata_to_images: metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) - metadata_parser.set_data(task['positive'], task['negative'], steps, base_model_name, refiner_model_name, loras) + metadata_parser.set_data(task['log_positive_prompt'], task['positive'], + task['log_negative_prompt'], task['negative'], + steps, base_model_name, refiner_model_name, loras) for li, (n, w) in enumerate(loras): if n != 'None': diff --git a/modules/meta_parser.py b/modules/meta_parser.py index b2931c46..aa0cd10e 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -8,7 +8,7 @@ import gradio as gr from PIL import Image import modules.config -import modules.config +import modules.sdxl_styles from modules.flags import MetadataScheme, Performance, Steps from modules.flags import lora_count from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256 @@ -187,7 +187,9 @@ def get_sha256(filepath): class MetadataParser(ABC): def __init__(self): + self.raw_prompt: str = '' self.full_prompt: str = '' + self.raw_negative_prompt: str = '' self.full_negative_prompt: str = '' self.steps: int = 30 self.base_model_name: str = '' @@ -208,8 +210,10 @@ class MetadataParser(ABC): def parse_string(self, metadata: dict) -> str: raise NotImplementedError - def set_data(self, full_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras): + def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras): + self.raw_prompt = raw_prompt self.full_prompt = full_prompt + self.raw_negative_prompt = raw_negative_prompt self.full_negative_prompt = full_negative_prompt self.steps = steps self.base_model_name = Path(base_model_name).stem @@ -235,6 +239,8 @@ class A1111MetadataParser(MetadataParser): return MetadataScheme.A1111 fooocus_to_a1111 = { + 'raw_prompt': 'Raw prompt', + 'raw_negative_prompt': 'Raw negative prompt', 'negative_prompt': 'Negative prompt', 'styles': 'Styles', 'performance': 'Performance', @@ -260,8 +266,8 @@ class A1111MetadataParser(MetadataParser): } def parse_json(self, metadata: str) -> dict: - prompt = '' - negative_prompt = '' + metadata_prompt = '' + metadata_negative_prompt = '' done_with_prompt = False @@ -276,16 +282,15 @@ class A1111MetadataParser(MetadataParser): done_with_prompt = True line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip() if done_with_prompt: - negative_prompt += ('' if negative_prompt == '' else "\n") + line + metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line else: - prompt += ('' if prompt == '' else "\n") + line + metadata_prompt += ('' if metadata_prompt == '' else "\n") + line - found_styles, prompt, negative_prompt = extract_styles_from_prompt(prompt, negative_prompt) + found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt) data = { 'prompt': prompt, - 'negative_prompt': negative_prompt, - 'styles': str(found_styles) + 'negative_prompt': negative_prompt } for k, v in re_param.findall(lastline): @@ -295,12 +300,24 @@ class A1111MetadataParser(MetadataParser): m = re_imagesize.match(v) if m is not None: - data[f'resolution'] = str((m.group(1), m.group(2))) + data['resolution'] = str((m.group(1), m.group(2))) else: data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v except Exception: print(f"Error parsing \"{k}: {v}\"") + # workaround for multiline prompts + if 'raw_prompt' in data: + data['prompt'] = data['raw_prompt'] + raw_prompt = data['raw_prompt'].replace("\n", ', ') + if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles: + found_styles.append(modules.sdxl_styles.fooocus_expansion) + + if 'raw_negative_prompt' in data: + data['negative_prompt'] = data['raw_negative_prompt'] + + data['styles'] = str(found_styles) + # try to load performance based on steps, fallback for direct A1111 imports if 'steps' in data and 'performance' not in data: try: @@ -344,7 +361,11 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'], self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, self.fooocus_to_a1111['base_model_hash']: self.base_model_hash, + self.fooocus_to_a1111['performance']: data['performance'], + # workaround for multiline prompts + self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, + self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, } # TODO evaluate if this should always be added diff --git a/modules/util.py b/modules/util.py index d4596799..4a590f51 100644 --- a/modules/util.py +++ b/modules/util.py @@ -311,18 +311,20 @@ def extract_styles_from_prompt(prompt, negative_prompt): applicable_styles.remove(found_style) extracted.append(found_style.name) + # TODO check multiline prompt # add prompt expansion if not all styles could be resolved if prompt != '': - if prompt != real_prompt: + if real_prompt != '': extracted.append(modules.sdxl_styles.fooocus_expansion) - - # find real_prompt when only prompt expansion is selected - if real_prompt == '': + else: + # find real_prompt when only prompt expansion is selected first_word = prompt.split(', ')[0] first_word_positions = [i for i in range(len(prompt)) if prompt.startswith(first_word, i)] - real_prompt = prompt[:first_word_positions[-1]] - if real_prompt.endswith(', '): - real_prompt = real_prompt[:-2] + if len(first_word_positions) > 1: + real_prompt = prompt[:first_word_positions[-1]] + extracted.append(modules.sdxl_styles.fooocus_expansion) + if real_prompt.endswith(', '): + real_prompt = real_prompt[:-2] return list(reversed(extracted)), real_prompt, negative_prompt @@ -337,6 +339,6 @@ def is_json(data: str) -> bool: try: loaded_json = json.loads(data) assert isinstance(loaded_json, dict) - except ValueError: + except (ValueError, AssertionError): return False return True