diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 2fd776d8..a4f0e7c4 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,13 +1,12 @@ +import modules.config import numpy as np import torch +from extras.GroundingDINO.util.inference import default_groundingdino from extras.sam.predictor import SamPredictor from rembg import remove, new_session from segment_anything import sam_model_registry from segment_anything.utils.amg import remove_small_regions -from extras.GroundingDINO.util.inference import default_groundingdino -import modules.config - class SAMOptions: def __init__(self, @@ -26,6 +25,7 @@ class SAMOptions: self.dino_box_threshold = dino_box_threshold self.dino_text_threshold = dino_text_threshold self.dino_erode_or_dilate = dino_erode_or_dilate + self.dino_debug = dino_debug self.max_num_boxes = max_num_boxes self.model_type = model_type @@ -68,30 +68,6 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= box_threshold=sam_options.dino_box_threshold, text_threshold=sam_options.dino_text_threshold ) - # detection_boxes = detections.xyxy - # # use full image if no box has been found - # detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes - # - # - # for idx, box in enumerate(detection_boxes): - # box_list = box.tolist() - # if dino_erode_or_dilate != 0: - # box_list[0] -= dino_erode_or_dilate - # box_list[1] -= dino_erode_or_dilate - # box_list[2] += dino_erode_or_dilate - # box_list[3] += dino_erode_or_dilate - # extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}] - # - # if debug_dino: - # from PIL import ImageDraw, Image - # debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") - # draw = ImageDraw.Draw(debug_dino_image) - # for box in extras['sam_prompt']: - # draw.rectangle(box['data'], fill="white") - # return np.array(debug_dino_image) - - # TODO add support for dino_erode_or_dilate again - # TODO add dino_debug again H, W = image.shape[0], image.shape[1] boxes = boxes * torch.Tensor([W, H, W, H]) @@ -107,6 +83,21 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= if boxes.size(0) > 0: sam_predictor.set_image(image) + for index in range(boxes.size(0)): + assert boxes.size(1) == 4 + boxes[index][0] -= sam_options.dino_erode_or_dilate + boxes[index][1] -= sam_options.dino_erode_or_dilate + boxes[index][2] += sam_options.dino_erode_or_dilate + boxes[index][3] += sam_options.dino_erode_or_dilate + + if sam_options.dino_debug: + from PIL import ImageDraw, Image + debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") + draw = ImageDraw.Draw(debug_dino_image) + for box in boxes.numpy(): + draw.rectangle(box.tolist(), fill="white") + return np.array(debug_dino_image) + transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) masks, _, _ = sam_predictor.predict_torch( point_coords=None,