refactor: rename max_num_boxes to max_detections
This commit is contained in:
parent
b7fb42436c
commit
229ff81738
|
|
@ -14,7 +14,7 @@ sam_options = SAMOptions(
|
|||
dino_text_threshold=0.25,
|
||||
dino_erode_or_dilate=0,
|
||||
dino_debug=False,
|
||||
max_num_boxes=2,
|
||||
max_detections=2,
|
||||
model_type='vit_b'
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
|
||||
import modules.config
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -18,7 +20,7 @@ class SAMOptions:
|
|||
dino_debug=False,
|
||||
|
||||
# SAM
|
||||
max_num_boxes=2,
|
||||
max_detections=2,
|
||||
model_type='vit_b'
|
||||
):
|
||||
self.dino_prompt = dino_prompt
|
||||
|
|
@ -26,7 +28,7 @@ class SAMOptions:
|
|||
self.dino_text_threshold = dino_text_threshold
|
||||
self.dino_erode_or_dilate = dino_erode_or_dilate
|
||||
self.dino_debug = dino_debug
|
||||
self.max_num_boxes = max_num_boxes
|
||||
self.max_detections = max_detections
|
||||
self.model_type = model_type
|
||||
|
||||
|
||||
|
|
@ -114,7 +116,9 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
|
||||
masks = optimize_masks(masks)
|
||||
sam_detection_count = len(masks)
|
||||
sam_objects = min(len(logits), sam_options.max_num_boxes)
|
||||
if sam_options.max_detections == 0:
|
||||
sam_options.max_detections = sys.maxsize
|
||||
sam_objects = min(len(logits), sam_options.max_detections)
|
||||
for obj_ind in range(sam_objects):
|
||||
mask_tensor = masks[obj_ind][0]
|
||||
final_mask_tensor += mask_tensor
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class AsyncTask:
|
|||
enhance_mask_sam_model = args.pop()
|
||||
enhance_mask_text_threshold = args.pop()
|
||||
enhance_mask_box_threshold = args.pop()
|
||||
enhance_mask_sam_max_num_boxes = args.pop()
|
||||
enhance_mask_sam_max_detections = args.pop()
|
||||
enhance_inpaint_disable_initial_latent = args.pop()
|
||||
enhance_inpaint_engine = args.pop()
|
||||
enhance_inpaint_strength = args.pop()
|
||||
|
|
@ -137,7 +137,7 @@ class AsyncTask:
|
|||
enhance_mask_sam_model,
|
||||
enhance_mask_text_threshold,
|
||||
enhance_mask_box_threshold,
|
||||
enhance_mask_sam_max_num_boxes,
|
||||
enhance_mask_sam_max_detections,
|
||||
enhance_inpaint_disable_initial_latent,
|
||||
enhance_inpaint_engine,
|
||||
enhance_inpaint_strength,
|
||||
|
|
@ -1160,7 +1160,7 @@ def worker():
|
|||
current_task_id = -1
|
||||
for imgs in generated_imgs.values():
|
||||
for img in imgs:
|
||||
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
||||
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
||||
current_task_id += 1
|
||||
current_progress = int(base_progress + (100 - preparation_steps) * float(current_task_id * async_task.steps) / float(all_steps))
|
||||
progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
|
||||
|
|
@ -1176,7 +1176,7 @@ def worker():
|
|||
dino_text_threshold=enhance_mask_text_threshold,
|
||||
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
|
||||
dino_debug=async_task.debugging_dino,
|
||||
max_num_boxes=enhance_mask_sam_max_num_boxes,
|
||||
max_detections=enhance_mask_sam_max_detections,
|
||||
model_type=enhance_mask_sam_model
|
||||
))
|
||||
if len(mask.shape) == 3:
|
||||
|
|
|
|||
|
|
@ -516,10 +516,10 @@ default_enhance_tabs = get_config_item_or_set_default(
|
|||
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
|
||||
expected_type=int
|
||||
)
|
||||
default_sam_max_num_boxes = get_config_item_or_set_default(
|
||||
key='default_sam_max_num_boxes',
|
||||
default_value=2,
|
||||
validator=lambda x: isinstance(x, int) and 1 <= x <= 5,
|
||||
default_sam_max_detections = get_config_item_or_set_default(
|
||||
key='default_sam_max_detections',
|
||||
default_value=0,
|
||||
validator=lambda x: isinstance(x, int) and 0 <= x <= 10,
|
||||
expected_type=int
|
||||
)
|
||||
default_black_out_nsfw = get_config_item_or_set_default(
|
||||
|
|
|
|||
19
webui.py
19
webui.py
|
|
@ -258,10 +258,10 @@ with shared.gradio_root:
|
|||
inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model)
|
||||
inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05)
|
||||
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
|
||||
inpaint_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
|
||||
inpaint_mask_sam_num_boxes = gr.Slider(label="Maximum number of detections", info="Set to 0 to detect all", minimum=0, maximum=10, value=modules.config.default_sam_max_detections, step=1, interactive=True)
|
||||
generate_mask_button = gr.Button(value='Generate mask from image')
|
||||
|
||||
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, sam_max_num_boxes, dino_erode_or_dilate, dino_debug):
|
||||
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, sam_max_detections, dino_erode_or_dilate, dino_debug):
|
||||
from extras.inpaint_mask import generate_mask_from_image
|
||||
|
||||
extras = {}
|
||||
|
|
@ -275,7 +275,7 @@ with shared.gradio_root:
|
|||
dino_text_threshold=text_threshold,
|
||||
dino_erode_or_dilate=dino_erode_or_dilate,
|
||||
dino_debug=dino_debug,
|
||||
max_num_boxes=sam_max_num_boxes,
|
||||
max_detections=sam_max_detections,
|
||||
model_type=sam_model
|
||||
)
|
||||
|
||||
|
|
@ -380,10 +380,11 @@ with shared.gradio_root:
|
|||
enhance_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0,
|
||||
maximum=1.0, value=0.25, step=0.05,
|
||||
interactive=True)
|
||||
enhance_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections",
|
||||
minimum=1, maximum=5,
|
||||
value=modules.config.default_sam_max_num_boxes,
|
||||
step=1, interactive=True)
|
||||
enhance_mask_sam_max_detections = gr.Slider(label="Maximum number of detections",
|
||||
info="Set to 0 to detect all",
|
||||
minimum=0, maximum=10,
|
||||
value=modules.config.default_sam_max_detections,
|
||||
step=1, interactive=True)
|
||||
|
||||
with gr.Accordion("Inpaint", visible=True, open=False):
|
||||
enhance_inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options,
|
||||
|
|
@ -420,7 +421,7 @@ with shared.gradio_root:
|
|||
enhance_mask_sam_model,
|
||||
enhance_mask_text_threshold,
|
||||
enhance_mask_box_threshold,
|
||||
enhance_mask_sam_max_num_boxes,
|
||||
enhance_mask_sam_max_detections,
|
||||
enhance_inpaint_disable_initial_latent,
|
||||
enhance_inpaint_engine,
|
||||
enhance_inpaint_strength,
|
||||
|
|
@ -868,7 +869,7 @@ with shared.gradio_root:
|
|||
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
|
||||
inpaint_mask_dino_prompt_text, inpaint_mask_sam_model,
|
||||
inpaint_mask_box_threshold, inpaint_mask_text_threshold,
|
||||
inpaint_mask_sam_max_num_boxes, dino_erode_or_dilate, debugging_dino],
|
||||
inpaint_mask_sam_num_boxes, dino_erode_or_dilate, debugging_dino],
|
||||
outputs=inpaint_mask_image, show_progress=True, queue=True)
|
||||
|
||||
ctrls = [currentTask, generate_image_grid]
|
||||
|
|
|
|||
Loading…
Reference in New Issue