diff --git a/args_manager.py b/args_manager.py index b1584d8b..9dcb9a49 100644 --- a/args_manager.py +++ b/args_manager.py @@ -25,6 +25,9 @@ args_parser.parser.add_argument("--disable-image-log", action='store_true', args_parser.parser.add_argument("--disable-analytics", action='store_true', help="Disables analytics for Gradio.") +args_parser.parser.add_argument("--disable-metadata", action='store_true', + help="Disables saving metadata to images.") + args_parser.parser.set_defaults( disable_cuda_malloc=True, in_browser=True, diff --git a/modules/async_worker.py b/modules/async_worker.py index 63372070..79b620ae 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -15,8 +15,11 @@ async_tasks = [] def worker(): global async_tasks + + import os import traceback import math + import json import numpy as np import torch import time @@ -36,6 +39,7 @@ def worker(): import extras.ip_adapter as ip_adapter import extras.face_crop import fooocus_version + import args_manager from modules.censor import censor_batch @@ -43,7 +47,7 @@ def worker(): from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate + get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256, quote from modules.upscaler import perform_upscale try: @@ -150,6 +154,8 @@ def worker(): inpaint_input_image = args.pop() inpaint_additional_prompt = args.pop() inpaint_mask_image_upload = args.pop() + save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False + metadata_scheme = args.pop() if not args_manager.args.disable_metadata else 'fooocus' cn_tasks = {x: [] for x in flags.ip_list} for _ in range(4): @@ -212,6 +218,17 @@ def worker(): prompt = translate2en(prompt, 'prompt') negative_prompt = translate2en(negative_prompt, 'negative prompt') + if not args_manager.args.disable_metadata: + base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name) + base_model_hash = calculate_sha256(base_model_path)[0:10] + + lora_hashes = [] + for (n, w) in loras: + if n != 'None': + lora_path = os.path.join(modules.config.path_loras, n) + lora_hashes.append(f'{n.split(".")[0]}: {calculate_sha256(lora_path)[0:10]}') + lora_hashes_string = ", ".join(lora_hashes) + modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}') @@ -369,6 +386,9 @@ def worker(): progressbar(async_task, 1, 'Initializing ...') + raw_prompt = prompt + raw_negative_prompt = negative_prompt + if not skip_prompt_processing: prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='') @@ -527,7 +547,7 @@ def worker(): if direct_return: d = [('Upscale (Fast)', '2x')] - uov_input_image_path = log(uov_input_image, d, image_file_extension) + uov_input_image_path = log(uov_input_image, d, image_file_extension=image_file_extension) yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True) return @@ -793,6 +813,99 @@ def worker(): imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] img_paths = [] + metadata_string = '' + if save_metadata_to_images and metadata_scheme == 'fooocus': + metadata = { + 'prompt': raw_prompt, 'negative_prompt': raw_negative_prompt, 'styles': str(raw_style_selections), + 'real_prompt': task['log_positive_prompt'], 'real_negative_prompt': task['log_negative_prompt'], + 'seed': task['task_seed'], 'width': width, 'height': height, + 'sampler': sampler_name, 'scheduler': scheduler_name, 'performance': performance_selection, + 'steps': steps, 'refiner_switch': refiner_switch, 'sharpness': sharpness, 'cfg': cfg_scale, + 'base_model': base_model_name, 'refiner_model': refiner_model_name, + 'denoising_strength': denoising_strength, + 'freeu': advanced_parameters.freeu_enabled, + 'img2img': input_image_checkbox, + 'prompt_expansion': task['expansion'] + } + + + if advanced_parameters.freeu_enabled: + metadata |= { + 'freeu_b1': advanced_parameters.freeu_b1, 'freeu_b2': advanced_parameters.freeu_b2, 'freeu_s1': advanced_parameters.freeu_s1, 'freeu_s2': advanced_parameters.freeu_s2 + } + + if 'vary' in goals: + metadata |= { + 'uov_method': uov_method + } + + if 'upscale' in goals: + metadata |= { + 'uov_method': uov_method, 'scale': f + } + + if 'inpaint' in goals: + if len(outpaint_selections) > 0: + metadata |= { + 'outpaint_selections': outpaint_selections + } + else: + metadata |= { + 'inpaint_additional_prompt': inpaint_additional_prompt, 'inpaint_mask_upload': advanced_parameters.inpaint_mask_upload_checkbox, 'invert_mask': advanced_parameters.invert_mask_checkbox, + 'inpaint_disable_initial_latent': advanced_parameters.inpaint_disable_initial_latent, 'inpaint_engine': advanced_parameters.inpaint_engine, + 'inpaint_strength': advanced_parameters.inpaint_strength, 'inpaint_respective_field': advanced_parameters.inpaint_respective_field, + } + + if 'cn' in goals: + metadata |= { + 'canny_low_threshold': advanced_parameters.canny_low_threshold, 'canny_high_threshold': advanced_parameters.canny_high_threshold, + } + + ip_list = {x: [] for x in flags.ip_list} + cn_task_index = 1 + for cn_type in ip_list: + for cn_task in cn_tasks[cn_type]: + cn_img, cn_stop, cn_weight = cn_task + metadata |= { + f'image_prompt_{cn_task_index}': { + 'cn_type': cn_type, 'cn_stop': cn_stop, 'cn_weight': cn_weight, + } + } + cn_task_index += 1 + + metadata |= { + 'software': f'Fooocus v{fooocus_version.version}', + } + if modules.config.metadata_created_by != 'None': + metadata |= { + 'created_by': modules.config.metadata_created_by + } + metadata_string = json.dumps(metadata, ensure_ascii=False) + elif save_metadata_to_images and metadata_scheme == 'a1111': + generation_params = { + "Steps": steps, + "Sampler": sampler_name, + "CFG scale": cfg_scale, + "Seed": task['task_seed'], + "Size": f"{width}x{height}", + "Model hash": base_model_hash, + "Model": base_model_name.split('.')[0], + "Lora hashes": lora_hashes_string, + "Denoising strength": denoising_strength, + "Version": f'Fooocus v{fooocus_version.version}' + } + + if modules.config.metadata_created_by != 'None': + generation_params |= { + 'Created By': f'{modules.config.metadata_created_by}' + } + + generation_params_text = ", ".join([k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None]) + positive_prompt_resolved = ', '.join(task['positive']) + negative_prompt_resolved = ', '.join(task['negative']) + negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" + metadata_string = f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() + for x in imgs: d = [ ('Prompt', task['log_positive_prompt']), @@ -819,7 +932,7 @@ def worker(): if n != 'None': d.append((f'LoRA {li + 1}', f'{n} : {w}')) d.append(('Version', 'v' + fooocus_version.version)) - img_paths.append(log(x, d, image_file_extension)) + img_paths.append(log(x, d, metadata_string, save_metadata_to_images, image_file_extension)) yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except ldm_patched.modules.model_management.InterruptProcessingException as e: diff --git a/modules/config.py b/modules/config.py index ddbc01b7..6e5fcf9c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -361,6 +361,21 @@ example_inpaint_prompts = get_config_item_or_set_default( ], validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) ) +default_save_metadata_to_images = get_config_item_or_set_default( + key='default_save_metadata_to_images', + default_value=False, + validator=lambda x: isinstance(x, bool) +) +default_metadata_scheme = get_config_item_or_set_default( + key='default_metadata_scheme', + default_value='fooocus', + validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] +) +metadata_created_by = get_config_item_or_set_default( + key='metadata_created_by', + default_value='', + validator=lambda x: isinstance(x, str) +) example_inpaint_prompts = [[x] for x in example_inpaint_prompts] diff --git a/modules/flags.py b/modules/flags.py index 3462ec80..121c4c1b 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -32,6 +32,11 @@ default_parameters = { cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0) } # stop, weight +metadata_scheme =[ + ('Fooocus (json)', 'fooocus'), + ('A1111 (plain text)', 'a1111'), +] + inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] performance_selections = [ diff --git a/modules/private_logger.py b/modules/private_logger.py index 4d2427ad..354a50c4 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -5,6 +5,7 @@ import json import urllib.parse from PIL import Image +from PIL.PngImagePlugin import PngInfo from modules.util import generate_temp_filename from tempfile import gettempdir @@ -19,12 +20,27 @@ def get_current_html_path(image_extension=None): return html_name -def log(img, dic, image_extension=None) -> str: +def log(img, dic, metadata=None, save_metadata_to_image=False, image_file_extension=None) -> str: path_outputs = args_manager.args.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs - _image_extension = image_extension if image_extension else modules.config.default_image_extension - date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=_image_extension) + image_file_extension = image_file_extension if image_file_extension else modules.config.default_image_file_extension + date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=image_file_extension) os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True) - Image.fromarray(img).save(local_temp_filename) + + if image_file_extension == 'png': + if save_metadata_to_image: + pnginfo = PngInfo() + pnginfo.add_text("parameters", metadata) + else: + pnginfo = None + Image.fromarray(img).save(local_temp_filename, pnginfo=pnginfo) + elif image_file_extension == 'jpg': + # TODO check if metadata works correctly here + Image.fromarray(img).save(local_temp_filename, quality=95, optimize=True, progressive=True, comment=metadata if save_metadata_to_image else None) + elif image_file_extension == 'webp': + # TODO test exif handling + Image.fromarray(img).save(local_temp_filename, quality=95, lossless=False) + else: + Image.fromarray(img).save(local_temp_filename) if args_manager.args.disable_image_log: return local_temp_filename diff --git a/modules/util.py b/modules/util.py index de9bd5b9..8430f5b8 100644 --- a/modules/util.py +++ b/modules/util.py @@ -4,8 +4,10 @@ import random import math import os import cv2 +import json from PIL import Image +from hashlib import sha256 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -174,4 +176,20 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None): path = os.path.join(relative_path, filename) filenames.append(path) - return filenames + return sorted(filenames, key=lambda x: -1 if os.sep in x else 1) + +def calculate_sha256(filename): + hash_sha256 = sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + +def quote(text): + if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text): + return text + + return json.dumps(text, ensure_ascii=False) diff --git a/readme.md b/readme.md index f0f09450..a4003608 100644 --- a/readme.md +++ b/readme.md @@ -25,6 +25,7 @@ Included adjustments: * ✨ https://github.com/lllyasviel/Fooocus/pull/1932 - use consistent file name in gradio * ✨ https://github.com/lllyasviel/Fooocus/pull/1863 - image extension support (png, jpg, webp) * ✨ https://github.com/lllyasviel/Fooocus/pull/1938 - automatically describe image on uov image upload if prompt is empty +* ✨ https://github.com/lllyasviel/Fooocus/pull/1940 - meta data handling, schemes: Fooocus (json) and A1111 (plain text). Compatible with Civitai. ✨ = new feature
🐛 = bugfix
diff --git a/webui.py b/webui.py index c5ab135d..fdd4642e 100644 --- a/webui.py +++ b/webui.py @@ -421,6 +421,16 @@ with shared.gradio_root: black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False) + if not args_manager.args.disable_metadata: + save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images, + info='Adds parameters to generated images allowing manual regeneration.') + metadata_scheme = gr.Radio(label='Metadata Scheme', choices=flags.metadata_scheme, value=modules.config.default_metadata_scheme, + info='Use A1111 for compatibility with Civitai.', + visible=modules.config.default_save_metadata_to_images) + + save_metadata_to_images.change(lambda x: gr.update(visible=x), inputs=[save_metadata_to_images], outputs=[metadata_scheme], + queue=False, show_progress=False) + with gr.Tab(label='Control'): debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False, info='See the results from preprocessors.') @@ -641,6 +651,10 @@ with shared.gradio_root: ctrls += [input_image_checkbox, current_tab] ctrls += [uov_method, uov_input_image] ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] + + if not args_manager.args.disable_metadata: + ctrls += [save_metadata_to_images, metadata_scheme] + ctrls += ip_ctrls def parse_meta(raw_prompt_txt, is_generating):