43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
from PIL import Image
|
|
import numpy as np
|
|
import torch
|
|
from rembg import remove, new_session
|
|
from extras.GroundingDINO.util.inference import default_groundingdino
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
|
|
|
|
# run grounding dino model
|
|
boxes, _ = default_groundingdino(
|
|
image=np.array(input_image),
|
|
caption=text_prompt,
|
|
box_threshold=box_threshold,
|
|
text_threshold=text_threshold
|
|
)
|
|
|
|
return boxes.xyxy
|
|
|
|
|
|
def generate_mask_from_image(image, mask_model, extras):
|
|
if image is None:
|
|
return
|
|
|
|
if 'image' in image:
|
|
image = image['image']
|
|
|
|
if mask_model == 'sam':
|
|
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
|
|
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
|
|
extras['sam_prompt'] = []
|
|
for idx, box in enumerate(boxes):
|
|
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
|
|
|
|
return remove(
|
|
image,
|
|
session=new_session(mask_model, **extras),
|
|
only_mask=True,
|
|
**extras
|
|
)
|