wip: remove ultralytics, always use manual sam for image mask instead of rembg

This commit is contained in:
Manuel Schmid 2024-06-10 01:33:03 +02:00
parent 09e23f5509
commit 8a81993940
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 146 additions and 155 deletions

View File

@ -1,120 +1,24 @@
import cv2
# https://github.com/sail-sg/EditAnything/blob/main/sam2groundingdino_edit.py
import numpy as np
import torch
from PIL import Image
from segment_anything.utils.amg import remove_small_regions
from extras.GroundingDINO.util.inference import default_groundingdino
from extras.adetailer.args import ADetailerArgs
from extras.adetailer.script import get_ad_model
from extras.adetailer.script import pred_preprocessing
from extras.adetailer.ultralytics_predict import ultralytics_predict
from extras.inpaint_mask import run_grounded_sam, generate_mask_from_image
from extras.inpaint_mask import SAMOptions, generate_mask_from_image
original_image1 = cv2.imread('cat.webp')
original_image = Image.fromarray(original_image1)
device = "cuda" if torch.cuda.is_available() else "cpu"
original_image = Image.open('cat.webp')
image = np.array(original_image, dtype=np.uint8)
# predictor = ultralytics_predict
#
# ad_model = get_ad_model('face_yolov8n.pt')
#
# kwargs = {}
# kwargs["device"] = torch.device('cpu')
# kwargs["classes"] = ""
#
# img2 = Image.fromarray(img)
# pred = predictor(ad_model, img2, **kwargs)
#
# if pred.preview is None:
# print('[ADetailer] nothing detected on image')
#
# args = ADetailerArgs()
#
# masks = pred_preprocessing(img, pred, args)
# merged_masks = np.maximum(*[np.array(mask) for mask in masks])
#
#
# merged_masks_img = Image.fromarray(merged_masks)
# merged_masks_img.show()
sam_prompt = 'eye'
sam_model = 'sam_vit_l_0b3195'
dino_box_threshold = 0.3
dino_text_threshold = 0.25
box_erode_or_dilate = 0
detections, boxes, logits, phrases = default_groundingdino(
image=np.array(original_image),
caption=sam_prompt,
box_threshold=dino_box_threshold,
text_threshold=dino_text_threshold
sam_options = SAMOptions(
dino_prompt='eye',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
box_erode_or_dilate=0,
max_num_boxes=2,
sam_checkpoint="./models/sam/sam_vit_l.safetensors",
model_type="vit_l"
)
# for boxes.xyxy
#boxes = run_grounded_sam(img, sam_prompt, box_threshold=dino_box_threshold, text_threshold=dino_text_threshold)
#boxes = np.array([[0, 0, img.shape[1], img.shape[0]]]) if len(boxes) == 0 else boxes
# from PIL import ImageDraw
# draw = ImageDraw.Draw(img)
# for idx, box in enumerate(boxes.xyxy):
# box_list = box.tolist()
# if box_erode_or_dilate != 0:
# box_list[0] -= box_erode_or_dilate
# box_list[1] -= box_erode_or_dilate
# box_list[2] += box_erode_or_dilate
# box_list[3] += box_erode_or_dilate
# draw.rectangle(box_list, fill=128, outline ="red")
# img.show()
H, W = original_image.size[1], original_image.size[0]
boxes = boxes * torch.Tensor([W, H, W, H])
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "./models/sam/sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
num_boxes = 2
sam_predictor = SamPredictor(sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device))
image_np = np.array(original_image, dtype=np.uint8)
final_m = torch.zeros((image_np.shape[0], image_np.shape[1]))
if boxes.size(0) > 0:
sam_predictor.set_image(image_np)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(device),
multimask_output=False,
)
# remove small disconnected regions and holes
fine_masks = []
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
masks = torch.from_numpy(masks)
num_obj = min(len(logits), num_boxes)
for obj_ind in range(num_obj):
# box = boxes[obj_ind]
m = masks[obj_ind][0]
final_m += m
final_m = (final_m > 0).to('cpu').numpy()
# print(final_m.max(), final_m.min())
mask_image = np.array(np.dstack((final_m, final_m, final_m)) * 255, dtype=np.uint8)
mask_image = generate_mask_from_image(image, sam_options=sam_options)
merged_masks_img = Image.fromarray(mask_image)
merged_masks_img.show()

View File

@ -1,47 +1,129 @@
import numpy as np
import torch
from rembg import remove, new_session
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.amg import remove_small_regions
from extras.GroundingDINO.util.inference import default_groundingdino
def generate_mask_from_image(image: np.ndarray, mask_model: str, extras: dict, box_erode_or_dilate: int=0, debug_dino: bool=False) -> np.ndarray | None:
class SAMOptions:
def __init__(self,
# GroundingDINO
dino_prompt: str = '',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
box_erode_or_dilate=0,
# SAM
max_num_boxes=2,
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth",
model_type="vit_l"
):
self.dino_prompt = dino_prompt
self.dino_box_threshold = dino_box_threshold
self.dino_text_threshold = dino_text_threshold
self.box_erode_or_dilate = box_erode_or_dilate
self.max_num_boxes = max_num_boxes
self.sam_checkpoint = sam_checkpoint
self.model_type = model_type
def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
"""
removes small disconnected regions and holes
"""
fine_masks = []
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
return torch.from_numpy(masks)
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None:
if image is None:
return
if extras is None:
extras = {}
if 'image' in image:
image = image['image']
if mask_model == 'sam':
detections, _, _, _ = default_groundingdino(
image=image,
caption=extras['sam_prompt_text'],
box_threshold=extras['box_threshold'],
text_threshold=extras['text_threshold']
if mask_model != 'sam' and sam_options is None:
return remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)
detection_boxes = detections.xyxy
# use full image if no box has been found
detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes
extras['sam_prompt'] = []
for idx, box in enumerate(detection_boxes):
box_list = box.tolist()
if box_erode_or_dilate != 0:
box_list[0] -= box_erode_or_dilate
box_list[1] -= box_erode_or_dilate
box_list[2] += box_erode_or_dilate
box_list[3] += box_erode_or_dilate
extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
assert sam_options is not None
if debug_dino:
from PIL import ImageDraw, Image
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
draw = ImageDraw.Draw(debug_dino_image)
for box in extras['sam_prompt']:
draw.rectangle(box['data'], fill="white")
return np.array(debug_dino_image)
return remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
detections, boxes, logits, phrases = default_groundingdino(
image=image,
caption=sam_options.dino_prompt,
box_threshold=sam_options.dino_box_threshold,
text_threshold=sam_options.dino_text_threshold
)
# detection_boxes = detections.xyxy
# # use full image if no box has been found
# detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes
#
#
# for idx, box in enumerate(detection_boxes):
# box_list = box.tolist()
# if box_erode_or_dilate != 0:
# box_list[0] -= box_erode_or_dilate
# box_list[1] -= box_erode_or_dilate
# box_list[2] += box_erode_or_dilate
# box_list[3] += box_erode_or_dilate
# extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
#
# if debug_dino:
# from PIL import ImageDraw, Image
# debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
# draw = ImageDraw.Draw(debug_dino_image)
# for box in extras['sam_prompt']:
# draw.rectangle(box['data'], fill="white")
# return np.array(debug_dino_image)
# TODO add support for box_erode_or_dilate again
H, W = image.shape[0], image.shape[1]
boxes = boxes * torch.Tensor([W, H, W, H])
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
# TODO add model patcher for model logic and device management
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_options.sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
if boxes.size(0) > 0:
sam_predictor.set_image(image)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(device),
multimask_output=False,
)
masks = optimize_masks(masks)
num_obj = min(len(logits), sam_options.max_num_boxes)
for obj_ind in range(num_obj):
mask_tensor = masks[obj_ind][0]
final_mask_tensor += mask_tensor
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
mask_image = np.array(mask_image, dtype=np.uint8)
return mask_image

View File

@ -20,5 +20,4 @@ timm==0.9.2
translators==5.8.9
rembg==2.0.53
groundingdino-py==0.4.0
ultralytics==8.2.28
segment_anything==1.0

View File

@ -16,6 +16,7 @@ import modules.meta_parser
import args_manager
import copy
import launch
from extras.inpaint_mask import SAMOptions
from modules.sdxl_styles import legal_style_names
from modules.private_logger import get_current_html_path
@ -223,7 +224,7 @@ with shared.gradio_root:
choices=flags.inpaint_mask_cloth_category,
value=modules.config.default_inpaint_mask_cloth_category,
visible=False)
inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False, info='Use singular whenever possible')
inpaint_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False, info='Use singular whenever possible')
with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options:
inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model)
inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False)
@ -231,24 +232,29 @@ with shared.gradio_root:
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
generate_mask_button = gr.Button(value='Generate mask from image')
def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold, debug_dino, dino_erode_or_dilate):
def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, sam_quant, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino):
from extras.inpaint_mask import generate_mask_from_image
extras = {}
sam_options = None
if mask_model == 'u2net_cloth_seg':
extras['cloth_category'] = cloth_category
elif mask_model == 'sam':
extras['sam_prompt_text'] = sam_prompt_text
extras['sam_model'] = sam_model
extras['sam_quant'] = sam_quant
extras['box_threshold'] = box_threshold
extras['text_threshold'] = text_threshold
sam_options = SAMOptions(
dino_prompt=dino_prompt_text,
dino_box_threshold=box_threshold,
dino_text_threshold=text_threshold,
box_erode_or_dilate=dino_erode_or_dilate,
max_num_boxes=2, #TODO replace with actual value
sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", # TODO replace with actual value
model_type="vit_l"
)
return generate_mask_from_image(image, mask_model, extras, dino_erode_or_dilate, debug_dino)
return generate_mask_from_image(image, mask_model, extras, sam_options)
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
outputs=[inpaint_mask_cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options],
outputs=[inpaint_mask_cloth_category, inpaint_mask_dino_prompt_text, inpaint_mask_advanced_options],
queue=False, show_progress=False)
with gr.TabItem(label='Describe') as desc_tab:
@ -737,9 +743,9 @@ with shared.gradio_root:
generate_mask_button.click(fn=generate_mask,
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
inpaint_mask_sam_prompt_text, inpaint_mask_sam_model, inpaint_mask_sam_quant,
inpaint_mask_box_threshold, inpaint_mask_text_threshold, debug_dino,
dino_erode_or_dilate],
inpaint_mask_dino_prompt_text, inpaint_mask_sam_model, inpaint_mask_sam_quant,
inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate,
debug_dino],
outputs=inpaint_mask_image, show_progress=True, queue=True)
ctrls = [currentTask, generate_image_grid]