From 9bdb65ec5d77410ce92c75299476ad7d5193e686 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Wed, 31 Jan 2024 01:18:09 +0100 Subject: [PATCH] feat: add metadata handling for all non-img2img parameters --- modules/async_worker.py | 71 +++++++++++++++++++++++-------------- modules/meta_parser.py | 77 +++++++++++++++++++++++++---------------- modules/metadata.py | 35 +++++++++++++++++-- webui.py | 9 ++--- 4 files changed, 130 insertions(+), 62 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index f58ab79f..ff12f5ec 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -274,7 +274,7 @@ def worker(): and isinstance(inpaint_input_image, dict): inpaint_image = inpaint_input_image['image'] inpaint_mask = inpaint_input_image['mask'][:, :, 0] - + if advanced_parameters.inpaint_mask_upload_checkbox: if isinstance(inpaint_mask_image_upload, np.ndarray): if inpaint_mask_image_upload.ndim == 3: @@ -777,38 +777,57 @@ def worker(): imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] for x in imgs: - d = [ - ('Prompt', 'prompt', task['log_positive_prompt'], True, True), - ('Full Positive Prompt', 'full_prompt', task['positive'], False, False), - ('Negative Prompt', 'negative_prompt', task['log_negative_prompt'], True, True), - ('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False), - ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True), - ('Styles', 'styles', str(raw_style_selections), True, True), - ('Performance', 'performance', performance_selection.value, True, True), - ('Steps', 'steps', steps, False, False), - ('Resolution', 'resolution', str((width, height)), True, True), - ('Sharpness', 'sharpness', sharpness, True, True), - ('Guidance Scale', 'guidance_scale', guidance_scale, True, True), - ('ADM Guidance', 'adm_guidance', str(( - modules.patch.positive_adm_scale, - modules.patch.negative_adm_scale, - modules.patch.adm_scaler_end)), True, True), - ('Base Model', 'base_model', base_model_name, True, True), - ('Base Model Hash', 'base_model_hash', base_model_hash, False, False), - ('Refiner Model', 'refiner_model', refiner_model_name, True, True), - ('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False), - ('Refiner Switch', 'refiner_switch', refiner_switch, True, True), - ('Sampler', 'sampler', sampler_name, True, True), - ('Scheduler', 'scheduler', scheduler_name, True, True), - ('Seed', 'seed', task['task_seed'], True, True) - ] + d = [('Prompt', 'prompt', task['log_positive_prompt'], True, True), + ('Full Positive Prompt', 'full_prompt', task['positive'], False, False), + ('Negative Prompt', 'negative_prompt', task['log_negative_prompt'], True, True), + ('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False), + ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True), + ('Styles', 'styles', str(raw_style_selections), True, True), + ('Performance', 'performance', performance_selection.value, True, True), + ('Steps', 'steps', steps, False, False), + ('Resolution', 'resolution', str((width, height)), True, True), + ('Guidance Scale', 'guidance_scale', guidance_scale, True, True), + ('Sharpness', 'sharpness', sharpness, True, True), + ('ADM Guidance', 'adm_guidance', str(( + modules.patch.positive_adm_scale, + modules.patch.negative_adm_scale, + modules.patch.adm_scaler_end)), True, True), + ('Base Model', 'base_model', base_model_name, True, True), + ('Base Model Hash', 'base_model_hash', base_model_hash, False, False), # TODO move to metadata and use cache + ('Refiner Model', 'refiner_model', refiner_model_name, True, True), + ('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False), # TODO move to metadata and use cache + ('Refiner Switch', 'refiner_switch', refiner_switch, True, True)] + + # TODO evaluate if this should always be added + if refiner_model_name != 'None': + if advanced_parameters.overwrite_switch > 0: + d.append(('Overwrite Switch', 'overwrite_switch', advanced_parameters.overwrite_switch, True, True)) + if refiner_swap_method != flags.refiner_swap_method: + d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method, True, True)) + if advanced_parameters.adaptive_cfg != modules.config.default_cfg_tsnr: + d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', advanced_parameters.adaptive_cfg, True, True)) + + d.append(('Sampler', 'sampler', sampler_name, True, True)) + d.append(('Scheduler', 'scheduler', scheduler_name, True, True)) + d.append(('Seed', 'seed', task['task_seed'], True, True)) + + if advanced_parameters.freeu_enabled: + d.append(('FreeU', 'freeu', str(( + advanced_parameters.freeu_b1, + advanced_parameters.freeu_b2, + advanced_parameters.freeu_s1, + advanced_parameters.freeu_s2)), True, True)) + for li, (n, w) in enumerate(loras): if n != 'None': d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}', True, True)) d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False)) d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False)) + # TODO move hashes to metadata handling d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False)) + d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version, True, True)) + if modules.config.metadata_created_by != '': d.append(('Created By', 'created_by', modules.config.metadata_created_by, False, False)) diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 33772140..6d4e542c 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -3,7 +3,7 @@ import json import gradio as gr import modules.config -from modules.flags import lora_count +from modules.flags import lora_count, Steps def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): @@ -18,10 +18,14 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) get_list('styles', 'Styles', loaded_parameter_dict, results) get_str('performance', 'Performance', loaded_parameter_dict, results) + get_steps('steps', 'Steps', loaded_parameter_dict, results) + get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results) get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) - get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) + get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) + get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results) + get_str('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results) get_str('base_model', 'Base Model', loaded_parameter_dict, results) get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) @@ -36,6 +40,8 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): results.append(gr.update(visible=False)) + get_freeu('freeu', 'FreeU', loaded_parameter_dict, results) + for i in range(lora_count): get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results) @@ -44,45 +50,51 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) assert isinstance(h, str) results.append(h) except: - if fallback is not None: - get_str(fallback, None, source_dict, results, default) - return results.append(gr.update()) def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) h = eval(h) assert isinstance(h, list) results.append(h) except: - if fallback is not None: - get_list(fallback, None, source_dict, results, default) - return results.append(gr.update()) +# TODO try get generic + def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) assert h is not None h = float(h) results.append(h) except: - if fallback is not None: - get_float(fallback, None, source_dict, results, default) - return results.append(gr.update()) +def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert h is not None + h = int(h) + if h not in set(item.value for item in Steps): + results.append(h) + return + results.append(-1) + except: + results.append(-1) + + def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) width, height = eval(h) formatted = modules.config.add_ratio(f'{width}*{height}') if formatted in modules.config.available_aspect_ratios: @@ -94,9 +106,6 @@ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: l results.append(width) results.append(height) except: - if fallback is not None: - get_resolution(fallback, None, source_dict, results, default) - return results.append(gr.update()) results.append(gr.update()) results.append(gr.update()) @@ -104,30 +113,41 @@ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: l def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) assert h is not None h = int(h) results.append(False) results.append(h) except: - if fallback is not None: - get_seed(fallback, None, source_dict, results, default) - return results.append(gr.update()) results.append(gr.update()) def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - h = source_dict.get(key, default) + h = source_dict.get(key, source_dict.get(fallback, default)) p, n, e = eval(h) results.append(float(p)) results.append(float(n)) results.append(float(e)) except: - if fallback is not None: - get_adm_guidance(fallback, None, source_dict, results, default) - return + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + +def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + b1, b2, s1, s2 = eval(h) + results.append(True) + results.append(float(b1)) + results.append(float(b2)) + results.append(float(s1)) + results.append(float(s2)) + except: + results.append(False) + results.append(gr.update()) results.append(gr.update()) results.append(gr.update()) results.append(gr.update()) @@ -135,13 +155,10 @@ def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: - n, w = source_dict.get(key).split(' : ') + n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ') w = float(w) results.append(n) results.append(w) except: - if fallback is not None: - get_lora(fallback, None, source_dict, results, default) - return results.append('None') results.append(1) diff --git a/modules/metadata.py b/modules/metadata.py index 9d68fd17..34cc1923 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -31,11 +31,18 @@ class A1111MetadataParser(MetadataParser): fooocus_to_a1111 = { 'negative_prompt': 'Negative prompt', 'styles': 'Styles', + 'performance': 'Performance', 'steps': 'Steps', 'sampler': 'Sampler', 'guidance_scale': 'CFG scale', 'seed': 'Seed', 'resolution': 'Size', + 'sharpness': 'Sharpness', + 'adm_guidance': 'ADM Guidance', + 'refiner_swap_method': 'Refiner Swap Method', + 'adaptive_cfg': 'Adaptive CFG', + 'overwrite_switch': 'Overwrite Switch', + 'freeu': 'FreeU', 'base_model': 'Model', 'base_model_hash': 'Model hash', 'refiner_model': 'Refiner', @@ -87,8 +94,8 @@ class A1111MetadataParser(MetadataParser): except Exception: print(f"Error parsing \"{k}: {v}\"") - # try to load performance based on steps - if 'steps' in data: + # try to load performance based on steps, fallback for direct A1111 imports + if 'steps' in data and 'performance' not in data: try: data['performance'] = Performance[Steps(int(data['steps'])).name].value except Exception: @@ -132,10 +139,14 @@ class A1111MetadataParser(MetadataParser): lora_hashes_string = ', '.join(lora_hashes) generation_params = { + self.fooocus_to_a1111['performance']: data['performance'], self.fooocus_to_a1111['steps']: data['steps'], self.fooocus_to_a1111['sampler']: data['sampler'], self.fooocus_to_a1111['seed']: data['seed'], self.fooocus_to_a1111['resolution']: f'{width}x{heigth}', + self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], + self.fooocus_to_a1111['sharpness']: data['sharpness'], + self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'], # TODO load model by name / hash self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, self.fooocus_to_a1111['base_model_hash']: data['base_model_hash'] @@ -147,6 +158,26 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'] } + if 'refiner_swap_method' in data: + generation_params |= { + self.fooocus_to_a1111['refiner_swap_method']: data['refiner_swap_method'], + } + + # TODO unify with for and call with key + + if 'freeu' in data: + generation_params |= { + self.fooocus_to_a1111['freeu']: data['freeu'], + } + if 'adaptive_cfg' in data: + generation_params |= { + self.fooocus_to_a1111['adaptive_cfg']: data['adaptive_cfg'], + } + if 'overwrite_switch' in data: + generation_params |= { + self.fooocus_to_a1111['overwrite_switch']: data['overwrite_switch'], + } + generation_params |= { self.fooocus_to_a1111['lora_hashes']: lora_hashes_string, self.fooocus_to_a1111['version']: data['version'] diff --git a/webui.py b/webui.py index 816a122f..51193fac 100644 --- a/webui.py +++ b/webui.py @@ -588,10 +588,11 @@ with shared.gradio_root: prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False) load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections, - performance_selection, aspect_ratios_selection, overwrite_width, overwrite_height, - sharpness, guidance_scale, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, - base_model, refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, - image_seed, generate_button, load_parameter_button] + lora_ctrls + performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, + overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, + adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, + refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, + generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)