feat: add dino_erode_or_dilate and dino_debug again
This commit is contained in:
parent
757863c023
commit
29967d3a18
|
|
@ -1,13 +1,12 @@
|
|||
import modules.config
|
||||
import numpy as np
|
||||
import torch
|
||||
from extras.GroundingDINO.util.inference import default_groundingdino
|
||||
from extras.sam.predictor import SamPredictor
|
||||
from rembg import remove, new_session
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything.utils.amg import remove_small_regions
|
||||
|
||||
from extras.GroundingDINO.util.inference import default_groundingdino
|
||||
import modules.config
|
||||
|
||||
|
||||
class SAMOptions:
|
||||
def __init__(self,
|
||||
|
|
@ -26,6 +25,7 @@ class SAMOptions:
|
|||
self.dino_box_threshold = dino_box_threshold
|
||||
self.dino_text_threshold = dino_text_threshold
|
||||
self.dino_erode_or_dilate = dino_erode_or_dilate
|
||||
self.dino_debug = dino_debug
|
||||
self.max_num_boxes = max_num_boxes
|
||||
self.model_type = model_type
|
||||
|
||||
|
|
@ -68,30 +68,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
box_threshold=sam_options.dino_box_threshold,
|
||||
text_threshold=sam_options.dino_text_threshold
|
||||
)
|
||||
# detection_boxes = detections.xyxy
|
||||
# # use full image if no box has been found
|
||||
# detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes
|
||||
#
|
||||
#
|
||||
# for idx, box in enumerate(detection_boxes):
|
||||
# box_list = box.tolist()
|
||||
# if dino_erode_or_dilate != 0:
|
||||
# box_list[0] -= dino_erode_or_dilate
|
||||
# box_list[1] -= dino_erode_or_dilate
|
||||
# box_list[2] += dino_erode_or_dilate
|
||||
# box_list[3] += dino_erode_or_dilate
|
||||
# extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
|
||||
#
|
||||
# if debug_dino:
|
||||
# from PIL import ImageDraw, Image
|
||||
# debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
|
||||
# draw = ImageDraw.Draw(debug_dino_image)
|
||||
# for box in extras['sam_prompt']:
|
||||
# draw.rectangle(box['data'], fill="white")
|
||||
# return np.array(debug_dino_image)
|
||||
|
||||
# TODO add support for dino_erode_or_dilate again
|
||||
# TODO add dino_debug again
|
||||
|
||||
H, W = image.shape[0], image.shape[1]
|
||||
boxes = boxes * torch.Tensor([W, H, W, H])
|
||||
|
|
@ -107,6 +83,21 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
|
|||
if boxes.size(0) > 0:
|
||||
sam_predictor.set_image(image)
|
||||
|
||||
for index in range(boxes.size(0)):
|
||||
assert boxes.size(1) == 4
|
||||
boxes[index][0] -= sam_options.dino_erode_or_dilate
|
||||
boxes[index][1] -= sam_options.dino_erode_or_dilate
|
||||
boxes[index][2] += sam_options.dino_erode_or_dilate
|
||||
boxes[index][3] += sam_options.dino_erode_or_dilate
|
||||
|
||||
if sam_options.dino_debug:
|
||||
from PIL import ImageDraw, Image
|
||||
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
|
||||
draw = ImageDraw.Draw(debug_dino_image)
|
||||
for box in boxes.numpy():
|
||||
draw.rectangle(box.tolist(), fill="white")
|
||||
return np.array(debug_dino_image)
|
||||
|
||||
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
|
||||
masks, _, _ = sam_predictor.predict_torch(
|
||||
point_coords=None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue