From dbc844804b3eb3fbf0e779b54b3e40e87ed69393 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Wed, 12 Jun 2024 22:16:02 +0200 Subject: [PATCH] feat: add handling for stage2_mask_sam_max_num_boxes and config --- modules/async_worker.py | 21 ++++++++++++++------- modules/config.py | 10 ++++++++-- webui.py | 20 +++++++++++--------- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 0367d12a..3191895f 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -108,21 +108,26 @@ class AsyncTask: if cn_img is not None: self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) + self.debugging_dino = args.pop() + self.dino_erode_or_dilate = args.pop() + self.stage2_ctrls = [] - for _ in range(modules.config.default_max_stage2_tabs): + for _ in range(modules.config.default_stage2_tabs): stage2_enabled = args.pop() # stage2_mode = args.pop() stage2_mask_dino_prompt_text = args.pop() - stage2_mask_sam_model = args.pop() stage2_mask_box_threshold = args.pop() stage2_mask_text_threshold = args.pop() + stage2_mask_sam_max_num_boxes = args.pop() + stage2_mask_sam_model = args.pop() if stage2_enabled: self.stage2_ctrls.append([ # stage2_mode, stage2_mask_dino_prompt_text, - stage2_mask_sam_model, stage2_mask_box_threshold, - stage2_mask_text_threshold + stage2_mask_text_threshold, + stage2_mask_sam_max_num_boxes, + stage2_mask_sam_model, ]) @@ -1040,13 +1045,15 @@ def worker(): continue for img in imgs: - for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls: + for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls: mask = generate_mask_from_image(img, sam_options=SAMOptions( dino_prompt=stage2_mask_dino_prompt_text, - model_type=stage2_mask_sam_model, dino_box_threshold=stage2_mask_box_threshold, dino_text_threshold=stage2_mask_text_threshold, - dino_debug=True + dino_erode_or_dilate=async_task.dino_erode_or_dilate, + dino_debug=async_task.debugging_dino, + max_num_boxes=stage2_mask_sam_max_num_boxes, + model_type=stage2_mask_sam_model )) mask = mask[:, :, 0] diff --git a/modules/config.py b/modules/config.py index 16e3043a..ef8e9576 100644 --- a/modules/config.py +++ b/modules/config.py @@ -510,12 +510,18 @@ example_stage2_prompts = get_config_item_or_set_default( validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), expected_type=list ) -default_max_stage2_tabs = get_config_item_or_set_default( - key='default_max_stage2_tabs', +default_stage2_tabs = get_config_item_or_set_default( + key='default_stage2_tabs', default_value=3, validator=lambda x: isinstance(x, int) and 1 <= x <= 5, expected_type=int ) +default_sam_max_num_boxes = get_config_item_or_set_default( + key='default_sam_max_num_boxes', + default_value=2, + validator=lambda x: isinstance(x, int) and 1 <= x <= 5, + expected_type=int +) default_black_out_nsfw = get_config_item_or_set_default( key='default_black_out_nsfw', default_value=False, diff --git a/webui.py b/webui.py index 2f0c8aa9..839cd1db 100644 --- a/webui.py +++ b/webui.py @@ -232,7 +232,7 @@ with shared.gradio_root: inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) generate_mask_button = gr.Button(value='Generate mask from image') - def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino): + def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, dino_debug): from extras.inpaint_mask import generate_mask_from_image extras = {} @@ -245,7 +245,7 @@ with shared.gradio_root: dino_box_threshold=box_threshold, dino_text_threshold=text_threshold, dino_erode_or_dilate=dino_erode_or_dilate, - dino_debug=debug_dino, + dino_debug=dino_debug, max_num_boxes=2, #TODO replace with actual value model_type=sam_model ) @@ -301,7 +301,7 @@ with shared.gradio_root: with gr.Row(visible=False) as stage2_input_panel: with gr.Tabs(): stage2_ctrls = [] - for index in range(modules.config.default_max_stage2_tabs): + for index in range(modules.config.default_stage2_tabs): with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item: stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False) with gr.Accordion('Options', visible=True, open=False) as stage2_accordion: @@ -317,14 +317,16 @@ with shared.gradio_root: stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True) stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True) stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) + stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True) stage2_ctrls += [ stage2_enabled, # stage2_mode, stage2_mask_dino_prompt_text, - stage2_mask_sam_model, stage2_mask_box_threshold, - stage2_mask_text_threshold + stage2_mask_text_threshold, + stage2_mask_sam_max_num_boxes, + stage2_mask_sam_model, ] stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled, @@ -598,8 +600,8 @@ with shared.gradio_root: with gr.Tab(label='Inpaint'): debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False) - debug_dino = gr.Checkbox(label='Debug GroundingDINO', value=False, - info='Used for SAM object detection and box generation') + debugging_dino = gr.Checkbox(label='Debug GroundingDINO', value=False, + info='Used for SAM object detection and box generation') inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False) inpaint_engine = gr.Dropdown(label='Inpaint Engine', value=modules.config.default_inpaint_engine_version, @@ -779,7 +781,7 @@ with shared.gradio_root: inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category, inpaint_mask_dino_prompt_text, inpaint_mask_sam_model, inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate, - debug_dino], + debugging_dino], outputs=inpaint_mask_image, show_progress=True, queue=True) ctrls = [currentTask, generate_image_grid] @@ -807,7 +809,7 @@ with shared.gradio_root: ctrls += [save_metadata_to_images, metadata_scheme] ctrls += ip_ctrls - ctrls += stage2_ctrls + ctrls += [debugging_dino, dino_erode_or_dilate] + stage2_ctrls def parse_meta(raw_prompt_txt, is_generating): loaded_json = None