diff --git a/modules/async_worker.py b/modules/async_worker.py index b2af6712..baadd69a 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -36,7 +36,7 @@ def worker(): import extras.face_crop import fooocus_version - from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion + from modules.sdxl_styles import apply_style, apply_wildcards, apply_wildprompts, get_all_wildprompts, fooocus_expansion from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, \ @@ -121,6 +121,8 @@ def worker(): prompt = args.pop() negative_prompt = args.pop() + wildprompt_selections = args.pop() + wildprompt_generate_all = args.pop() style_selections = args.pop() performance_selection = args.pop() aspect_ratios_selection = args.pop() @@ -153,6 +155,7 @@ def worker(): outpaint_selections = [o.lower() for o in outpaint_selections] base_model_additional_loras = [] raw_style_selections = copy.deepcopy(style_selections) + raw_wildprompt_selections = copy.deepcopy(wildprompt_selections) uov_method = uov_method.lower() if fooocus_expansion in style_selections: @@ -162,6 +165,7 @@ def worker(): use_expansion = False use_style = len(style_selections) > 0 + use_wildprompt = len(wildprompt_selections) > 0 if base_model_name == refiner_model_name: print(f'Refiner disabled because base model and refiner are same.') @@ -376,11 +380,35 @@ def worker(): progressbar(async_task, 3, 'Processing prompts ...') tasks = [] - for i in range(image_number): + wildprompts = [] + wildprompt_count = len(wildprompt_selections) + + # Get wildprompts if wildprompt_generate_all is enabled and there is only one wildprompt + if wildprompt_generate_all and wildprompt_count == 1: + wildprompts = get_all_wildprompts(wildprompt_selections) + + if len(wildprompts) > 0: + totalprompts = len(wildprompts) * image_number + else: + totalprompts = image_number + + for i in range(totalprompts): task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not task_rng = random.Random(task_seed) # may bind to inpaint noise in the future + + if len(wildprompts) > 0: + wildprompt_prompt = wildprompts[i // image_number] + else: + wildprompt_prompt = apply_wildprompts(wildprompt_selections, task_rng) if use_wildprompt else '' + wildprompt_prompt = apply_wildcards(wildprompt_prompt, task_rng) task_prompt = apply_wildcards(prompt, task_rng) + + if wildprompt_prompt and task_prompt: + task_prompt = f"{wildprompt_prompt}, {task_prompt}" + elif wildprompt_prompt: + task_prompt = wildprompt_prompt + task_negative_prompt = apply_wildcards(negative_prompt, task_rng) task_extra_positive_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_positive_prompts] task_extra_negative_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_negative_prompts] @@ -780,6 +808,7 @@ def worker(): ('Negative Prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', task['expansion']), ('Styles', str(raw_style_selections)), + ('Wildprompts', str(raw_wildprompt_selections)), ('Performance', performance_selection), ('Resolution', str((width, height))), ('Sharpness', sharpness), @@ -799,7 +828,7 @@ def worker(): if n != 'None': d.append((f'LoRA {li + 1}', f'{n} : {w}')) d.append(('Version', 'v' + fooocus_version.version)) - log(x, d) + log(x, d, str(wildprompt_selections).replace("'", "")) yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1) except ldm_patched.modules.model_management.InterruptProcessingException as e: diff --git a/modules/config.py b/modules/config.py index 58107806..923d6664 100644 --- a/modules/config.py +++ b/modules/config.py @@ -233,6 +233,11 @@ default_styles = get_config_item_or_set_default( ], validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) ) +default_wildprompts = get_config_item_or_set_default( + key='default_wildprompts', + default_value=[], + validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_wildprompt_names for y in x) +) default_prompt_negative = get_config_item_or_set_default( key='default_prompt_negative', default_value='', diff --git a/modules/private_logger.py b/modules/private_logger.py index 968bd4f5..8e61d8b5 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -12,20 +12,18 @@ log_cache = {} def get_current_html_path(): - date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, - extension='png') + date_string, local_temp_filename, only_name, logpath = generate_temp_filename(folder=modules.config.path_outputs, extension='png') html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') return html_name -def log(img, dic): +def log(img, dic, wildprompt=''): if args_manager.args.disable_image_log: return - - date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, extension='png') + date_string, local_temp_filename, only_name, logpath = generate_temp_filename(folder=modules.config.path_outputs, extension='png', wildprompt=wildprompt) os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True) Image.fromarray(img).save(local_temp_filename) - html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') + html_name = os.path.join(os.path.dirname(logpath), 'log.html') css_styles = ( "