From c3ab9f1f3071b9c32a74802288a3f8615ba1c6d8 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 29 Jan 2024 14:22:36 +0100 Subject: [PATCH] refactor: use central flag for LoRA count --- modules/async_worker.py | 4 ++-- modules/config.py | 4 ++-- modules/flags.py | 2 ++ modules/meta_parser.py | 3 ++- modules/metadata.py | 4 ++-- webui.py | 2 +- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index e6ec8c39..1c788b62 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -44,7 +44,7 @@ def worker(): from modules.util import remove_empty_str, HWC3, resize_image, \ get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256 from modules.upscaler import perform_upscale - from modules.flags import Performance, MetadataScheme + from modules.flags import Performance, MetadataScheme, lora_count try: async_gradio_app = shared.gradio_root @@ -134,7 +134,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(5)] + loras = [[str(args.pop()), float(args.pop())] for _ in range(lora_count)] input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() diff --git a/modules/config.py b/modules/config.py index 6c5da0e9..e9152b40 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,7 +8,7 @@ import modules.sdxl_styles from modules.model_loader import load_file_from_url from modules.util import get_files_from_folder -from modules.flags import Performance, MetadataScheme +from modules.flags import Performance, MetadataScheme, lora_count config_path = os.path.abspath("./config.txt") config_example_path = os.path.abspath("config_modification_tutorial.txt") @@ -333,7 +333,7 @@ metadata_created_by = get_config_item_or_set_default( example_inpaint_prompts = [[x] for x in example_inpaint_prompts] -config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] +config_dict["default_loras"] = default_loras = default_loras[:lora_count] + [['None', 1.0] for _ in range(lora_count - len(default_loras))] possible_preset_keys = [ "default_model", diff --git a/modules/flags.py b/modules/flags.py index 5008d5cd..f0297783 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -55,6 +55,8 @@ metadata_scheme = [ ('A1111 (plain text)', MetadataScheme.A1111.value), ] +lora_count = 5 +lora_count_with_lcm = lora_count + 1 class Steps(Enum): QUALITY = 60 diff --git a/modules/meta_parser.py b/modules/meta_parser.py index e7cf8a47..3c3e416b 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -3,6 +3,7 @@ import json import gradio as gr import modules.config +from modules.flags import lora_count_with_lcm def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): @@ -35,7 +36,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): results.append(gr.update(visible=False)) - for i in range(1, 6): + for i in range(1, lora_count_with_lcm): try: n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ') w = float(w) diff --git a/modules/metadata.py b/modules/metadata.py index 0d585f1c..0978e862 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -7,7 +7,7 @@ import modules.config import fooocus_version # import advanced_parameters from modules.util import quote, unquote, extract_styles_from_prompt, is_json -from modules.flags import MetadataScheme, Performance, Steps +from modules.flags import MetadataScheme, Performance, Steps, lora_count_with_lcm re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) @@ -106,7 +106,7 @@ class A1111MetadataParser(MetadataParser): width, heigth = eval(data['resolution']) lora_hashes = [] - for index in range(5): + for index in range(lora_count_with_lcm): key = f'lora_name_{index + 1}' if key in data: name = data[f'lora_name_{index + 1}'] diff --git a/webui.py b/webui.py index b245df17..60d6540d 100644 --- a/webui.py +++ b/webui.py @@ -502,7 +502,7 @@ with shared.gradio_root: modules.config.update_all_model_names() results = [] results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)] - for i in range(5): + for i in range(flags.lora_count): results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] return results