diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py index 259094f2..bc8b6429 100644 --- a/extras/GroundingDINO/util/inference.py +++ b/extras/GroundingDINO/util/inference.py @@ -25,7 +25,7 @@ class GroundingDinoModel(Model): caption: str, box_threshold: float = 0.35, text_threshold: float = 0.25 - ) -> Tuple[sv.Detections, List[str]]: + ) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]: if self.model is None: filename = load_file_from_url( url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", @@ -56,7 +56,7 @@ class GroundingDinoModel(Model): source_w=source_w, boxes=boxes, logits=logits) - return detections, phrases + return detections, boxes, logits, phrases def predict( diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 4999f258..ea6e8819 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -10,17 +10,17 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): # run grounding dino model - boxes, _ = default_groundingdino( + detections, _, _, _ = default_groundingdino( image=np.array(input_image), caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) - return boxes.xyxy + return detections.xyxy -def generate_mask_from_image(image, mask_model, extras): +def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int=0): if image is None: return @@ -28,15 +28,29 @@ def generate_mask_from_image(image, mask_model, extras): image = image['image'] if mask_model == 'sam': - boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) + img = Image.fromarray(image) + boxes = run_grounded_sam(img, extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) + # use full image if no box has been found boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes + extras['sam_prompt'] = [] + # from PIL import ImageDraw + # draw = ImageDraw.Draw(img) for idx, box in enumerate(boxes): - extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}] + box_list = box.tolist() + if box_erode_or_dilate != 0: + box_list[0] -= box_erode_or_dilate + box_list[1] -= box_erode_or_dilate + box_list[2] += box_erode_or_dilate + box_list[3] += box_erode_or_dilate + # draw.rectangle(box_list, fill=128, outline ="red") + extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}] + # img.show() return remove( image, session=new_session(mask_model, **extras), only_mask=True, + # post_process_mask=True, **extras )