feat: add handling for stage2_mask_sam_max_num_boxes and config
This commit is contained in:
parent
9998b52dd2
commit
dbc844804b
|
|
@ -108,21 +108,26 @@ class AsyncTask:
|
|||
if cn_img is not None:
|
||||
self.cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
||||
|
||||
self.debugging_dino = args.pop()
|
||||
self.dino_erode_or_dilate = args.pop()
|
||||
|
||||
self.stage2_ctrls = []
|
||||
for _ in range(modules.config.default_max_stage2_tabs):
|
||||
for _ in range(modules.config.default_stage2_tabs):
|
||||
stage2_enabled = args.pop()
|
||||
# stage2_mode = args.pop()
|
||||
stage2_mask_dino_prompt_text = args.pop()
|
||||
stage2_mask_sam_model = args.pop()
|
||||
stage2_mask_box_threshold = args.pop()
|
||||
stage2_mask_text_threshold = args.pop()
|
||||
stage2_mask_sam_max_num_boxes = args.pop()
|
||||
stage2_mask_sam_model = args.pop()
|
||||
if stage2_enabled:
|
||||
self.stage2_ctrls.append([
|
||||
# stage2_mode,
|
||||
stage2_mask_dino_prompt_text,
|
||||
stage2_mask_sam_model,
|
||||
stage2_mask_box_threshold,
|
||||
stage2_mask_text_threshold
|
||||
stage2_mask_text_threshold,
|
||||
stage2_mask_sam_max_num_boxes,
|
||||
stage2_mask_sam_model,
|
||||
])
|
||||
|
||||
|
||||
|
|
@ -1040,13 +1045,15 @@ def worker():
|
|||
continue
|
||||
|
||||
for img in imgs:
|
||||
for stage2_mask_dino_prompt_text, stage2_mask_sam_model, stage2_mask_box_threshold, stage2_mask_text_threshold in async_task.stage2_ctrls:
|
||||
for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
|
||||
mask = generate_mask_from_image(img, sam_options=SAMOptions(
|
||||
dino_prompt=stage2_mask_dino_prompt_text,
|
||||
model_type=stage2_mask_sam_model,
|
||||
dino_box_threshold=stage2_mask_box_threshold,
|
||||
dino_text_threshold=stage2_mask_text_threshold,
|
||||
dino_debug=True
|
||||
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
|
||||
dino_debug=async_task.debugging_dino,
|
||||
max_num_boxes=stage2_mask_sam_max_num_boxes,
|
||||
model_type=stage2_mask_sam_model
|
||||
))
|
||||
mask = mask[:, :, 0]
|
||||
|
||||
|
|
|
|||
|
|
@ -510,12 +510,18 @@ example_stage2_prompts = get_config_item_or_set_default(
|
|||
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
|
||||
expected_type=list
|
||||
)
|
||||
default_max_stage2_tabs = get_config_item_or_set_default(
|
||||
key='default_max_stage2_tabs',
|
||||
default_stage2_tabs = get_config_item_or_set_default(
|
||||
key='default_stage2_tabs',
|
||||
default_value=3,
|
||||
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
|
||||
expected_type=int
|
||||
)
|
||||
default_sam_max_num_boxes = get_config_item_or_set_default(
|
||||
key='default_sam_max_num_boxes',
|
||||
default_value=2,
|
||||
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
|
||||
expected_type=int
|
||||
)
|
||||
default_black_out_nsfw = get_config_item_or_set_default(
|
||||
key='default_black_out_nsfw',
|
||||
default_value=False,
|
||||
|
|
|
|||
20
webui.py
20
webui.py
|
|
@ -232,7 +232,7 @@ with shared.gradio_root:
|
|||
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
|
||||
generate_mask_button = gr.Button(value='Generate mask from image')
|
||||
|
||||
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino):
|
||||
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, dino_erode_or_dilate, dino_debug):
|
||||
from extras.inpaint_mask import generate_mask_from_image
|
||||
|
||||
extras = {}
|
||||
|
|
@ -245,7 +245,7 @@ with shared.gradio_root:
|
|||
dino_box_threshold=box_threshold,
|
||||
dino_text_threshold=text_threshold,
|
||||
dino_erode_or_dilate=dino_erode_or_dilate,
|
||||
dino_debug=debug_dino,
|
||||
dino_debug=dino_debug,
|
||||
max_num_boxes=2, #TODO replace with actual value
|
||||
model_type=sam_model
|
||||
)
|
||||
|
|
@ -301,7 +301,7 @@ with shared.gradio_root:
|
|||
with gr.Row(visible=False) as stage2_input_panel:
|
||||
with gr.Tabs():
|
||||
stage2_ctrls = []
|
||||
for index in range(modules.config.default_max_stage2_tabs):
|
||||
for index in range(modules.config.default_stage2_tabs):
|
||||
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
|
||||
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
|
||||
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
|
||||
|
|
@ -317,14 +317,16 @@ with shared.gradio_root:
|
|||
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
|
||||
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
|
||||
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
|
||||
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
|
||||
|
||||
stage2_ctrls += [
|
||||
stage2_enabled,
|
||||
# stage2_mode,
|
||||
stage2_mask_dino_prompt_text,
|
||||
stage2_mask_sam_model,
|
||||
stage2_mask_box_threshold,
|
||||
stage2_mask_text_threshold
|
||||
stage2_mask_text_threshold,
|
||||
stage2_mask_sam_max_num_boxes,
|
||||
stage2_mask_sam_model,
|
||||
]
|
||||
|
||||
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
|
||||
|
|
@ -598,8 +600,8 @@ with shared.gradio_root:
|
|||
|
||||
with gr.Tab(label='Inpaint'):
|
||||
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
|
||||
debug_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
|
||||
info='Used for SAM object detection and box generation')
|
||||
debugging_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
|
||||
info='Used for SAM object detection and box generation')
|
||||
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
|
||||
inpaint_engine = gr.Dropdown(label='Inpaint Engine',
|
||||
value=modules.config.default_inpaint_engine_version,
|
||||
|
|
@ -779,7 +781,7 @@ with shared.gradio_root:
|
|||
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
|
||||
inpaint_mask_dino_prompt_text, inpaint_mask_sam_model,
|
||||
inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate,
|
||||
debug_dino],
|
||||
debugging_dino],
|
||||
outputs=inpaint_mask_image, show_progress=True, queue=True)
|
||||
|
||||
ctrls = [currentTask, generate_image_grid]
|
||||
|
|
@ -807,7 +809,7 @@ with shared.gradio_root:
|
|||
ctrls += [save_metadata_to_images, metadata_scheme]
|
||||
|
||||
ctrls += ip_ctrls
|
||||
ctrls += stage2_ctrls
|
||||
ctrls += [debugging_dino, dino_erode_or_dilate] + stage2_ctrls
|
||||
|
||||
def parse_meta(raw_prompt_txt, is_generating):
|
||||
loaded_json = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue