diff --git a/modules/async_worker.py b/modules/async_worker.py index e563e667..ddef06c1 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -17,7 +17,6 @@ def worker(): import os import traceback import math - import json import numpy as np import torch import time @@ -43,8 +42,9 @@ def worker(): from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256, quote + get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256 from modules.upscaler import perform_upscale + from modules.metadata import MetadataScheme try: async_gradio_app = shared.gradio_root @@ -144,7 +144,8 @@ def worker(): inpaint_additional_prompt = args.pop() inpaint_mask_image_upload = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = args.pop() if not args_manager.args.disable_metadata else 'fooocus' + metadata_scheme = args.pop() if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS.value + assert metadata_scheme in [item.value for item in MetadataScheme] cn_tasks = {x: [] for x in flags.ip_list} for _ in range(4): @@ -793,129 +794,37 @@ def worker(): if inpaint_worker.current_task is not None: imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] - metadata_string = '' - if save_metadata_to_images and metadata_scheme == 'fooocus': - metadata = { - # prompt with wildcards - 'prompt': raw_prompt, 'negative_prompt': raw_negative_prompt, - # prompt with resolved wildcards - 'real_prompt': task['log_positive_prompt'], 'real_negative_prompt': task['log_negative_prompt'], - # prompt with resolved wildcards, styles and prompt expansion - 'complete_prompt_positive': task['positive'], 'complete_prompt_negative': task['negative'], - 'styles': str(raw_style_selections), - 'seed': task['task_seed'], 'width': width, 'height': height, - 'sampler': sampler_name, 'scheduler': scheduler_name, 'performance': performance_selection, - 'steps': steps, 'refiner_switch': refiner_switch, 'sharpness': sharpness, 'cfg': cfg_scale, - 'base_model': base_model_name, 'base_model_hash': base_model_hash, 'refiner_model': refiner_model_name, - 'denoising_strength': denoising_strength, - 'freeu': advanced_parameters.freeu_enabled, - 'img2img': input_image_checkbox, - 'prompt_expansion': task['expansion'] - } - - if advanced_parameters.freeu_enabled: - metadata |= { - 'freeu_b1': advanced_parameters.freeu_b1, 'freeu_b2': advanced_parameters.freeu_b2, 'freeu_s1': advanced_parameters.freeu_s1, 'freeu_s2': advanced_parameters.freeu_s2 - } - - if 'vary' in goals: - metadata |= { - 'uov_method': uov_method - } - - if 'upscale' in goals: - metadata |= { - 'uov_method': uov_method, 'scale': f - } - - if 'inpaint' in goals: - if len(outpaint_selections) > 0: - metadata |= { - 'outpaint_selections': outpaint_selections - } - - metadata |= { - 'inpaint_additional_prompt': inpaint_additional_prompt, 'inpaint_mask_upload': advanced_parameters.inpaint_mask_upload_checkbox, 'invert_mask': advanced_parameters.invert_mask_checkbox, - 'inpaint_disable_initial_latent': advanced_parameters.inpaint_disable_initial_latent, 'inpaint_engine': advanced_parameters.inpaint_engine, - 'inpaint_strength': advanced_parameters.inpaint_strength, 'inpaint_respective_field': advanced_parameters.inpaint_respective_field, - } - - if 'cn' in goals: - metadata |= { - 'canny_low_threshold': advanced_parameters.canny_low_threshold, 'canny_high_threshold': advanced_parameters.canny_high_threshold, - } - - ip_list = {x: [] for x in flags.ip_list} - cn_task_index = 1 - for cn_type in ip_list: - for cn_task in cn_tasks[cn_type]: - cn_img, cn_stop, cn_weight = cn_task - metadata |= { - f'image_prompt_{cn_task_index}': { - 'cn_type': cn_type, 'cn_stop': cn_stop, 'cn_weight': cn_weight, - } - } - cn_task_index += 1 - - metadata |= { - 'software': f'Fooocus v{fooocus_version.version}', - } - if modules.config.metadata_created_by != '': - metadata |= { - 'created_by': modules.config.metadata_created_by - } - metadata_string = json.dumps(metadata, ensure_ascii=False) - elif save_metadata_to_images and metadata_scheme == 'a1111': - generation_params = { - "Steps": steps, - "Sampler": sampler_name, - "CFG scale": cfg_scale, - "Seed": task['task_seed'], - "Size": f"{width}x{height}", - "Model hash": base_model_hash, - "Model": base_model_name.split('.')[0], - "Lora hashes": lora_hashes_string, - "Denoising strength": denoising_strength, - "Version": f'Fooocus v{fooocus_version.version}' - } - - if modules.config.metadata_created_by != '': - generation_params |= { - 'Created By': f'{modules.config.metadata_created_by}' - } - - generation_params_text = ", ".join([k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None]) - positive_prompt_resolved = ', '.join(task['positive']) - negative_prompt_resolved = ', '.join(task['negative']) - negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" - metadata_string = f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() - for x in imgs: d = [ - ('Prompt', task['log_positive_prompt']), - ('Negative Prompt', task['log_negative_prompt']), - ('Fooocus V2 Expansion', task['expansion']), - ('Styles', str(raw_style_selections)), - ('Performance', performance_selection), - ('Resolution', str((width, height))), - ('Sharpness', sharpness), - ('Guidance Scale', guidance_scale), - ('ADM Guidance', str(( + ('Prompt', 'prompt', task['log_positive_prompt'], True, True), + ('Full Positive Prompt', 'full_prompt', task['positive'], False, False), + ('Negative Prompt', 'negative_prompt', task['log_negative_prompt'], True, True), + ('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False), + ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True), + ('Styles', 'styles', str(raw_style_selections), True, True), + ('Performance', 'performance', performance_selection, True, True), + ('Steps', 'steps', steps, False, False), + ('Resolution', 'resolution', str((width, height)), True, True), + ('Sharpness', 'sharpness', sharpness, True, True), + ('Guidance Scale', 'guidance_scale', guidance_scale, True, True), + ('ADM Guidance', 'adm_guidance', str(( modules.patch.positive_adm_scale, modules.patch.negative_adm_scale, - modules.patch.adm_scaler_end))), - ('Base Model', base_model_name), - ('Refiner Model', refiner_model_name), - ('Refiner Switch', refiner_switch), - ('Sampler', sampler_name), - ('Scheduler', scheduler_name), - ('Seed', task['task_seed']), + modules.patch.adm_scaler_end)), True, True), + ('Base Model', 'base_model', base_model_name, True, True), + ('Refiner Model', 'refiner_model', refiner_model_name, True, True), + ('Refiner Switch', 'refiner_switch', refiner_switch, True, True), + ('Sampler', 'sampler', sampler_name, True, True), + ('Scheduler', 'scheduler', scheduler_name, True, True), + ('Seed', 'seed', task['task_seed'], True, True) ] for li, (n, w) in enumerate(loras): if n != 'None': - d.append((f'LoRA {li + 1}', f'{n} : {w}')) - d.append(('Version', 'v' + fooocus_version.version)) - log(x, d, metadata_string, save_metadata_to_images) + d.append((f'LoRA {li + 1}', f'lora{li + 1}_combined', f'{n} : {w}', True, True)) + # d.append((f'LoRA {li + 1} Name', f'lora{li + 1}_name', n, False, False)) + # d.append((f'LoRA {li + 1} Weight', f'lora{li + 1}_weight', n, False, False)) + d.append(('Version', 'version', 'v' + fooocus_version.version, True, True)) + log(x, d, save_metadata_to_images, metadata_scheme) yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1) except ldm_patched.modules.model_management.InterruptProcessingException as e: diff --git a/modules/config.py b/modules/config.py index 924284c5..4a9b6837 100644 --- a/modules/config.py +++ b/modules/config.py @@ -6,6 +6,7 @@ import args_manager import modules.flags import modules.sdxl_styles +from modules.metadata import MetadataScheme from modules.model_loader import load_file_from_url from modules.util import get_files_from_folder @@ -322,7 +323,7 @@ default_save_metadata_to_images = get_config_item_or_set_default( ) default_metadata_scheme = get_config_item_or_set_default( key='default_metadata_scheme', - default_value='fooocus', + default_value=MetadataScheme.FOOOCUS.value, validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] ) metadata_created_by = get_config_item_or_set_default( diff --git a/modules/flags.py b/modules/flags.py index fd346f2e..abcd3f60 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -1,3 +1,5 @@ +from modules.metadata import MetadataScheme + disabled = 'Disabled' enabled = 'Enabled' subtle_variation = 'Vary (Subtle)' @@ -32,9 +34,10 @@ default_parameters = { cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0) } # stop, weight -metadata_scheme =[ - ('Fooocus (json)', 'fooocus'), - ('A1111 (plain text)', 'a1111'), +# TODO use translation here +metadata_scheme = [ + ('Fooocus (json)', MetadataScheme.FOOOCUS.value), + ('A1111 (plain text)', MetadataScheme.A1111.value), ] inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 07b42a16..e7cf8a47 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -1,138 +1,38 @@ import json + import gradio as gr + import modules.config -def load_parameter_button_click(raw_prompt_txt, is_generating): - loaded_parameter_dict = json.loads(raw_prompt_txt) +def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): + loaded_parameter_dict = raw_metadata + if isinstance(raw_metadata, str): + loaded_parameter_dict = json.loads(raw_metadata) assert isinstance(loaded_parameter_dict, dict) results = [True, 1] - try: - h = loaded_parameter_dict.get('Prompt', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Negative Prompt', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Styles', None) - h = eval(h) - assert isinstance(h, list) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Performance', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Resolution', None) - width, height = eval(h) - formatted = modules.config.add_ratio(f'{width}*{height}') - if formatted in modules.config.available_aspect_ratios: - results.append(formatted) - results.append(-1) - results.append(-1) - else: - results.append(gr.update()) - results.append(width) - results.append(height) - except: - results.append(gr.update()) - results.append(gr.update()) - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Sharpness', None) - assert h is not None - h = float(h) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Guidance Scale', None) - assert h is not None - h = float(h) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('ADM Guidance', None) - p, n, e = eval(h) - results.append(float(p)) - results.append(float(n)) - results.append(float(e)) - except: - results.append(gr.update()) - results.append(gr.update()) - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Base Model', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Refiner Model', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Refiner Switch', None) - assert h is not None - h = float(h) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Sampler', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Scheduler', None) - assert isinstance(h, str) - results.append(h) - except: - results.append(gr.update()) - - try: - h = loaded_parameter_dict.get('Seed', None) - assert h is not None - h = int(h) - results.append(False) - results.append(h) - except: - results.append(gr.update()) - results.append(gr.update()) + get_str('prompt', 'Prompt', loaded_parameter_dict, results) + get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) + get_list('styles', 'Styles', loaded_parameter_dict, results) + get_str('performance', 'Performance', loaded_parameter_dict, results) + get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) + get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) + get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) + get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) + get_str('base_model', 'Base Model', loaded_parameter_dict, results) + get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) + get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) + get_str('sampler', 'Sampler', loaded_parameter_dict, results) + get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) + get_seed('seed', 'Seed', loaded_parameter_dict, results) if is_generating: results.append(gr.update()) else: results.append(gr.update(visible=True)) - + results.append(gr.update(visible=False)) for i in range(1, 6): @@ -146,3 +46,94 @@ def load_parameter_button_click(raw_prompt_txt, is_generating): results.append(gr.update()) return results + + +def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + assert isinstance(h, str) + results.append(h) + except: + if fallback is not None: + get_str(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + + +def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + h = eval(h) + assert isinstance(h, list) + results.append(h) + except: + if fallback is not None: + get_list(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + + +def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + assert h is not None + h = float(h) + results.append(h) + except: + if fallback is not None: + get_float(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + + +def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + width, height = eval(h) + formatted = modules.config.add_ratio(f'{width}*{height}') + if formatted in modules.config.available_aspect_ratios: + results.append(formatted) + results.append(-1) + results.append(-1) + else: + results.append(gr.update()) + results.append(width) + results.append(height) + except: + if fallback is not None: + get_resolution(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + +def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + assert h is not None + h = int(h) + results.append(False) + results.append(h) + except: + if fallback is not None: + get_seed(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + results.append(gr.update()) + + +def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, default) + p, n, e = eval(h) + results.append(float(p)) + results.append(float(n)) + results.append(float(e)) + except: + if fallback is not None: + get_adm_guidance(fallback, None, source_dict, results, default) + return + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) diff --git a/modules/metadata.py b/modules/metadata.py new file mode 100644 index 00000000..397f7333 --- /dev/null +++ b/modules/metadata.py @@ -0,0 +1,209 @@ +import json +from abc import ABC, abstractmethod +from enum import Enum +from PIL import Image + +import modules.config +import fooocus_version +# import advanced_parameters +from modules.util import quote, is_json + + +class MetadataScheme(Enum): + FOOOCUS = 'fooocus' + A1111 = 'a1111' + + +class MetadataParser(ABC): + @abstractmethod + def parse_json(self, metadata: dict): + raise NotImplementedError + + # TODO add data to parse + @abstractmethod + def parse_string(self, metadata: dict) -> str: + raise NotImplementedError + + +class A1111MetadataParser(MetadataParser): + + def parse_json(self, metadata: dict): + # TODO add correct mapping + pass + + def parse_string(self, metadata: dict) -> str: + # TODO add correct mapping + + data = {k: v for _, k, v, _, _ in metadata} + + # TODO check if correct + width, heigth = data['resolution'].split(', ') + + generation_params = { + "Steps": data['steps'], + "Sampler": data['sampler'], + "CFG scale": data['guidance_scale'], + "Seed": data['seed'], + "Size": f"{width}x{heigth}", + # "Model hash": base_model_hash, + "Model": data['base_model'].split('.')[0], + # "Lora hashes": lora_hashes_string, + # "Denoising strength": data['denoising_strength'], + "Version": f"Fooocus {data['version']}" + } + + generation_params_text = ", ".join( + [k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None]) + positive_prompt_resolved = ', '.join(data['full_prompt']) + negative_prompt_resolved = ', '.join(data['full_negative_prompt']) + negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" + return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() + + +class FooocusMetadataParser(MetadataParser): + + def parse_json(self, metadata: dict): + # TODO add mapping if necessary + return metadata + + def parse_string(self, metadata: dict) -> str: + + return json.dumps({k: v for _, k, v, _, _ in metadata}) + # metadata = { + # # prompt with wildcards + # 'prompt': raw_prompt, 'negative_prompt': raw_negative_prompt, + # # prompt with resolved wildcards + # 'real_prompt': task['log_positive_prompt'], 'real_negative_prompt': task['log_negative_prompt'], + # # prompt with resolved wildcards, styles and prompt expansion + # 'complete_prompt_positive': task['positive'], 'complete_prompt_negative': task['negative'], + # 'styles': str(raw_style_selections), + # 'seed': task['task_seed'], 'width': width, 'height': height, + # 'sampler': sampler_name, 'scheduler': scheduler_name, 'performance': performance_selection, + # 'steps': steps, 'refiner_switch': refiner_switch, 'sharpness': sharpness, 'cfg': cfg_scale, + # 'base_model': base_model_name, 'base_model_hash': base_model_hash, 'refiner_model': refiner_model_name, + # 'denoising_strength': denoising_strength, + # 'freeu': advanced_parameters.freeu_enabled, + # 'img2img': input_image_checkbox, + # 'prompt_expansion': task['expansion'] + # } + # + # if advanced_parameters.freeu_enabled: + # metadata |= { + # 'freeu_b1': advanced_parameters.freeu_b1, 'freeu_b2': advanced_parameters.freeu_b2, + # 'freeu_s1': advanced_parameters.freeu_s1, 'freeu_s2': advanced_parameters.freeu_s2 + # } + # + # if 'vary' in goals: + # metadata |= { + # 'uov_method': uov_method + # } + # + # if 'upscale' in goals: + # metadata |= { + # 'uov_method': uov_method, 'scale': f + # } + # + # if 'inpaint' in goals: + # if len(outpaint_selections) > 0: + # metadata |= { + # 'outpaint_selections': outpaint_selections + # } + # + # metadata |= { + # 'inpaint_additional_prompt': inpaint_additional_prompt, + # 'inpaint_mask_upload': advanced_parameters.inpaint_mask_upload_checkbox, + # 'invert_mask': advanced_parameters.invert_mask_checkbox, + # 'inpaint_disable_initial_latent': advanced_parameters.inpaint_disable_initial_latent, + # 'inpaint_engine': advanced_parameters.inpaint_engine, + # 'inpaint_strength': advanced_parameters.inpaint_strength, + # 'inpaint_respective_field': advanced_parameters.inpaint_respective_field, + # } + # + # if 'cn' in goals: + # metadata |= { + # 'canny_low_threshold': advanced_parameters.canny_low_threshold, + # 'canny_high_threshold': advanced_parameters.canny_high_threshold, + # } + # + # ip_list = {x: [] for x in flags.ip_list} + # cn_task_index = 1 + # for cn_type in ip_list: + # for cn_task in cn_tasks[cn_type]: + # cn_img, cn_stop, cn_weight = cn_task + # metadata |= { + # f'image_prompt_{cn_task_index}': { + # 'cn_type': cn_type, 'cn_stop': cn_stop, 'cn_weight': cn_weight, + # } + # } + # cn_task_index += 1 + # + # metadata |= { + # 'software': f'Fooocus v{fooocus_version.version}', + # } + # if modules.config.metadata_created_by != '': + # metadata |= { + # 'created_by': modules.config.metadata_created_by + # } + # # return json.dumps(metadata, ensure_ascii=True) TODO check if possible + # return json.dumps(metadata, ensure_ascii=False) + + +def get_metadata_parser(metadata_scheme: str) -> MetadataParser: + match metadata_scheme: + case MetadataScheme.FOOOCUS.value: + return FooocusMetadataParser() + case MetadataScheme.A1111.value: + return A1111MetadataParser() + case _: + raise NotImplementedError + +# IGNORED_INFO_KEYS = { +# 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', +# 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', +# 'icc_profile', 'chromaticity', 'photoshop', +# } + + +def read_info_from_image(filepath) -> tuple[str | None, dict, str | None]: + with Image.open(filepath) as image: + items = (image.info or {}).copy() + + parameters = items.pop('parameters', None) + if parameters is not None and is_json(parameters): + parameters = json.loads(parameters) + + metadata_scheme = items.pop('fooocus_scheme', None) + + # if "exif" in items: + # exif_data = items["exif"] + # try: + # exif = piexif.load(exif_data) + # except OSError: + # # memory / exif was not valid so piexif tried to read from a file + # exif = None + # exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') + # try: + # exif_comment = piexif.helper.UserComment.load(exif_comment) + # except ValueError: + # exif_comment = exif_comment.decode('utf8', errors="ignore") + # + # if exif_comment: + # items['exif comment'] = exif_comment + # parameters = exif_comment + + # for field in IGNORED_INFO_KEYS: + # items.pop(field, None) + + # if items.get("Software", None) == "NovelAI": + # try: + # json_info = json.loads(items["Comment"]) + # sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a") + # + # geninfo = f"""{items["Description"]} + # Negative prompt: {json_info["uc"]} + # Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" + # except Exception: + # errors.report("Error parsing NovelAI image generation parameters", + # exc_info=True) + + return parameters, items, metadata_scheme diff --git a/modules/private_logger.py b/modules/private_logger.py index 223463b3..3e186e1a 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -7,6 +7,7 @@ import urllib.parse from PIL import Image from PIL.PngImagePlugin import PngInfo from modules.util import generate_temp_filename +from modules.metadata import MetadataScheme log_cache = {} @@ -19,7 +20,9 @@ def get_current_html_path(): return html_name -def log(img, dic, metadata=None, save_metadata_to_image=False): +def log(img, metadata, save_metadata_to_image=False, metadata_scheme: str = MetadataScheme.FOOOCUS.value): + assert metadata_scheme in [item.value for item in MetadataScheme] + if args_manager.args.disable_image_log: return @@ -27,8 +30,12 @@ def log(img, dic, metadata=None, save_metadata_to_image=False): os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True) if save_metadata_to_image: + metadata_parser = modules.metadata.get_metadata_parser(metadata_scheme) + parsed_parameters = metadata_parser.parse_string(metadata) + pnginfo = PngInfo() - pnginfo.add_text('parameters', metadata) + pnginfo.add_text('parameters', parsed_parameters) + pnginfo.add_text('fooocus_scheme', metadata_scheme) else: pnginfo = None Image.fromarray(img).save(local_temp_filename, pnginfo=pnginfo) @@ -40,7 +47,7 @@ def log(img, dic, metadata=None, save_metadata_to_image=False): "body { background-color: #121212; color: #E0E0E0; } " "a { color: #BB86FC; } " ".metadata { border-collapse: collapse; width: 100%; } " - ".metadata .key { width: 15%; } " + ".metadata .label { width: 15%; } " ".metadata .value { width: 85%; font-weight: bold; } " ".metadata th, .metadata td { border: 1px solid #4d4d4d; padding: 4px; } " ".image-container img { height: auto; max-width: 512px; display: block; padding-right:10px; } " @@ -93,12 +100,13 @@ def log(img, dic, metadata=None, save_metadata_to_image=False): item = f"

\n" item += f"" item += "" diff --git a/modules/util.py b/modules/util.py index 89270138..8e304d6c 100644 --- a/modules/util.py +++ b/modules/util.py @@ -178,6 +178,7 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None): return sorted(filenames, key=lambda x: -1 if os.sep in x else 1) + def calculate_sha256(filename): hash_sha256 = sha256() blksize = 1024 * 1024 @@ -188,8 +189,18 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() + def quote(text): if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text): return text - return json.dumps(text, ensure_ascii=False) \ No newline at end of file + return json.dumps(text, ensure_ascii=False) + + +def is_json(data: str) -> bool: + try: + loaded_json = json.loads(data) + assert isinstance(loaded_json, dict) + except ValueError: + return False + return True diff --git a/webui.py b/webui.py index 21ed6210..0688215a 100644 --- a/webui.py +++ b/webui.py @@ -14,6 +14,7 @@ import modules.gradio_hijack as grh import modules.advanced_parameters as advanced_parameters import modules.style_sorter as style_sorter import modules.meta_parser +import modules.metadata import args_manager import copy @@ -21,6 +22,7 @@ from modules.sdxl_styles import legal_style_names from modules.private_logger import get_current_html_path from modules.ui_gradio_extensions import reload_javascript from modules.auth import auth_enabled, check_auth +from modules.util import is_json def generate_clicked(*args): @@ -208,6 +210,28 @@ with shared.gradio_root: value=flags.desc_type_photo) desc_btn = gr.Button(value='Describe this Image into Prompt') gr.HTML('\U0001F4D4 Document') + with gr.TabItem(label='Load Metadata') as load_tab: + with gr.Column(): + metadata_input_image = grh.Image(label='Drag any image generated by Fooocus here', source='upload', type='filepath') + metadata_json = gr.JSON(label='Metadata') + metadata_import_button = gr.Button(value='Overwrite Input Values') + + def trigger_metadata_preview(filepath): + parameters, items, metadata_scheme = modules.metadata.read_info_from_image(filepath) + + results = {} + if parameters is not None: + results['parameters'] = parameters + if items: + results['items'] = items + if metadata_scheme is not None: + results['metadata_scheme'] = metadata_scheme + + return results + + metadata_input_image.upload(trigger_metadata_preview, inputs=metadata_input_image, + outputs=metadata_json) + switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}" down_js = "() => {viewer_to_bottom();}" @@ -548,14 +572,8 @@ with shared.gradio_root: def parse_meta(raw_prompt_txt, is_generating): loaded_json = None - try: - if '{' in raw_prompt_txt: - if '}' in raw_prompt_txt: - if ':' in raw_prompt_txt: - loaded_json = json.loads(raw_prompt_txt) - assert isinstance(loaded_json, dict) - except: - loaded_json = None + if is_json(raw_prompt_txt): + loaded_json = json.loads(raw_prompt_txt) if loaded_json is None: if is_generating: @@ -567,31 +585,30 @@ with shared.gradio_root: prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False) - load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=[ - advanced_checkbox, - image_number, - prompt, - negative_prompt, - style_selections, - performance_selection, - aspect_ratios_selection, - overwrite_width, - overwrite_height, - sharpness, - guidance_scale, - adm_scaler_positive, - adm_scaler_negative, - adm_scaler_end, - base_model, - refiner_model, - refiner_switch, - sampler_name, - scheduler_name, - seed_random, - image_seed, - generate_button, - load_parameter_button - ] + lora_ctrls, queue=False, show_progress=False) + load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections, + performance_selection, aspect_ratios_selection, overwrite_width, overwrite_height, + sharpness, guidance_scale, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, + base_model, refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, + image_seed, generate_button, load_parameter_button] + lora_ctrls + + load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False) + + def trigger_metadata_import(filepath, state_is_generating): + parameters, items, metadata_scheme = modules.metadata.read_info_from_image(filepath) + + if parameters is None: + pass + + if metadata_scheme is None and isinstance(parameters, dict): + metadata_scheme = modules.metadata.MetadataScheme.FOOOCUS.value + + + metadata_parser = modules.metadata.get_metadata_parser(metadata_scheme) + parsed_parameters = metadata_parser.parse_json(parameters) + return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating) + + + metadata_import_button.click(trigger_metadata_import, inputs=[metadata_input_image, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False) generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True), outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \
{only_name}
" - for key, value in dic: - value_txt = str(value).replace('\n', '
') - item += f"\n" + for label, key, value, showable, copyable in metadata: + if showable: + value_txt = str(value).replace('\n', '
') + item += f"\n" item += "" - js_txt = urllib.parse.quote(json.dumps({k: v for k, v in dic}, indent=0), safe='') + js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v, _, copyable in metadata if copyable}, indent=0), safe='') item += f"
" item += "