From 468d704b299e0bf10ace1662506289ccd85be018 Mon Sep 17 00:00:00 2001 From: MindOfMatter <35126123+MindOfMatter@users.noreply.github.com> Date: Sun, 25 Feb 2024 13:59:28 -0500 Subject: [PATCH] feat: add button to enable LoRAs (#2210) * Initial commit * Update README.md * sync with original main Fooocus repo * update with my gitignore setup * add max lora config feature * Revert "add max lora config feature" This reverts commit cfe7463fe25475b6d59f36072ade410a2d8d5124. * add lora enabler feature * Update README.md * Update .gitignore * update * merge * revert changes * revert * feat: change width of LoRA columns * refactor: rename lora_enable to lora_enabled, optimize code --------- Co-authored-by: Manuel Schmid --- modules/async_worker.py | 10 +++++++++- modules/html.py | 24 ++++++++++++++++++++++++ modules/meta_parser.py | 6 ++++-- webui.py | 9 ++++++--- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index a304e697..34cd2e5a 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -115,6 +115,14 @@ def worker(): # must use deep copy otherwise gradio is super laggy. Do not use list.append() . async_task.results = async_task.results + [wall] return + + def apply_enabled_loras(loras): + enabled_loras = [] + for lora_enabled, lora_model, lora_weight in loras: + if lora_enabled: + enabled_loras.append([lora_model, lora_weight]) + + return enabled_loras @torch.no_grad() @torch.inference_mode() @@ -137,7 +145,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(5)] + loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(5)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() diff --git a/modules/html.py b/modules/html.py index 3ec6f2d6..47a1483a 100644 --- a/modules/html.py +++ b/modules/html.py @@ -112,6 +112,30 @@ progress::after { margin-left: -5px !important; } +.lora_enable { + flex-grow: 1 !important; +} + +.lora_enable label { + height: 100%; +} + +.lora_enable label input { + margin: auto; +} + +.lora_enable label span { + display: none; +} + +.lora_model { + flex-grow: 5 !important; +} + +.lora_weight { + flex-grow: 5 !important; +} + ''' progress_html = '''
diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 07b42a16..bd8f555e 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -139,10 +139,12 @@ def load_parameter_button_click(raw_prompt_txt, is_generating): try: n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ') w = float(w) + results.append(True) results.append(n) results.append(w) except: - results.append(gr.update()) - results.append(gr.update()) + results.append(True) + results.append("None") + results.append(1.0) return results diff --git a/webui.py b/webui.py index 6d72c67c..1463ff90 100644 --- a/webui.py +++ b/webui.py @@ -322,11 +322,14 @@ with shared.gradio_root: for i, (n, v) in enumerate(modules.config.default_loras): with gr.Row(): + lora_enabled = gr.Checkbox(label='Enable', value=True, + elem_classes=['lora_enable', 'min_check']) lora_model = gr.Dropdown(label=f'LoRA {i + 1}', - choices=['None'] + modules.config.lora_filenames, value=n) + choices=['None'] + modules.config.lora_filenames, value=n, + elem_classes='lora_model') lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v, elem_classes='lora_weight') - lora_ctrls += [lora_model, lora_weight] + lora_ctrls += [lora_enabled, lora_model, lora_weight] with gr.Row(): model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button') @@ -471,7 +474,7 @@ with shared.gradio_root: results = [] results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)] for i in range(5): - results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] + results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update(), gr.update(interactive=True)] return results model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,