From 18f9f7dc313ee279fd3241784aafad9e948b402b Mon Sep 17 00:00:00 2001 From: MindOfMatter <35126123+MindOfMatter@users.noreply.github.com> Date: Sun, 25 Feb 2024 15:12:26 -0500 Subject: [PATCH] feat: make lora number editable in config (#2215) * Initial commit * Update README.md * sync with original main Fooocus repo * update with my gitignore setup * add max lora config feature * Revert "add max lora config feature" This reverts commit cfe7463fe25475b6d59f36072ade410a2d8d5124. * add max loras config feature * Update README.md * Update .gitignore * update * merge * revert * refactor: rename default_loras_max_number to default_max_lora_number, validate config for int * fix: add missing patch_all call and imports again --------- Co-authored-by: Manuel Schmid --- modules/async_worker.py | 7 +++---- modules/config.py | 8 +++++++- modules/meta_parser.py | 6 +++--- webui.py | 8 ++++---- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 34cd2e5a..47848ad6 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -4,7 +4,6 @@ from modules.patch import PatchSettings, patch_settings, patch_all patch_all() - class AsyncTask: def __init__(self, args): self.args = args @@ -115,13 +114,13 @@ def worker(): # must use deep copy otherwise gradio is super laggy. Do not use list.append() . async_task.results = async_task.results + [wall] return - + def apply_enabled_loras(loras): enabled_loras = [] for lora_enabled, lora_model, lora_weight in loras: if lora_enabled: enabled_loras.append([lora_model, lora_weight]) - + return enabled_loras @torch.no_grad() @@ -145,7 +144,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(5)]) + loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() diff --git a/modules/config.py b/modules/config.py index 6f713916..bb1ee26c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -235,6 +235,11 @@ default_loras = get_config_item_or_set_default( ], validator=lambda x: isinstance(x, list) and all(len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) for y in x) ) +default_max_lora_number = get_config_item_or_set_default( + key='default_max_lora_number', + default_value=len(default_loras), + validator=lambda x: isinstance(x, int) and x >= 1 +) default_cfg_scale = get_config_item_or_set_default( key='default_cfg_scale', default_value=7.0, @@ -357,13 +362,14 @@ example_inpaint_prompts = get_config_item_or_set_default( example_inpaint_prompts = [[x] for x in example_inpaint_prompts] -config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] +config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))] possible_preset_keys = [ "default_model", "default_refiner", "default_refiner_switch", "default_loras", + "default_max_lora_number", "default_cfg_scale", "default_sample_sharpness", "default_sampler", diff --git a/modules/meta_parser.py b/modules/meta_parser.py index bd8f555e..061e1f8a 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -135,16 +135,16 @@ def load_parameter_button_click(raw_prompt_txt, is_generating): results.append(gr.update(visible=False)) - for i in range(1, 6): + for i in range(1, modules.config.default_max_lora_number + 1): try: - n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ') + n, w = loaded_parameter_dict.get(f'LoRA {i}', ' : ').split(' : ') w = float(w) results.append(True) results.append(n) results.append(w) except: results.append(True) - results.append("None") + results.append('None') results.append(1.0) return results diff --git a/webui.py b/webui.py index 1463ff90..270f0ffa 100644 --- a/webui.py +++ b/webui.py @@ -471,10 +471,10 @@ with shared.gradio_root: def model_refresh_clicked(): modules.config.update_all_model_names() - results = [] - results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)] - for i in range(5): - results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update(), gr.update(interactive=True)] + results = [gr.update(choices=modules.config.model_filenames)] + results += [gr.update(choices=['None'] + modules.config.model_filenames)] + for i in range(modules.config.default_max_lora_number): + results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] return results model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,