feat: add handling for stage2_mask_sam_max_num_boxes and config

This commit is contained in:
Manuel Schmid 2024-06-12 22:16:02 +02:00
parent 9998b52dd2
commit dbc844804b
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 33 additions and 18 deletions

View File

@ -108,21 +108,26 @@ class AsyncTask:
if cn_img is not None:
self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
self.debugging_dino = args.pop()
self.dino_erode_or_dilate = args.pop()
self.stage2_ctrls = []
for _ in range(modules.config.default_max_stage2_tabs):
for _ in range(modules.config.default_stage2_tabs):
stage2_enabled = args.pop()
# stage2_mode = args.pop()
stage2_mask_dino_prompt_text = args.pop()
stage2_mask_sam_model = args.pop()
stage2_mask_box_threshold = args.pop()
stage2_mask_text_threshold = args.pop()
stage2_mask_sam_max_num_boxes = args.pop()
stage2_mask_sam_model = args.pop()
if stage2_enabled:
self.stage2_ctrls.append([
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold,
stage2_mask_text_threshold
stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
])
@ -1040,13 +1045,15 @@ def worker():
continue
for img in imgs:
for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls:
for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
mask = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=stage2_mask_dino_prompt_text,
model_type=stage2_mask_sam_model,
dino_box_threshold=stage2_mask_box_threshold,
dino_text_threshold=stage2_mask_text_threshold,
dino_debug=True
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_num_boxes=stage2_mask_sam_max_num_boxes,
model_type=stage2_mask_sam_model
))
mask = mask[:, :, 0]

View File

@ -510,12 +510,18 @@ example_stage2_prompts = get_config_item_or_set_default(
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list
)
default_max_stage2_tabs = get_config_item_or_set_default(
key='default_max_stage2_tabs',
default_stage2_tabs = get_config_item_or_set_default(
key='default_stage2_tabs',
default_value=3,
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
expected_type=int
)
default_sam_max_num_boxes = get_config_item_or_set_default(
key='default_sam_max_num_boxes',
default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
expected_type=int
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,

View File

@ -232,7 +232,7 @@ with shared.gradio_root:
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
generate_mask_button = gr.Button(value='Generate mask from image')
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino):
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, dino_debug):
from extras.inpaint_mask import generate_mask_from_image
extras = {}
@ -245,7 +245,7 @@ with shared.gradio_root:
dino_box_threshold=box_threshold,
dino_text_threshold=text_threshold,
dino_erode_or_dilate=dino_erode_or_dilate,
dino_debug=debug_dino,
dino_debug=dino_debug,
max_num_boxes=2, #TODO replace with actual value
model_type=sam_model
)
@ -301,7 +301,7 @@ with shared.gradio_root:
with gr.Row(visible=False) as stage2_input_panel:
with gr.Tabs():
stage2_ctrls = []
for index in range(modules.config.default_max_stage2_tabs):
for index in range(modules.config.default_stage2_tabs):
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
@ -317,14 +317,16 @@ with shared.gradio_root:
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
stage2_ctrls += [
stage2_enabled,
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_sam_model,
stage2_mask_box_threshold,
stage2_mask_text_threshold
stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
]
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
@ -598,8 +600,8 @@ with shared.gradio_root:
with gr.Tab(label='Inpaint'):
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
debug_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
info='Used for SAM object detection and box generation')
debugging_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
info='Used for SAM object detection and box generation')
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
inpaint_engine = gr.Dropdown(label='Inpaint Engine',
value=modules.config.default_inpaint_engine_version,
@ -779,7 +781,7 @@ with shared.gradio_root:
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
inpaint_mask_dino_prompt_text, inpaint_mask_sam_model,
inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate,
debug_dino],
debugging_dino],
outputs=inpaint_mask_image, show_progress=True, queue=True)
ctrls = [currentTask, generate_image_grid]
@ -807,7 +809,7 @@ with shared.gradio_root:
ctrls += [save_metadata_to_images, metadata_scheme]
ctrls += ip_ctrls
ctrls += stage2_ctrls
ctrls += [debugging_dino, dino_erode_or_dilate] + stage2_ctrls
def parse_meta(raw_prompt_txt, is_generating):
loaded_json = None