101 lines
3.5 KiB
Python
101 lines
3.5 KiB
Python
from typing import Tuple, List
|
|
|
|
import ldm_patched.modules.model_management as model_management
|
|
from ldm_patched.modules.model_patcher import ModelPatcher
|
|
from modules.config import path_inpaint
|
|
from modules.model_loader import load_file_from_url
|
|
|
|
import numpy as np
|
|
import supervision as sv
|
|
import torch
|
|
from groundingdino.util.inference import Model
|
|
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap
|
|
|
|
|
|
class GroundingDinoModel(Model):
|
|
def __init__(self):
|
|
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
|
|
self.model = None
|
|
self.load_device = torch.device('cpu')
|
|
self.offload_device = torch.device('cpu')
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def predict_with_caption(
|
|
self,
|
|
image: np.ndarray,
|
|
caption: str,
|
|
box_threshold: float = 0.35,
|
|
text_threshold: float = 0.25
|
|
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
|
|
if self.model is None:
|
|
filename = load_file_from_url(
|
|
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
|
file_name='groundingdino_swint_ogc.pth',
|
|
model_dir=path_inpaint)
|
|
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)
|
|
|
|
self.load_device = model_management.text_encoder_device()
|
|
self.offload_device = model_management.text_encoder_offload_device()
|
|
|
|
model.to(self.offload_device)
|
|
|
|
self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
|
|
|
|
model_management.load_model_gpu(self.model)
|
|
|
|
processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
|
|
boxes, logits, phrases = predict(
|
|
model=self.model,
|
|
image=processed_image,
|
|
caption=caption,
|
|
box_threshold=box_threshold,
|
|
text_threshold=text_threshold,
|
|
device=self.load_device)
|
|
source_h, source_w, _ = image.shape
|
|
detections = GroundingDinoModel.post_process_result(
|
|
source_h=source_h,
|
|
source_w=source_w,
|
|
boxes=boxes,
|
|
logits=logits)
|
|
return detections, boxes, logits, phrases
|
|
|
|
|
|
def predict(
|
|
model,
|
|
image: torch.Tensor,
|
|
caption: str,
|
|
box_threshold: float,
|
|
text_threshold: float,
|
|
device: str = "cuda"
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
|
caption = preprocess_caption(caption=caption)
|
|
|
|
# override to use model wrapped by patcher
|
|
model = model.model.to(device)
|
|
image = image.to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(image[None], captions=[caption])
|
|
|
|
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
|
|
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
|
|
|
|
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
|
logits = prediction_logits[mask] # logits.shape = (n, 256)
|
|
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
|
|
|
|
tokenizer = model.tokenizer
|
|
tokenized = tokenizer(caption)
|
|
|
|
phrases = [
|
|
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
|
for logit
|
|
in logits
|
|
]
|
|
|
|
return boxes, logits.max(dim=1)[0], phrases
|
|
|
|
|
|
default_groundingdino = GroundingDinoModel().predict_with_caption
|