refactor: use central flag for LoRA count

This commit is contained in:
Manuel Schmid 2024-01-29 14:22:36 +01:00
parent 13d0341a02
commit c3ab9f1f30
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
6 changed files with 11 additions and 8 deletions

View File

@ -44,7 +44,7 @@ def worker():
from modules.util import remove_empty_str, HWC3, resize_image, \
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256
from modules.upscaler import perform_upscale
from modules.flags import Performance, MetadataScheme
from modules.flags import Performance, MetadataScheme, lora_count
try:
async_gradio_app = shared.gradio_root
@ -134,7 +134,7 @@ def worker():
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = [[str(args.pop()), float(args.pop())] for _ in range(5)]
loras = [[str(args.pop()), float(args.pop())] for _ in range(lora_count)]
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()

View File

@ -8,7 +8,7 @@ import modules.sdxl_styles
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder
from modules.flags import Performance, MetadataScheme
from modules.flags import Performance, MetadataScheme, lora_count
config_path = os.path.abspath("./config.txt")
config_example_path = os.path.abspath("config_modification_tutorial.txt")
@ -333,7 +333,7 @@ metadata_created_by = get_config_item_or_set_default(
example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
config_dict["default_loras"] = default_loras = default_loras[:lora_count] + [['None', 1.0] for _ in range(lora_count - len(default_loras))]
possible_preset_keys = [
"default_model",

View File

@ -55,6 +55,8 @@ metadata_scheme = [
('A1111 (plain text)', MetadataScheme.A1111.value),
]
lora_count = 5
lora_count_with_lcm = lora_count + 1
class Steps(Enum):
QUALITY = 60

View File

@ -3,6 +3,7 @@ import json
import gradio as gr
import modules.config
from modules.flags import lora_count_with_lcm
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
@ -35,7 +36,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
results.append(gr.update(visible=False))
for i in range(1, 6):
for i in range(1, lora_count_with_lcm):
try:
n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ')
w = float(w)

View File

@ -7,7 +7,7 @@ import modules.config
import fooocus_version
# import advanced_parameters
from modules.util import quote, unquote, extract_styles_from_prompt, is_json
from modules.flags import MetadataScheme, Performance, Steps
from modules.flags import MetadataScheme, Performance, Steps, lora_count_with_lcm
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
@ -106,7 +106,7 @@ class A1111MetadataParser(MetadataParser):
width, heigth = eval(data['resolution'])
lora_hashes = []
for index in range(5):
for index in range(lora_count_with_lcm):
key = f'lora_name_{index + 1}'
if key in data:
name = data[f'lora_name_{index + 1}']

View File

@ -502,7 +502,7 @@ with shared.gradio_root:
modules.config.update_all_model_names()
results = []
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
for i in range(5):
for i in range(flags.lora_count):
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results