feat: add box_erode_or_dilate to generate_mask_from_image, expose more sam return values
This commit is contained in:
parent
9affa32583
commit
ff9fa6c837
|
|
@ -25,7 +25,7 @@ class GroundingDinoModel(Model):
|
||||||
caption: str,
|
caption: str,
|
||||||
box_threshold: float = 0.35,
|
box_threshold: float = 0.35,
|
||||||
text_threshold: float = 0.25
|
text_threshold: float = 0.25
|
||||||
) -> Tuple[sv.Detections, List[str]]:
|
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
filename = load_file_from_url(
|
filename = load_file_from_url(
|
||||||
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||||
|
|
@ -56,7 +56,7 @@ class GroundingDinoModel(Model):
|
||||||
source_w=source_w,
|
source_w=source_w,
|
||||||
boxes=boxes,
|
boxes=boxes,
|
||||||
logits=logits)
|
logits=logits)
|
||||||
return detections, phrases
|
return detections, boxes, logits, phrases
|
||||||
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
|
|
|
||||||
|
|
@ -10,17 +10,17 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
|
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
|
||||||
|
|
||||||
# run grounding dino model
|
# run grounding dino model
|
||||||
boxes, _ = default_groundingdino(
|
detections, _, _, _ = default_groundingdino(
|
||||||
image=np.array(input_image),
|
image=np.array(input_image),
|
||||||
caption=text_prompt,
|
caption=text_prompt,
|
||||||
box_threshold=box_threshold,
|
box_threshold=box_threshold,
|
||||||
text_threshold=text_threshold
|
text_threshold=text_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
return boxes.xyxy
|
return detections.xyxy
|
||||||
|
|
||||||
|
|
||||||
def generate_mask_from_image(image, mask_model, extras):
|
def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int=0):
|
||||||
if image is None:
|
if image is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -28,15 +28,29 @@ def generate_mask_from_image(image, mask_model, extras):
|
||||||
image = image['image']
|
image = image['image']
|
||||||
|
|
||||||
if mask_model == 'sam':
|
if mask_model == 'sam':
|
||||||
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
|
img = Image.fromarray(image)
|
||||||
|
boxes = run_grounded_sam(img, extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
|
||||||
|
# use full image if no box has been found
|
||||||
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
|
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
|
||||||
|
|
||||||
extras['sam_prompt'] = []
|
extras['sam_prompt'] = []
|
||||||
|
# from PIL import ImageDraw
|
||||||
|
# draw = ImageDraw.Draw(img)
|
||||||
for idx, box in enumerate(boxes):
|
for idx, box in enumerate(boxes):
|
||||||
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
|
box_list = box.tolist()
|
||||||
|
if box_erode_or_dilate != 0:
|
||||||
|
box_list[0] -= box_erode_or_dilate
|
||||||
|
box_list[1] -= box_erode_or_dilate
|
||||||
|
box_list[2] += box_erode_or_dilate
|
||||||
|
box_list[3] += box_erode_or_dilate
|
||||||
|
# draw.rectangle(box_list, fill=128, outline ="red")
|
||||||
|
extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
|
||||||
|
# img.show()
|
||||||
|
|
||||||
return remove(
|
return remove(
|
||||||
image,
|
image,
|
||||||
session=new_session(mask_model, **extras),
|
session=new_session(mask_model, **extras),
|
||||||
only_mask=True,
|
only_mask=True,
|
||||||
|
# post_process_mask=True,
|
||||||
**extras
|
**extras
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue