feat: add box_erode_or_dilate to generate_mask_from_image, expose more sam return values

This commit is contained in:
Manuel Schmid 2024-06-09 18:45:13 +02:00
parent 9affa32583
commit ff9fa6c837
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 21 additions and 7 deletions

View File

@ -25,7 +25,7 @@ class GroundingDinoModel(Model):
caption: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25
) -> Tuple[sv.Detections, List[str]]:
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
if self.model is None:
filename = load_file_from_url(
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
@ -56,7 +56,7 @@ class GroundingDinoModel(Model):
source_w=source_w,
boxes=boxes,
logits=logits)
return detections, phrases
return detections, boxes, logits, phrases
def predict(

View File

@ -10,17 +10,17 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
# run grounding dino model
boxes, _ = default_groundingdino(
detections, _, _, _ = default_groundingdino(
image=np.array(input_image),
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
return boxes.xyxy
return detections.xyxy
def generate_mask_from_image(image, mask_model, extras):
def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int=0):
if image is None:
return
@ -28,15 +28,29 @@ def generate_mask_from_image(image, mask_model, extras):
image = image['image']
if mask_model == 'sam':
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
img = Image.fromarray(image)
boxes = run_grounded_sam(img, extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
# use full image if no box has been found
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
extras['sam_prompt'] = []
# from PIL import ImageDraw
# draw = ImageDraw.Draw(img)
for idx, box in enumerate(boxes):
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
box_list = box.tolist()
if box_erode_or_dilate != 0:
box_list[0] -= box_erode_or_dilate
box_list[1] -= box_erode_or_dilate
box_list[2] += box_erode_or_dilate
box_list[3] += box_erode_or_dilate
# draw.rectangle(box_list, fill=128, outline ="red")
extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
# img.show()
return remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
# post_process_mask=True,
**extras
)