diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index f8ecd2c8..1d04d86c 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -56,7 +56,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= if 'image' in image: image = image['image'] - if mask_model != 'sam' and sam_options is None: + if mask_model != 'sam' or sam_options is None: result = remove( image, session=new_session(mask_model, **extras), @@ -66,8 +66,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count - assert sam_options is not None - detections, boxes, logits, phrases = default_groundingdino( image=image, caption=sam_options.dino_prompt, diff --git a/modules/async_worker.py b/modules/async_worker.py index 8925431b..92918c73 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -118,6 +118,7 @@ class AsyncTask: enhance_mask_dino_prompt_text = args.pop() enhance_prompt = args.pop() enhance_negative_prompt = args.pop() + enhance_mask_model = args.pop() enhance_mask_sam_model = args.pop() enhance_mask_text_threshold = args.pop() enhance_mask_box_threshold = args.pop() @@ -131,6 +132,7 @@ class AsyncTask: enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, + enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, @@ -1080,9 +1082,11 @@ def worker(): progressbar(async_task, current_progress, 'Processing enhance ...') for img in imgs: - for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls: - print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"') - mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions( + for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls: + if enhance_mask_model == 'sam': + print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"') + + mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, mask_model=enhance_mask_model, sam_options=SAMOptions( dino_prompt=enhance_mask_dino_prompt_text, dino_box_threshold=enhance_mask_box_threshold, dino_text_threshold=enhance_mask_text_threshold, @@ -1091,7 +1095,8 @@ def worker(): max_num_boxes=enhance_mask_sam_max_num_boxes, model_type=enhance_mask_sam_model )) - mask = mask[:, :, 0] + if len(mask.shape) == 3: + mask = mask[:, :, 0] if int(async_task.inpaint_erode_or_dilate) != 0: mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate) @@ -1106,7 +1111,7 @@ def worker(): print(f'[Enhance] {sam_detection_count} segments detected in boxes') print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask') - if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0: + if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0): print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping') continue diff --git a/webui.py b/webui.py index 8f2f86ad..2e650aa4 100644 --- a/webui.py +++ b/webui.py @@ -416,6 +416,7 @@ with shared.gradio_root: enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, + enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold,