diff --git a/modules/async_worker.py b/modules/async_worker.py index b2939d2b..f9e277db 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -9,7 +9,7 @@ patch_all() class AsyncTask: def __init__(self, args): - from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count, disabled + from modules.flags import Performance, MetadataScheme, ip_list, disabled from modules.util import get_enabled_loras from modules.config import default_max_lora_number import args_manager @@ -101,7 +101,7 @@ class AsyncTask: args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS self.cn_tasks = {x: [] for x in ip_list} - for _ in range(controlnet_image_count): + for _ in range(modules.config.default_controlnet_image_count): cn_img = args.pop() cn_stop = args.pop() cn_weight = args.pop() diff --git a/modules/config.py b/modules/config.py index 49208e71..7d556641 100644 --- a/modules/config.py +++ b/modules/config.py @@ -499,6 +499,46 @@ default_uov_method = get_config_item_or_set_default( validator=lambda x: x in modules.flags.uov_list, expected_type=int ) +default_controlnet_image_count = get_config_item_or_set_default( + key='default_controlnet_image_count', + default_value=4, + validator=lambda x: x > 0, + expected_type=int +) +default_ip_images = {} +default_ip_stop_ats = {} +default_ip_weights = {} +default_ip_types = {} + +for image_count in range(default_controlnet_image_count): + default_ip_images[image_count] = get_config_item_or_set_default( + key=f'default_ip_image_{image_count}', + default_value=None, + validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x), + expected_type=str + ) + default_ip_types[image_count] = get_config_item_or_set_default( + key=f'default_ip_type_{image_count}', + default_value=modules.flags.default_ip, + validator=lambda x: x in modules.flags.ip_list, + expected_type=str + ) + + default_end, default_weight = modules.flags.default_parameters[default_ip_types[image_count]] + + default_ip_stop_ats[image_count] = get_config_item_or_set_default( + key=f'default_ip_stop_at_{image_count}', + default_value=default_end, + validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x), + expected_type=float + ) + default_ip_weights[image_count] = get_config_item_or_set_default( + key=f'default_ip_weight_{image_count}', + default_value=default_weight, + validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x), + expected_type=float + ) + default_inpaint_advanced_masking_checkbox = get_config_item_or_set_default( key='default_inpaint_advanced_masking_checkbox', default_value=False, diff --git a/modules/flags.py b/modules/flags.py index 4b6cd7b5..dc1c98ec 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -113,8 +113,6 @@ metadata_scheme = [ (f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value), ] -controlnet_image_count = 4 - class OutputFormat(Enum): PNG = 'png' diff --git a/webui.py b/webui.py index 8d308da0..72f16afe 100644 --- a/webui.py +++ b/webui.py @@ -220,24 +220,22 @@ with shared.gradio_root: ip_weights = [] ip_ctrls = [] ip_ad_cols = [] - for _ in range(flags.controlnet_image_count): + for image_count in range(modules.config.default_controlnet_image_count): with gr.Column(): - ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300) + ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300, value=modules.config.default_ip_images[image_count]) ip_images.append(ip_image) ip_ctrls.append(ip_image) with gr.Column(visible=modules.config.default_image_prompt_advanced_checkbox) as ad_col: with gr.Row(): - default_end, default_weight = flags.default_parameters[flags.default_ip] - - ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=default_end) + ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=modules.config.default_ip_stop_ats[image_count]) ip_stops.append(ip_stop) ip_ctrls.append(ip_stop) - ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=default_weight) + ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=modules.config.default_ip_weights[image_count]) ip_weights.append(ip_weight) ip_ctrls.append(ip_weight) - ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=flags.default_ip, container=False) + ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=modules.config.default_ip_types[image_count], container=False) ip_types.append(ip_type) ip_ctrls.append(ip_type)