fix: use correct LoRA mapping, add fallback for backwards compatibility

This commit is contained in:
Manuel Schmid 2024-01-29 15:45:55 +01:00
parent 20e53028a4
commit c80011b1d1
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 19 additions and 12 deletions

View File

@ -804,7 +804,7 @@ def worker():
]
for li, (n, w) in enumerate(loras):
if n != 'None':
d.append((f'LoRA {li + 1}', f'lora{li + 1}_combined', f'{n} : {w}', True, True))
d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}', True, True))
d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False))
d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False))
d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False))

View File

@ -3,7 +3,7 @@ import json
import gradio as gr
import modules.config
from modules.flags import lora_count_with_lcm
from modules.flags import lora_count
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
@ -29,6 +29,9 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results)
for i in range(lora_count):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
if is_generating:
results.append(gr.update())
else:
@ -36,16 +39,6 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
results.append(gr.update(visible=False))
for i in range(1, lora_count_with_lcm):
try:
n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ')
w = float(w)
results.append(n)
results.append(w)
except:
results.append(gr.update())
results.append(gr.update())
return results
@ -138,3 +131,17 @@ def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results:
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
n, w = source_dict.get(key).split(' : ')
w = float(w)
results.append(n)
results.append(w)
except:
if fallback is not None:
get_lora(fallback, None, source_dict, results, default)
return
results.append(gr.update())
results.append(gr.update())