feat: add dino_erode_or_dilate and dino_debug again

This commit is contained in:
Manuel Schmid 2024-06-10 23:23:38 +02:00
parent 757863c023
commit 29967d3a18
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 18 additions and 27 deletions

View File

@ -1,13 +1,12 @@
import modules.config
import numpy as np
import torch
from extras.GroundingDINO.util.inference import default_groundingdino
from extras.sam.predictor import SamPredictor
from rembg import remove, new_session
from segment_anything import sam_model_registry
from segment_anything.utils.amg import remove_small_regions
from extras.GroundingDINO.util.inference import default_groundingdino
import modules.config
class SAMOptions:
def __init__(self,
@ -26,6 +25,7 @@ class SAMOptions:
self.dino_box_threshold = dino_box_threshold
self.dino_text_threshold = dino_text_threshold
self.dino_erode_or_dilate = dino_erode_or_dilate
self.dino_debug = dino_debug
self.max_num_boxes = max_num_boxes
self.model_type = model_type
@ -68,30 +68,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
box_threshold=sam_options.dino_box_threshold,
text_threshold=sam_options.dino_text_threshold
)
# detection_boxes = detections.xyxy
# # use full image if no box has been found
# detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes
#
#
# for idx, box in enumerate(detection_boxes):
# box_list = box.tolist()
# if dino_erode_or_dilate != 0:
# box_list[0] -= dino_erode_or_dilate
# box_list[1] -= dino_erode_or_dilate
# box_list[2] += dino_erode_or_dilate
# box_list[3] += dino_erode_or_dilate
# extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
#
# if debug_dino:
# from PIL import ImageDraw, Image
# debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
# draw = ImageDraw.Draw(debug_dino_image)
# for box in extras['sam_prompt']:
# draw.rectangle(box['data'], fill="white")
# return np.array(debug_dino_image)
# TODO add support for dino_erode_or_dilate again
# TODO add dino_debug again
H, W = image.shape[0], image.shape[1]
boxes = boxes * torch.Tensor([W, H, W, H])
@ -107,6 +83,21 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
if boxes.size(0) > 0:
sam_predictor.set_image(image)
for index in range(boxes.size(0)):
assert boxes.size(1) == 4
boxes[index][0] -= sam_options.dino_erode_or_dilate
boxes[index][1] -= sam_options.dino_erode_or_dilate
boxes[index][2] += sam_options.dino_erode_or_dilate
boxes[index][3] += sam_options.dino_erode_or_dilate
if sam_options.dino_debug:
from PIL import ImageDraw, Image
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
draw = ImageDraw.Draw(debug_dino_image)
for box in boxes.numpy():
draw.rectangle(box.tolist(), fill="white")
return np.array(debug_dino_image)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,