feat: add configs for controlnet

default_controlnet_image_count, ip_images, ip_stop_ats, ip_weights and ip_types
This commit is contained in:
Manuel Schmid 2024-07-27 21:15:51 +02:00
parent f5906f27a0
commit 0b1fe42971
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 47 additions and 11 deletions

View File

@ -9,7 +9,7 @@ patch_all()
class AsyncTask:
def __init__(self, args):
from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count, disabled
from modules.flags import Performance, MetadataScheme, ip_list, disabled
from modules.util import get_enabled_loras
from modules.config import default_max_lora_number
import args_manager
@ -101,7 +101,7 @@ class AsyncTask:
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
self.cn_tasks = {x: [] for x in ip_list}
for _ in range(controlnet_image_count):
for _ in range(modules.config.default_controlnet_image_count):
cn_img = args.pop()
cn_stop = args.pop()
cn_weight = args.pop()

View File

@ -499,6 +499,46 @@ default_uov_method = get_config_item_or_set_default(
validator=lambda x: x in modules.flags.uov_list,
expected_type=int
)
default_controlnet_image_count = get_config_item_or_set_default(
key='default_controlnet_image_count',
default_value=4,
validator=lambda x: x > 0,
expected_type=int
)
default_ip_images = {}
default_ip_stop_ats = {}
default_ip_weights = {}
default_ip_types = {}
for image_count in range(default_controlnet_image_count):
default_ip_images[image_count] = get_config_item_or_set_default(
key=f'default_ip_image_{image_count}',
default_value=None,
validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x),
expected_type=str
)
default_ip_types[image_count] = get_config_item_or_set_default(
key=f'default_ip_type_{image_count}',
default_value=modules.flags.default_ip,
validator=lambda x: x in modules.flags.ip_list,
expected_type=str
)
default_end, default_weight = modules.flags.default_parameters[default_ip_types[image_count]]
default_ip_stop_ats[image_count] = get_config_item_or_set_default(
key=f'default_ip_stop_at_{image_count}',
default_value=default_end,
validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x),
expected_type=float
)
default_ip_weights[image_count] = get_config_item_or_set_default(
key=f'default_ip_weight_{image_count}',
default_value=default_weight,
validator=lambda x: x is None or isinstance(x, str) and os.path.exists(x),
expected_type=float
)
default_inpaint_advanced_masking_checkbox = get_config_item_or_set_default(
key='default_inpaint_advanced_masking_checkbox',
default_value=False,

View File

@ -113,8 +113,6 @@ metadata_scheme = [
(f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value),
]
controlnet_image_count = 4
class OutputFormat(Enum):
PNG = 'png'

View File

@ -220,24 +220,22 @@ with shared.gradio_root:
ip_weights = []
ip_ctrls = []
ip_ad_cols = []
for _ in range(flags.controlnet_image_count):
for image_count in range(modules.config.default_controlnet_image_count):
with gr.Column():
ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300)
ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300, value=modules.config.default_ip_images[image_count])
ip_images.append(ip_image)
ip_ctrls.append(ip_image)
with gr.Column(visible=modules.config.default_image_prompt_advanced_checkbox) as ad_col:
with gr.Row():
default_end, default_weight = flags.default_parameters[flags.default_ip]
ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=default_end)
ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=modules.config.default_ip_stop_ats[image_count])
ip_stops.append(ip_stop)
ip_ctrls.append(ip_stop)
ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=default_weight)
ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=modules.config.default_ip_weights[image_count])
ip_weights.append(ip_weight)
ip_ctrls.append(ip_weight)
ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=flags.default_ip, container=False)
ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=modules.config.default_ip_types[image_count], container=False)
ip_types.append(ip_type)
ip_ctrls.append(ip_type)