121 lines
3.8 KiB
Python
121 lines
3.8 KiB
Python
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from segment_anything.utils.amg import remove_small_regions
|
|
|
|
from extras.GroundingDINO.util.inference import default_groundingdino
|
|
from extras.adetailer.args import ADetailerArgs
|
|
from extras.adetailer.script import get_ad_model
|
|
from extras.adetailer.script import pred_preprocessing
|
|
from extras.adetailer.ultralytics_predict import ultralytics_predict
|
|
from extras.inpaint_mask import run_grounded_sam, generate_mask_from_image
|
|
|
|
original_image1 = cv2.imread('cat.webp')
|
|
original_image = Image.fromarray(original_image1)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# predictor = ultralytics_predict
|
|
#
|
|
# ad_model = get_ad_model('face_yolov8n.pt')
|
|
#
|
|
# kwargs = {}
|
|
# kwargs["device"] = torch.device('cpu')
|
|
# kwargs["classes"] = ""
|
|
#
|
|
# img2 = Image.fromarray(img)
|
|
# pred = predictor(ad_model, img2, **kwargs)
|
|
#
|
|
# if pred.preview is None:
|
|
# print('[ADetailer] nothing detected on image')
|
|
#
|
|
# args = ADetailerArgs()
|
|
#
|
|
# masks = pred_preprocessing(img, pred, args)
|
|
# merged_masks = np.maximum(*[np.array(mask) for mask in masks])
|
|
#
|
|
#
|
|
# merged_masks_img = Image.fromarray(merged_masks)
|
|
# merged_masks_img.show()
|
|
|
|
sam_prompt = 'eye'
|
|
sam_model = 'sam_vit_l_0b3195'
|
|
dino_box_threshold = 0.3
|
|
dino_text_threshold = 0.25
|
|
box_erode_or_dilate = 0
|
|
|
|
detections, boxes, logits, phrases = default_groundingdino(
|
|
image=np.array(original_image),
|
|
caption=sam_prompt,
|
|
box_threshold=dino_box_threshold,
|
|
text_threshold=dino_text_threshold
|
|
)
|
|
|
|
# for boxes.xyxy
|
|
#boxes = run_grounded_sam(img, sam_prompt, box_threshold=dino_box_threshold, text_threshold=dino_text_threshold)
|
|
#boxes = np.array([[0, 0, img.shape[1], img.shape[0]]]) if len(boxes) == 0 else boxes
|
|
|
|
# from PIL import ImageDraw
|
|
# draw = ImageDraw.Draw(img)
|
|
# for idx, box in enumerate(boxes.xyxy):
|
|
# box_list = box.tolist()
|
|
# if box_erode_or_dilate != 0:
|
|
# box_list[0] -= box_erode_or_dilate
|
|
# box_list[1] -= box_erode_or_dilate
|
|
# box_list[2] += box_erode_or_dilate
|
|
# box_list[3] += box_erode_or_dilate
|
|
# draw.rectangle(box_list, fill=128, outline ="red")
|
|
# img.show()
|
|
|
|
H, W = original_image.size[1], original_image.size[0]
|
|
boxes = boxes * torch.Tensor([W, H, W, H])
|
|
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
|
|
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
|
|
|
|
|
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
|
|
|
sam_checkpoint = "./models/sam/sam_vit_l_0b3195.pth"
|
|
model_type = "vit_l"
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
sam.to(device=device)
|
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
|
num_boxes = 2
|
|
|
|
sam_predictor = SamPredictor(sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device))
|
|
|
|
image_np = np.array(original_image, dtype=np.uint8)
|
|
|
|
final_m = torch.zeros((image_np.shape[0], image_np.shape[1]))
|
|
|
|
if boxes.size(0) > 0:
|
|
sam_predictor.set_image(image_np)
|
|
|
|
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
|
|
masks, _, _ = sam_predictor.predict_torch(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
boxes=transformed_boxes.to(device),
|
|
multimask_output=False,
|
|
)
|
|
|
|
# remove small disconnected regions and holes
|
|
fine_masks = []
|
|
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
|
|
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
|
|
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
|
|
masks = torch.from_numpy(masks)
|
|
|
|
num_obj = min(len(logits), num_boxes)
|
|
for obj_ind in range(num_obj):
|
|
# box = boxes[obj_ind]
|
|
|
|
m = masks[obj_ind][0]
|
|
final_m += m
|
|
final_m = (final_m > 0).to('cpu').numpy()
|
|
# print(final_m.max(), final_m.min())
|
|
mask_image = np.array(np.dstack((final_m, final_m, final_m)) * 255, dtype=np.uint8)
|
|
|
|
merged_masks_img = Image.fromarray(mask_image)
|
|
merged_masks_img.show()
|