fix: correctly count sam masks
This commit is contained in:
parent
ef9fd293ff
commit
b585d9dfa7
|
|
@ -85,10 +85,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
|
||||
sam_predictor = SamPredictor(sam)
|
||||
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
|
||||
|
||||
dino_detection_count = boxes.size(0)
|
||||
sam_detection_count = 0
|
||||
sam_detection_on_mask_count = 0
|
||||
|
||||
if dino_detection_count > 0:
|
||||
sam_predictor.set_image(image)
|
||||
|
|
@ -118,11 +115,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
)
|
||||
|
||||
masks = optimize_masks(masks)
|
||||
|
||||
sam_detection_count = len(masks)
|
||||
sam_objects = min(len(logits), sam_options.max_num_boxes)
|
||||
for obj_ind in range(sam_objects):
|
||||
mask_tensor = masks[obj_ind][0]
|
||||
final_mask_tensor += mask_tensor
|
||||
sam_detection_on_mask_count += 1
|
||||
|
||||
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
|
||||
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
|
||||
|
|
|
|||
Loading…
Reference in New Issue