From 0f658a97f774d164b07a3d3b7952692086dca092 Mon Sep 17 00:00:00 2001 From: lvmin Date: Mon, 11 Sep 2023 02:48:36 -0700 Subject: [PATCH] prompt expansion v2 --- fooocus_version.py | 2 +- modules/async_worker.py | 11 ++++++++--- modules/expansion.py | 5 ++--- modules/sdxl_styles.py | 10 ++++------ modules/util.py | 9 +++++++++ update_log.md | 4 ++++ 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index 4100ebc2..e6638147 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '1.0.66' +version = '1.0.67' diff --git a/modules/async_worker.py b/modules/async_worker.py index 460b7faf..cc3f7dd8 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -18,6 +18,7 @@ def worker(): from modules.sdxl_styles import apply_style_negative, apply_style_positive, aspect_ratios from modules.private_logger import log from modules.expansion import safe_str + from modules.util import join_prompts try: async_gradio_app = shared.gradio_root @@ -81,12 +82,16 @@ def worker(): outputs.append(['preview', (5, f'Preparing positive text #{i + 1} ...', None)]) current_seed = seed + i - p_txt = apply_style_positive(style_selction, prompt) + expansion_weight = 0.35 - suffix = pipeline.expansion(p_txt, current_seed) + suffix = pipeline.expansion(prompt, current_seed) + suffix = f'({suffix}:{expansion_weight})' print(f'[Prompt Expansion] New suffix: {suffix}') - p_txt = safe_str(p_txt) + suffix + p_txt = apply_style_positive(style_selction, prompt) + p_txt = safe_str(p_txt) + + p_txt = join_prompts(p_txt, suffix) tasks.append(dict( prompt=prompt, diff --git a/modules/expansion.py b/modules/expansion.py index 9ed0f285..de361e31 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -6,8 +6,7 @@ from modules.path import fooocus_expansion_path fooocus_magic_split = [ ', extremely', ', trending', - ', intricate', - # '. The', + ', intricate,', ] dangrous_patterns = '[]【】()()|::' @@ -16,7 +15,7 @@ def safe_str(x): x = str(x) for _ in range(16): x = x.replace(' ', ' ') - return x.rstrip(",. \r\n") + return x.strip(",. \r\n") def remove_pattern(x, pattern): diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 3026cb13..a2e439eb 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -1,3 +1,6 @@ +from modules.util import join_prompts + + # https://github.com/twri/sdxl_prompt_styler/blob/main/sdxl_styles.json styles = [ @@ -966,9 +969,4 @@ def apply_style_positive(style, txt): def apply_style_negative(style, txt): p, n = styles.get(style, default_style) - if n == '': - return txt - elif txt == '': - return n - else: - return n + ', ' + txt + return join_prompts(n, txt) diff --git a/modules/util.py b/modules/util.py index d13f844c..9fd56667 100644 --- a/modules/util.py +++ b/modules/util.py @@ -3,6 +3,15 @@ import random import os +def join_prompts(*args, **kwargs): + prompts = [str(x) for x in args if str(x) != ""] + if len(prompts) == 0: + return "" + if len(prompts) == 1: + return prompts[0] + return ', '.join(prompts) + + def generate_temp_filename(folder='./outputs/', extension='png'): current_time = datetime.datetime.now() date_string = current_time.strftime("%Y-%m-%d") diff --git a/update_log.md b/update_log.md index 9ea8f0e8..e171856d 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +### 1.0.67 + +* Use dynamic weighting and lower weights for prompt expansion. + ### 1.0.64 * Fixed a small OOM problem.