feat: add handling for stage2_mask_sam_max_num_boxes and config

This commit is contained in:
Manuel Schmid 2024-06-12 22:16:02 +02:00
parent 9998b52dd2
commit dbc844804b
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 33 additions and 18 deletions

View File

@ -108,21 +108,26 @@ class AsyncTask:
if cn_img is not None: if cn_img is not None:
self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
self.debugging_dino = args.pop()
self.dino_erode_or_dilate = args.pop()
self.stage2_ctrls = [] self.stage2_ctrls = []
for _ in range(modules.config.default_max_stage2_tabs): for _ in range(modules.config.default_stage2_tabs):
stage2_enabled = args.pop() stage2_enabled = args.pop()
# stage2_mode = args.pop() # stage2_mode = args.pop()
stage2_mask_dino_prompt_text = args.pop() stage2_mask_dino_prompt_text = args.pop()
stage2_mask_sam_model = args.pop()
stage2_mask_box_threshold = args.pop() stage2_mask_box_threshold = args.pop()
stage2_mask_text_threshold = args.pop() stage2_mask_text_threshold = args.pop()
stage2_mask_sam_max_num_boxes = args.pop()
stage2_mask_sam_model = args.pop()
if stage2_enabled: if stage2_enabled:
self.stage2_ctrls.append([ self.stage2_ctrls.append([
# stage2_mode, # stage2_mode,
stage2_mask_dino_prompt_text, stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold, stage2_mask_box_threshold,
stage2_mask_text_threshold stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
]) ])
@ -1040,13 +1045,15 @@ def worker():
continue continue
for img in imgs: for img in imgs:
for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls: for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
mask = generate_mask_from_image(img, sam_options=SAMOptions( mask = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=stage2_mask_dino_prompt_text, dino_prompt=stage2_mask_dino_prompt_text,
model_type=stage2_mask_sam_model,
dino_box_threshold=stage2_mask_box_threshold, dino_box_threshold=stage2_mask_box_threshold,
dino_text_threshold=stage2_mask_text_threshold, dino_text_threshold=stage2_mask_text_threshold,
dino_debug=True dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_num_boxes=stage2_mask_sam_max_num_boxes,
model_type=stage2_mask_sam_model
)) ))
mask = mask[:, :, 0] mask = mask[:, :, 0]

View File

@ -510,12 +510,18 @@ example_stage2_prompts = get_config_item_or_set_default(
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list expected_type=list
) )
default_max_stage2_tabs = get_config_item_or_set_default( default_stage2_tabs = get_config_item_or_set_default(
key='default_max_stage2_tabs', key='default_stage2_tabs',
default_value=3, default_value=3,
validator=lambda x: isinstance(x, int) and 1 <= x <= 5, validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
expected_type=int expected_type=int
) )
default_sam_max_num_boxes = get_config_item_or_set_default(
key='default_sam_max_num_boxes',
default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
expected_type=int
)
default_black_out_nsfw = get_config_item_or_set_default( default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw', key='default_black_out_nsfw',
default_value=False, default_value=False,

View File

@ -232,7 +232,7 @@ with shared.gradio_root:
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
generate_mask_button = gr.Button(value='Generate mask from image') generate_mask_button = gr.Button(value='Generate mask from image')
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino): def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, dino_debug):
from extras.inpaint_mask import generate_mask_from_image from extras.inpaint_mask import generate_mask_from_image
extras = {} extras = {}
@ -245,7 +245,7 @@ with shared.gradio_root:
dino_box_threshold=box_threshold, dino_box_threshold=box_threshold,
dino_text_threshold=text_threshold, dino_text_threshold=text_threshold,
dino_erode_or_dilate=dino_erode_or_dilate, dino_erode_or_dilate=dino_erode_or_dilate,
dino_debug=debug_dino, dino_debug=dino_debug,
max_num_boxes=2, #TODO replace with actual value max_num_boxes=2, #TODO replace with actual value
model_type=sam_model model_type=sam_model
) )
@ -301,7 +301,7 @@ with shared.gradio_root:
with gr.Row(visible=False) as stage2_input_panel: with gr.Row(visible=False) as stage2_input_panel:
with gr.Tabs(): with gr.Tabs():
stage2_ctrls = [] stage2_ctrls = []
for index in range(modules.config.default_max_stage2_tabs): for index in range(modules.config.default_stage2_tabs):
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item: with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False) stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion: with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
@ -317,14 +317,16 @@ with shared.gradio_root:
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True) stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True) stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
stage2_ctrls += [ stage2_ctrls += [
stage2_enabled, stage2_enabled,
# stage2_mode, # stage2_mode,
stage2_mask_dino_prompt_text, stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold, stage2_mask_box_threshold,
stage2_mask_text_threshold stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
] ]
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled, stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
@ -598,8 +600,8 @@ with shared.gradio_root:
with gr.Tab(label='Inpaint'): with gr.Tab(label='Inpaint'):
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False) debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
debug_dino = gr.Checkbox(label='Debug GroundingDINO', value=False, debugging_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
info='Used for SAM object detection and box generation') info='Used for SAM object detection and box generation')
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False) inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
inpaint_engine = gr.Dropdown(label='Inpaint Engine', inpaint_engine = gr.Dropdown(label='Inpaint Engine',
value=modules.config.default_inpaint_engine_version, value=modules.config.default_inpaint_engine_version,
@ -779,7 +781,7 @@ with shared.gradio_root:
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category, inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
inpaint_mask_dino_prompt_text, inpaint_mask_sam_model, inpaint_mask_dino_prompt_text, inpaint_mask_sam_model,
inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate, inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate,
debug_dino], debugging_dino],
outputs=inpaint_mask_image, show_progress=True, queue=True) outputs=inpaint_mask_image, show_progress=True, queue=True)
ctrls = [currentTask, generate_image_grid] ctrls = [currentTask, generate_image_grid]
@ -807,7 +809,7 @@ with shared.gradio_root:
ctrls += [save_metadata_to_images, metadata_scheme] ctrls += [save_metadata_to_images, metadata_scheme]
ctrls += ip_ctrls ctrls += ip_ctrls
ctrls += stage2_ctrls ctrls += [debugging_dino, dino_erode_or_dilate] + stage2_ctrls
def parse_meta(raw_prompt_txt, is_generating): def parse_meta(raw_prompt_txt, is_generating):
loaded_json = None loaded_json = None