feat: add box_erode_or_dilate to generate_mask_from_image, expose more sam return values
This commit is contained in:
parent
9affa32583
commit
ff9fa6c837
|
|
@ -25,7 +25,7 @@ class GroundingDinoModel(Model):
|
|||
caption: str,
|
||||
box_threshold: float = 0.35,
|
||||
text_threshold: float = 0.25
|
||||
) -> Tuple[sv.Detections, List[str]]:
|
||||
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
|
||||
if self.model is None:
|
||||
filename = load_file_from_url(
|
||||
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||
|
|
@ -56,7 +56,7 @@ class GroundingDinoModel(Model):
|
|||
source_w=source_w,
|
||||
boxes=boxes,
|
||||
logits=logits)
|
||||
return detections, phrases
|
||||
return detections, boxes, logits, phrases
|
||||
|
||||
|
||||
def predict(
|
||||
|
|
|
|||
|
|
@ -10,17 +10,17 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
|
||||
|
||||
# run grounding dino model
|
||||
boxes, _ = default_groundingdino(
|
||||
detections, _, _, _ = default_groundingdino(
|
||||
image=np.array(input_image),
|
||||
caption=text_prompt,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold
|
||||
)
|
||||
|
||||
return boxes.xyxy
|
||||
return detections.xyxy
|
||||
|
||||
|
||||
def generate_mask_from_image(image, mask_model, extras):
|
||||
def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int=0):
|
||||
if image is None:
|
||||
return
|
||||
|
||||
|
|
@ -28,15 +28,29 @@ def generate_mask_from_image(image, mask_model, extras):
|
|||
image = image['image']
|
||||
|
||||
if mask_model == 'sam':
|
||||
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
|
||||
img = Image.fromarray(image)
|
||||
boxes = run_grounded_sam(img, extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
|
||||
# use full image if no box has been found
|
||||
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
|
||||
|
||||
extras['sam_prompt'] = []
|
||||
# from PIL import ImageDraw
|
||||
# draw = ImageDraw.Draw(img)
|
||||
for idx, box in enumerate(boxes):
|
||||
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
|
||||
box_list = box.tolist()
|
||||
if box_erode_or_dilate != 0:
|
||||
box_list[0] -= box_erode_or_dilate
|
||||
box_list[1] -= box_erode_or_dilate
|
||||
box_list[2] += box_erode_or_dilate
|
||||
box_list[3] += box_erode_or_dilate
|
||||
# draw.rectangle(box_list, fill=128, outline ="red")
|
||||
extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
|
||||
# img.show()
|
||||
|
||||
return remove(
|
||||
image,
|
||||
session=new_session(mask_model, **extras),
|
||||
only_mask=True,
|
||||
# post_process_mask=True,
|
||||
**extras
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue