diff --git a/css/style.css b/css/style.css index 6ed0f628..b10e644b 100644 --- a/css/style.css +++ b/css/style.css @@ -99,7 +99,7 @@ div:has(> #positive_prompt) { } .advanced_check_row { - width: 250px !important; + width: 310px !important; } .min_check { diff --git a/modules/async_worker.py b/modules/async_worker.py index 7343d903..0367d12a 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -2,6 +2,7 @@ import threading from extras.inpaint_mask import generate_mask_from_image, SAMOptions from modules.patch import PatchSettings, patch_settings, patch_all +import modules.config patch_all() @@ -107,6 +108,23 @@ class AsyncTask: if cn_img is not None: self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) + self.stage2_ctrls = [] + for _ in range(modules.config.default_max_stage2_tabs): + stage2_enabled = args.pop() + # stage2_mode = args.pop() + stage2_mask_dino_prompt_text = args.pop() + stage2_mask_sam_model = args.pop() + stage2_mask_box_threshold = args.pop() + stage2_mask_text_threshold = args.pop() + if stage2_enabled: + self.stage2_ctrls.append([ + # stage2_mode, + stage2_mask_dino_prompt_text, + stage2_mask_sam_model, + stage2_mask_box_threshold, + stage2_mask_text_threshold + ]) + async_tasks = [] @@ -131,7 +149,6 @@ def worker(): import modules.default_pipeline as pipeline import modules.core as core import modules.flags as flags - import modules.config import modules.patch import ldm_patched.modules.model_management import extras.preprocessors as preprocessors @@ -1019,37 +1036,44 @@ def worker(): # stage2 progressbar(async_task, current_progress, 'Processing stage2 ...') final_unet = pipeline.final_unet.clone() + if len(async_task.stage2_ctrls) == 0: + continue for img in imgs: - # TODO add stage2 check and options from inputs here - mask = generate_mask_from_image(img, sam_options=SAMOptions( - dino_prompt='eye' - )) - mask = mask[:, :, 0] + for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls: + mask = generate_mask_from_image(img, sam_options=SAMOptions( + dino_prompt=stage2_mask_dino_prompt_text, + model_type=stage2_mask_sam_model, + dino_box_threshold=stage2_mask_box_threshold, + dino_text_threshold=stage2_mask_text_threshold, + dino_debug=True + )) + mask = mask[:, :, 0] - async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) - # TODO also show do_not_show_finished_images=len(tasks) == 1 - yield_result(async_task, mask, async_task.black_out_nsfw, False, - do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results) - # TODO make configurable - denoising_strength_stage2 = 0.5 - inpaint_respective_field_stage2 = 0.0 - inpaint_head_model_path_stage2 = None - inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail - goals_stage2 = ['inpaint'] - denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint( - async_task, None, inpaint_head_model_path_stage2, img, mask, - inpaint_parameterized_stage2, denoising_strength_stage2, - inpaint_respective_field_stage2, switch, current_progress, True) + async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) + # TODO also show do_not_show_finished_images=len(tasks) == 1 + yield_result(async_task, mask, async_task.black_out_nsfw, False, + do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results) + # TODO make configurable + denoising_strength_stage2 = 0.5 + inpaint_respective_field_stage2 = 0.0 + inpaint_head_model_path_stage2 = None + inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail + goals_stage2 = ['inpaint'] + denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint( + async_task, None, inpaint_head_model_path_stage2, img, mask, + inpaint_parameterized_stage2, denoising_strength_stage2, + inpaint_respective_field_stage2, switch, current_progress, True) - process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, - current_task_id, denoising_strength_stage2, final_scheduler_name, goals_stage2, - initial_latent_stage2, switch, task, tasks, tiled, use_expansion, width_stage2, - height_stage2) + imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, + current_task_id, denoising_strength_stage2, final_scheduler_name, goals_stage2, + initial_latent_stage2, switch, task, tasks, tiled, use_expansion, width_stage2, + height_stage2) - # reset unet and inpaint_worker - pipeline.final_unet = final_unet - inpaint_worker.current_task = None + # reset and prepare next iteration + img = imgs2[0] + pipeline.final_unet = final_unet + inpaint_worker.current_task = None except ldm_patched.modules.model_management.InterruptProcessingException: if async_task.last_stop == 'skip': diff --git a/modules/config.py b/modules/config.py index 1fa7d87b..16e3043a 100644 --- a/modules/config.py +++ b/modules/config.py @@ -502,6 +502,20 @@ example_inpaint_prompts = get_config_item_or_set_default( validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), expected_type=list ) +example_stage2_prompts = get_config_item_or_set_default( + key='example_stage2_prompts', + default_value=[ + 'face', 'eye', 'mouth', 'hair', 'hand', 'body' + ], + validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), + expected_type=list +) +default_max_stage2_tabs = get_config_item_or_set_default( + key='default_max_stage2_tabs', + default_value=3, + validator=lambda x: isinstance(x, int) and 1 <= x <= 5, + expected_type=int +) default_black_out_nsfw = get_config_item_or_set_default( key='default_black_out_nsfw', default_value=False, @@ -528,6 +542,7 @@ metadata_created_by = get_config_item_or_set_default( ) example_inpaint_prompts = [[x] for x in example_inpaint_prompts] +example_stage2_prompts = [[x] for x in example_stage2_prompts] default_inpaint_mask_model = get_config_item_or_set_default( key='default_inpaint_mask_model', diff --git a/modules/flags.py b/modules/flags.py index ed9a5606..1169bd5b 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -76,7 +76,7 @@ output_formats = ['png', 'jpeg', 'webp'] inpaint_mask_models = ['u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam'] inpaint_mask_cloth_category = ['full', 'upper', 'lower'] -inpaint_mask_sam_model = [('base', 'vit_b'), ('large', 'vit_l'), ('huge', 'vit_h')] +inpaint_mask_sam_model = ['vit_b', 'vit_l', 'vit_h'] inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] inpaint_option_default = 'Inpaint or Outpaint (default)' diff --git a/webui.py b/webui.py index 5af17a00..2f0c8aa9 100644 --- a/webui.py +++ b/webui.py @@ -147,6 +147,7 @@ with shared.gradio_root: skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False) with gr.Row(elem_classes='advanced_check_row'): input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') + stage2_checkbox = gr.Checkbox(label='Stage2', value=False, container=False, elem_classes='min_check') advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check') with gr.Row(visible=False) as image_input_panel: with gr.Tabs(): @@ -297,6 +298,37 @@ with shared.gradio_root: metadata_input_image.upload(trigger_metadata_preview, inputs=metadata_input_image, outputs=metadata_json, queue=False, show_progress=True) + with gr.Row(visible=False) as stage2_input_panel: + with gr.Tabs(): + stage2_ctrls = [] + for index in range(modules.config.default_max_stage2_tabs): + with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item: + stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False) + with gr.Accordion('Options', visible=True, open=False) as stage2_accordion: + # stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True) + stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True) + example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts, + label='Additional Prompt Quick List', + components=[stage2_mask_dino_prompt_text], + visible=True) + example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False) + + with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options: + stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True) + stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True) + stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) + + stage2_ctrls += [ + stage2_enabled, + # stage2_mode, + stage2_mask_dino_prompt_text, + stage2_mask_sam_model, + stage2_mask_box_threshold, + stage2_mask_text_threshold + ] + + stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled, + outputs=stage2_accordion, queue=False, show_progress=False) switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}" down_js = "() => {viewer_to_bottom();}" @@ -311,6 +343,9 @@ with shared.gradio_root: desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False) metadata_tab.select(lambda: 'metadata', outputs=current_tab, queue=False, _js=down_js, show_progress=False) + stage2_checkbox.change(lambda x: gr.update(visible=x), inputs=stage2_checkbox, + outputs=stage2_input_panel, queue=False, show_progress=False, _js=switch_js) + with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column: with gr.Tab(label='Settings'): if not args_manager.args.disable_preset_selection: @@ -772,6 +807,7 @@ with shared.gradio_root: ctrls += [save_metadata_to_images, metadata_scheme] ctrls += ip_ctrls + ctrls += stage2_ctrls def parse_meta(raw_prompt_txt, is_generating): loaded_json = None