From b585d9dfa78dadefcf04866b2e60a2595098f681 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sun, 16 Jun 2024 15:57:53 +0200 Subject: [PATCH] fix: correctly count sam masks --- extras/inpaint_mask.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index f9025ef2..f8ecd2c8 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -85,10 +85,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= sam_predictor = SamPredictor(sam) final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) - dino_detection_count = boxes.size(0) - sam_detection_count = 0 - sam_detection_on_mask_count = 0 if dino_detection_count > 0: sam_predictor.set_image(image) @@ -118,11 +115,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= ) masks = optimize_masks(masks) - + sam_detection_count = len(masks) sam_objects = min(len(logits), sam_options.max_num_boxes) for obj_ind in range(sam_objects): mask_tensor = masks[obj_ind][0] final_mask_tensor += mask_tensor + sam_detection_on_mask_count += 1 final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255