feat: add gradio elements for input

This commit is contained in:
Manuel Schmid 2024-06-12 21:52:48 +02:00
parent 190c4b0a6f
commit 9998b52dd2
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
5 changed files with 104 additions and 29 deletions

View File

@ -99,7 +99,7 @@ div:has(> #positive_prompt) {
}
.advanced_check_row {
width: 250px !important;
width: 310px !important;
}
.min_check {

View File

@ -2,6 +2,7 @@ import threading
from extras.inpaint_mask import generate_mask_from_image, SAMOptions
from modules.patch import PatchSettings, patch_settings, patch_all
import modules.config
patch_all()
@ -107,6 +108,23 @@ class AsyncTask:
if cn_img is not None:
self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
self.stage2_ctrls = []
for _ in range(modules.config.default_max_stage2_tabs):
stage2_enabled = args.pop()
# stage2_mode = args.pop()
stage2_mask_dino_prompt_text = args.pop()
stage2_mask_sam_model = args.pop()
stage2_mask_box_threshold = args.pop()
stage2_mask_text_threshold = args.pop()
if stage2_enabled:
self.stage2_ctrls.append([
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold,
stage2_mask_text_threshold
])
async_tasks = []
@ -131,7 +149,6 @@ def worker():
import modules.default_pipeline as pipeline
import modules.core as core
import modules.flags as flags
import modules.config
import modules.patch
import ldm_patched.modules.model_management
import extras.preprocessors as preprocessors
@ -1019,37 +1036,44 @@ def worker():
# stage2
progressbar(async_task, current_progress, 'Processing stage2 ...')
final_unet = pipeline.final_unet.clone()
if len(async_task.stage2_ctrls) == 0:
continue
for img in imgs:
# TODO add stage2 check and options from inputs here
mask = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt='eye'
))
mask = mask[:, :, 0]
for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls:
mask = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=stage2_mask_dino_prompt_text,
model_type=stage2_mask_sam_model,
dino_box_threshold=stage2_mask_box_threshold,
dino_text_threshold=stage2_mask_text_threshold,
dino_debug=True
))
mask = mask[:, :, 0]
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
# TODO make configurable
denoising_strength_stage2 = 0.5
inpaint_respective_field_stage2 = 0.0
inpaint_head_model_path_stage2 = None
inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail
goals_stage2 = ['inpaint']
denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint(
async_task, None, inpaint_head_model_path_stage2, img, mask,
inpaint_parameterized_stage2, denoising_strength_stage2,
inpaint_respective_field_stage2, switch, current_progress, True)
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
# TODO make configurable
denoising_strength_stage2 = 0.5
inpaint_respective_field_stage2 = 0.0
inpaint_head_model_path_stage2 = None
inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail
goals_stage2 = ['inpaint']
denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint(
async_task, None, inpaint_head_model_path_stage2, img, mask,
inpaint_parameterized_stage2, denoising_strength_stage2,
inpaint_respective_field_stage2, switch, current_progress, True)
process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength_stage2, final_scheduler_name, goals_stage2,
initial_latent_stage2, switch, task, tasks, tiled, use_expansion, width_stage2,
height_stage2)
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength_stage2, final_scheduler_name, goals_stage2,
initial_latent_stage2, switch, task, tasks, tiled, use_expansion, width_stage2,
height_stage2)
# reset unet and inpaint_worker
pipeline.final_unet = final_unet
inpaint_worker.current_task = None
# reset and prepare next iteration
img = imgs2[0]
pipeline.final_unet = final_unet
inpaint_worker.current_task = None
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':

View File

@ -502,6 +502,20 @@ example_inpaint_prompts = get_config_item_or_set_default(
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list
)
example_stage2_prompts = get_config_item_or_set_default(
key='example_stage2_prompts',
default_value=[
'face', 'eye', 'mouth', 'hair', 'hand', 'body'
],
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list
)
default_max_stage2_tabs = get_config_item_or_set_default(
key='default_max_stage2_tabs',
default_value=3,
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
expected_type=int
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
@ -528,6 +542,7 @@ metadata_created_by = get_config_item_or_set_default(
)
example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
example_stage2_prompts = [[x] for x in example_stage2_prompts]
default_inpaint_mask_model = get_config_item_or_set_default(
key='default_inpaint_mask_model',

View File

@ -76,7 +76,7 @@ output_formats = ['png', 'jpeg', 'webp']
inpaint_mask_models = ['u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam']
inpaint_mask_cloth_category = ['full', 'upper', 'lower']
inpaint_mask_sam_model = [('base', 'vit_b'), ('large', 'vit_l'), ('huge', 'vit_h')]
inpaint_mask_sam_model = ['vit_b', 'vit_l', 'vit_h']
inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
inpaint_option_default = 'Inpaint or Outpaint (default)'

View File

@ -147,6 +147,7 @@ with shared.gradio_root:
skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False)
with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
stage2_checkbox = gr.Checkbox(label='Stage2', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
with gr.Row(visible=False) as image_input_panel:
with gr.Tabs():
@ -297,6 +298,37 @@ with shared.gradio_root:
metadata_input_image.upload(trigger_metadata_preview, inputs=metadata_input_image,
outputs=metadata_json, queue=False, show_progress=True)
with gr.Row(visible=False) as stage2_input_panel:
with gr.Tabs():
stage2_ctrls = []
for index in range(modules.config.default_max_stage2_tabs):
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
# stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True)
stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True)
example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts,
label='Additional Prompt Quick List',
components=[stage2_mask_dino_prompt_text],
visible=True)
example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False)
with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options:
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
stage2_ctrls += [
stage2_enabled,
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold,
stage2_mask_text_threshold
]
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
outputs=stage2_accordion, queue=False, show_progress=False)
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
down_js = "() => {viewer_to_bottom();}"
@ -311,6 +343,9 @@ with shared.gradio_root:
desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
metadata_tab.select(lambda: 'metadata', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
stage2_checkbox.change(lambda x: gr.update(visible=x), inputs=stage2_checkbox,
outputs=stage2_input_panel, queue=False, show_progress=False, _js=switch_js)
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
with gr.Tab(label='Settings'):
if not args_manager.args.disable_preset_selection:
@ -772,6 +807,7 @@ with shared.gradio_root:
ctrls += [save_metadata_to_images, metadata_scheme]
ctrls += ip_ctrls
ctrls += stage2_ctrls
def parse_meta(raw_prompt_txt, is_generating):
loaded_json = None