refactor: rename box_erode_or_dilate to dino_erode_or_dilate, add option dino_debug

This commit is contained in:
Manuel Schmid 2024-06-10 20:47:07 +02:00
parent b8578a080a
commit 651f9c5cfd
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 12 additions and 9 deletions

View File

@ -14,7 +14,8 @@ class SAMOptions:
dino_prompt: str = '',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
box_erode_or_dilate=0,
dino_erode_or_dilate=0,
dino_debug=False,
# SAM
max_num_boxes=2,
@ -23,7 +24,7 @@ class SAMOptions:
self.dino_prompt = dino_prompt
self.dino_box_threshold = dino_box_threshold
self.dino_text_threshold = dino_text_threshold
self.box_erode_or_dilate = box_erode_or_dilate
self.dino_erode_or_dilate = dino_erode_or_dilate
self.max_num_boxes = max_num_boxes
self.model_type = model_type
@ -73,11 +74,11 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
#
# for idx, box in enumerate(detection_boxes):
# box_list = box.tolist()
# if box_erode_or_dilate != 0:
# box_list[0] -= box_erode_or_dilate
# box_list[1] -= box_erode_or_dilate
# box_list[2] += box_erode_or_dilate
# box_list[3] += box_erode_or_dilate
# if dino_erode_or_dilate != 0:
# box_list[0] -= dino_erode_or_dilate
# box_list[1] -= dino_erode_or_dilate
# box_list[2] += dino_erode_or_dilate
# box_list[3] += dino_erode_or_dilate
# extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
#
# if debug_dino:
@ -88,7 +89,8 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
# draw.rectangle(box['data'], fill="white")
# return np.array(debug_dino_image)
# TODO add support for box_erode_or_dilate again
# TODO add support for dino_erode_or_dilate again
# TODO add dino_debug again
H, W = image.shape[0], image.shape[1]
boxes = boxes * torch.Tensor([W, H, W, H])

View File

@ -243,7 +243,8 @@ with shared.gradio_root:
dino_prompt=dino_prompt_text,
dino_box_threshold=box_threshold,
dino_text_threshold=text_threshold,
box_erode_or_dilate=dino_erode_or_dilate,
dino_erode_or_dilate=dino_erode_or_dilate,
dino_debug=debug_dino,
max_num_boxes=2, #TODO replace with actual value
model_type=sam_model
)