From 229ff8173856d036270c5d64a0d9a4a6e7a69c40 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Tue, 18 Jun 2024 21:07:27 +0200 Subject: [PATCH] refactor: rename max_num_boxes to max_detections --- experiments_mask_generation.py | 2 +- extras/inpaint_mask.py | 10 +++++++--- modules/async_worker.py | 8 ++++---- modules/config.py | 8 ++++---- webui.py | 19 ++++++++++--------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/experiments_mask_generation.py b/experiments_mask_generation.py index 0f6b960d..a27eb39c 100644 --- a/experiments_mask_generation.py +++ b/experiments_mask_generation.py @@ -14,7 +14,7 @@ sam_options = SAMOptions( dino_text_threshold=0.25, dino_erode_or_dilate=0, dino_debug=False, - max_num_boxes=2, + max_detections=2, model_type='vit_b' ) diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 1d04d86c..086b7da6 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,3 +1,5 @@ +import sys + import modules.config import numpy as np import torch @@ -18,7 +20,7 @@ class SAMOptions: dino_debug=False, # SAM - max_num_boxes=2, + max_detections=2, model_type='vit_b' ): self.dino_prompt = dino_prompt @@ -26,7 +28,7 @@ class SAMOptions: self.dino_text_threshold = dino_text_threshold self.dino_erode_or_dilate = dino_erode_or_dilate self.dino_debug = dino_debug - self.max_num_boxes = max_num_boxes + self.max_detections = max_detections self.model_type = model_type @@ -114,7 +116,9 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= masks = optimize_masks(masks) sam_detection_count = len(masks) - sam_objects = min(len(logits), sam_options.max_num_boxes) + if sam_options.max_detections == 0: + sam_options.max_detections = sys.maxsize + sam_objects = min(len(logits), sam_options.max_detections) for obj_ind in range(sam_objects): mask_tensor = masks[obj_ind][0] final_mask_tensor += mask_tensor diff --git a/modules/async_worker.py b/modules/async_worker.py index e100dc9f..8864582b 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -123,7 +123,7 @@ class AsyncTask: enhance_mask_sam_model = args.pop() enhance_mask_text_threshold = args.pop() enhance_mask_box_threshold = args.pop() - enhance_mask_sam_max_num_boxes = args.pop() + enhance_mask_sam_max_detections = args.pop() enhance_inpaint_disable_initial_latent = args.pop() enhance_inpaint_engine = args.pop() enhance_inpaint_strength = args.pop() @@ -137,7 +137,7 @@ class AsyncTask: enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, - enhance_mask_sam_max_num_boxes, + enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, @@ -1160,7 +1160,7 @@ def worker(): current_task_id = -1 for imgs in generated_imgs.values(): for img in imgs: - for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls: + for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls: current_task_id += 1 current_progress = int(base_progress + (100 - preparation_steps) * float(current_task_id * async_task.steps) / float(all_steps)) progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...') @@ -1176,7 +1176,7 @@ def worker(): dino_text_threshold=enhance_mask_text_threshold, dino_erode_or_dilate=async_task.dino_erode_or_dilate, dino_debug=async_task.debugging_dino, - max_num_boxes=enhance_mask_sam_max_num_boxes, + max_detections=enhance_mask_sam_max_detections, model_type=enhance_mask_sam_model )) if len(mask.shape) == 3: diff --git a/modules/config.py b/modules/config.py index 929dd9ce..d3240888 100644 --- a/modules/config.py +++ b/modules/config.py @@ -516,10 +516,10 @@ default_enhance_tabs = get_config_item_or_set_default( validator=lambda x: isinstance(x, int) and 1 <= x <= 5, expected_type=int ) -default_sam_max_num_boxes = get_config_item_or_set_default( - key='default_sam_max_num_boxes', - default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= 5, +default_sam_max_detections = get_config_item_or_set_default( + key='default_sam_max_detections', + default_value=0, + validator=lambda x: isinstance(x, int) and 0 <= x <= 10, expected_type=int ) default_black_out_nsfw = get_config_item_or_set_default( diff --git a/webui.py b/webui.py index 9185dd10..021ed55b 100644 --- a/webui.py +++ b/webui.py @@ -258,10 +258,10 @@ with shared.gradio_root: inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model) inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05) inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) - inpaint_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True) + inpaint_mask_sam_num_boxes = gr.Slider(label="Maximum number of detections", info="Set to 0 to detect all", minimum=0, maximum=10, value=modules.config.default_sam_max_detections, step=1, interactive=True) generate_mask_button = gr.Button(value='Generate mask from image') - def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, sam_max_num_boxes, dino_erode_or_dilate, dino_debug): + def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, box_threshold, text_threshold, sam_max_detections, dino_erode_or_dilate, dino_debug): from extras.inpaint_mask import generate_mask_from_image extras = {} @@ -275,7 +275,7 @@ with shared.gradio_root: dino_text_threshold=text_threshold, dino_erode_or_dilate=dino_erode_or_dilate, dino_debug=dino_debug, - max_num_boxes=sam_max_num_boxes, + max_detections=sam_max_detections, model_type=sam_model ) @@ -380,10 +380,11 @@ with shared.gradio_root: enhance_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True) - enhance_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", - minimum=1, maximum=5, - value=modules.config.default_sam_max_num_boxes, - step=1, interactive=True) + enhance_mask_sam_max_detections = gr.Slider(label="Maximum number of detections", + info="Set to 0 to detect all", + minimum=0, maximum=10, + value=modules.config.default_sam_max_detections, + step=1, interactive=True) with gr.Accordion("Inpaint", visible=True, open=False): enhance_inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options, @@ -420,7 +421,7 @@ with shared.gradio_root: enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, - enhance_mask_sam_max_num_boxes, + enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, @@ -868,7 +869,7 @@ with shared.gradio_root: inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category, inpaint_mask_dino_prompt_text, inpaint_mask_sam_model, inpaint_mask_box_threshold, inpaint_mask_text_threshold, - inpaint_mask_sam_max_num_boxes, dino_erode_or_dilate, debugging_dino], + inpaint_mask_sam_num_boxes, dino_erode_or_dilate, debugging_dino], outputs=inpaint_mask_image, show_progress=True, queue=True) ctrls = [currentTask, generate_image_grid]