feat: add model patching

automatically unload model when not needed anymore
This commit is contained in:
Manuel Schmid 2024-01-26 02:05:38 +01:00
parent 860ac91b16
commit 90be73a6df
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 106 additions and 18 deletions

View File

@ -0,0 +1,104 @@
from typing import Tuple, List
import ldm_patched.modules.model_management as model_management
from ldm_patched.modules.model_patcher import ModelPatcher
from modules.config import path_inpaint
from modules.model_loader import load_file_from_url
import numpy as np
import supervision as sv
import torch
from groundingdino.util.inference import Model
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap
class GroundingDinoModel(Model):
def __init__(self):
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
self.model = None
self.load_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.offload_device = torch.device('cpu')
self.dtype = torch.float32
def predict_with_caption(
self,
image: np.ndarray,
caption: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25
) -> Tuple[sv.Detections, List[str]]:
if self.model is None:
filename = load_file_from_url(
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
file_name='groundingdino_swint_ogc.pth',
model_dir=path_inpaint)
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
model.to(self.offload_device)
if model_management.should_use_fp16(device=self.load_device):
model.half()
self.dtype = torch.float16
self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
model_management.load_model_gpu(self.model)
processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.load_device)
source_h, source_w, _ = image.shape
detections = GroundingDinoModel.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
return detections, phrases
def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
# override to use model wrapped by patcher
model = model.model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]
return boxes, logits.max(dim=1)[0], phrases
default_groundingdino = GroundingDinoModel().predict_with_caption

View File

@ -2,31 +2,15 @@ from PIL import Image
import numpy as np
import torch
from rembg import remove, new_session
from groundingdino.util.inference import Model as GroundingDinoModel
from extras.GroundingDINO.util.inference import default_groundingdino
from modules.model_loader import load_file_from_url
from modules.config import path_inpaint
config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
groundingdino_model = None
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
global groundingdino_model
if groundingdino_model is None:
filename = load_file_from_url(
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
model_dir=path_inpaint)
groundingdino_model = GroundingDinoModel(model_config_path=config_file, model_checkpoint_path=filename, device=device)
# run grounding dino model
boxes, _ = groundingdino_model.predict_with_caption(
boxes, _ = default_groundingdino(
image=np.array(input_image),
caption=text_prompt,
box_threshold=box_threshold,