From 8a81993940155e57b699c0e862f14515f28d3061 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 10 Jun 2024 01:33:03 +0200 Subject: [PATCH] wip: remove ultralytics, always use manual sam for image mask instead of rembg --- experiments_mask_generation.py | 124 ++++------------------------ extras/inpaint_mask.py | 146 +++++++++++++++++++++++++-------- requirements_versions.txt | 1 - webui.py | 30 ++++--- 4 files changed, 146 insertions(+), 155 deletions(-) diff --git a/experiments_mask_generation.py b/experiments_mask_generation.py index a28c66e2..538ad712 100644 --- a/experiments_mask_generation.py +++ b/experiments_mask_generation.py @@ -1,120 +1,24 @@ -import cv2 +# https://github.com/sail-sg/EditAnything/blob/main/sam2groundingdino_edit.py + import numpy as np -import torch from PIL import Image -from segment_anything.utils.amg import remove_small_regions -from extras.GroundingDINO.util.inference import default_groundingdino -from extras.adetailer.args import ADetailerArgs -from extras.adetailer.script import get_ad_model -from extras.adetailer.script import pred_preprocessing -from extras.adetailer.ultralytics_predict import ultralytics_predict -from extras.inpaint_mask import run_grounded_sam, generate_mask_from_image +from extras.inpaint_mask import SAMOptions, generate_mask_from_image -original_image1 = cv2.imread('cat.webp') -original_image = Image.fromarray(original_image1) -device = "cuda" if torch.cuda.is_available() else "cpu" +original_image = Image.open('cat.webp') +image = np.array(original_image, dtype=np.uint8) -# predictor = ultralytics_predict -# -# ad_model = get_ad_model('face_yolov8n.pt') -# -# kwargs = {} -# kwargs["device"] = torch.device('cpu') -# kwargs["classes"] = "" -# -# img2 = Image.fromarray(img) -# pred = predictor(ad_model, img2, **kwargs) -# -# if pred.preview is None: -# print('[ADetailer] nothing detected on image') -# -# args = ADetailerArgs() -# -# masks = pred_preprocessing(img, pred, args) -# merged_masks = np.maximum(*[np.array(mask) for mask in masks]) -# -# -# merged_masks_img = Image.fromarray(merged_masks) -# merged_masks_img.show() - -sam_prompt = 'eye' -sam_model = 'sam_vit_l_0b3195' -dino_box_threshold = 0.3 -dino_text_threshold = 0.25 -box_erode_or_dilate = 0 - -detections, boxes, logits, phrases = default_groundingdino( - image=np.array(original_image), - caption=sam_prompt, - box_threshold=dino_box_threshold, - text_threshold=dino_text_threshold +sam_options = SAMOptions( + dino_prompt='eye', + dino_box_threshold=0.3, + dino_text_threshold=0.25, + box_erode_or_dilate=0, + max_num_boxes=2, + sam_checkpoint="./models/sam/sam_vit_l.safetensors", + model_type="vit_l" ) -# for boxes.xyxy -#boxes = run_grounded_sam(img, sam_prompt, box_threshold=dino_box_threshold, text_threshold=dino_text_threshold) -#boxes = np.array([[0, 0, img.shape[1], img.shape[0]]]) if len(boxes) == 0 else boxes - -# from PIL import ImageDraw -# draw = ImageDraw.Draw(img) -# for idx, box in enumerate(boxes.xyxy): -# box_list = box.tolist() -# if box_erode_or_dilate != 0: -# box_list[0] -= box_erode_or_dilate -# box_list[1] -= box_erode_or_dilate -# box_list[2] += box_erode_or_dilate -# box_list[3] += box_erode_or_dilate -# draw.rectangle(box_list, fill=128, outline ="red") -# img.show() - -H, W = original_image.size[1], original_image.size[0] -boxes = boxes * torch.Tensor([W, H, W, H]) -boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 -boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] - - -from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor - -sam_checkpoint = "./models/sam/sam_vit_l_0b3195.pth" -model_type = "vit_l" -sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) -sam.to(device=device) -mask_generator = SamAutomaticMaskGenerator(sam) -num_boxes = 2 - -sam_predictor = SamPredictor(sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)) - -image_np = np.array(original_image, dtype=np.uint8) - -final_m = torch.zeros((image_np.shape[0], image_np.shape[1])) - -if boxes.size(0) > 0: - sam_predictor.set_image(image_np) - - transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2]) - masks, _, _ = sam_predictor.predict_torch( - point_coords=None, - point_labels=None, - boxes=transformed_boxes.to(device), - multimask_output=False, - ) - - # remove small disconnected regions and holes - fine_masks = [] - for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] - fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) - masks = np.stack(fine_masks, axis=0)[:, np.newaxis] - masks = torch.from_numpy(masks) - - num_obj = min(len(logits), num_boxes) - for obj_ind in range(num_obj): - # box = boxes[obj_ind] - - m = masks[obj_ind][0] - final_m += m -final_m = (final_m > 0).to('cpu').numpy() -# print(final_m.max(), final_m.min()) -mask_image = np.array(np.dstack((final_m, final_m, final_m)) * 255, dtype=np.uint8) +mask_image = generate_mask_from_image(image, sam_options=sam_options) merged_masks_img = Image.fromarray(mask_image) merged_masks_img.show() diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 3ee00cf4..85cd7fc5 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,47 +1,129 @@ import numpy as np +import torch from rembg import remove, new_session +from segment_anything import sam_model_registry, SamPredictor +from segment_anything.utils.amg import remove_small_regions + from extras.GroundingDINO.util.inference import default_groundingdino -def generate_mask_from_image(image: np.ndarray, mask_model: str, extras: dict, box_erode_or_dilate: int=0, debug_dino: bool=False) -> np.ndarray | None: +class SAMOptions: + def __init__(self, + # GroundingDINO + dino_prompt: str = '', + dino_box_threshold=0.3, + dino_text_threshold=0.25, + box_erode_or_dilate=0, + + # SAM + max_num_boxes=2, + sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", + model_type="vit_l" + ): + self.dino_prompt = dino_prompt + self.dino_box_threshold = dino_box_threshold + self.dino_text_threshold = dino_text_threshold + self.box_erode_or_dilate = box_erode_or_dilate + self.max_num_boxes = max_num_boxes + self.sam_checkpoint = sam_checkpoint + self.model_type = model_type + + +def optimize_masks(masks: torch.Tensor) -> torch.Tensor: + """ + removes small disconnected regions and holes + """ + fine_masks = [] + for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] + fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) + masks = np.stack(fine_masks, axis=0)[:, np.newaxis] + return torch.from_numpy(masks) + + +def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None, + sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None: if image is None: return + if extras is None: + extras = {} + if 'image' in image: image = image['image'] - if mask_model == 'sam': - detections, _, _, _ = default_groundingdino( - image=image, - caption=extras['sam_prompt_text'], - box_threshold=extras['box_threshold'], - text_threshold=extras['text_threshold'] + if mask_model != 'sam' and sam_options is None: + return remove( + image, + session=new_session(mask_model, **extras), + only_mask=True, + **extras ) - detection_boxes = detections.xyxy - # use full image if no box has been found - detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes - extras['sam_prompt'] = [] - for idx, box in enumerate(detection_boxes): - box_list = box.tolist() - if box_erode_or_dilate != 0: - box_list[0] -= box_erode_or_dilate - box_list[1] -= box_erode_or_dilate - box_list[2] += box_erode_or_dilate - box_list[3] += box_erode_or_dilate - extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}] + assert sam_options is not None - if debug_dino: - from PIL import ImageDraw, Image - debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") - draw = ImageDraw.Draw(debug_dino_image) - for box in extras['sam_prompt']: - draw.rectangle(box['data'], fill="white") - return np.array(debug_dino_image) - - return remove( - image, - session=new_session(mask_model, **extras), - only_mask=True, - **extras + detections, boxes, logits, phrases = default_groundingdino( + image=image, + caption=sam_options.dino_prompt, + box_threshold=sam_options.dino_box_threshold, + text_threshold=sam_options.dino_text_threshold ) + # detection_boxes = detections.xyxy + # # use full image if no box has been found + # detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes + # + # + # for idx, box in enumerate(detection_boxes): + # box_list = box.tolist() + # if box_erode_or_dilate != 0: + # box_list[0] -= box_erode_or_dilate + # box_list[1] -= box_erode_or_dilate + # box_list[2] += box_erode_or_dilate + # box_list[3] += box_erode_or_dilate + # extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}] + # + # if debug_dino: + # from PIL import ImageDraw, Image + # debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") + # draw = ImageDraw.Draw(debug_dino_image) + # for box in extras['sam_prompt']: + # draw.rectangle(box['data'], fill="white") + # return np.array(debug_dino_image) + + # TODO add support for box_erode_or_dilate again + + H, W = image.shape[0], image.shape[1] + boxes = boxes * torch.Tensor([W, H, W, H]) + boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 + boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] + + # TODO add model patcher for model logic and device management + device = "cuda" if torch.cuda.is_available() else "cpu" + + sam = sam_model_registry[sam_options.model_type](checkpoint=sam_options.sam_checkpoint) + sam.to(device=device) + + sam_predictor = SamPredictor(sam) + final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) + + if boxes.size(0) > 0: + sam_predictor.set_image(image) + + transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) + masks, _, _ = sam_predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes.to(device), + multimask_output=False, + ) + + masks = optimize_masks(masks) + + num_obj = min(len(logits), sam_options.max_num_boxes) + for obj_ind in range(num_obj): + mask_tensor = masks[obj_ind][0] + final_mask_tensor += mask_tensor + + final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() + mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255 + mask_image = np.array(mask_image, dtype=np.uint8) + return mask_image diff --git a/requirements_versions.txt b/requirements_versions.txt index 095452b4..bc86caac 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -20,5 +20,4 @@ timm==0.9.2 translators==5.8.9 rembg==2.0.53 groundingdino-py==0.4.0 -ultralytics==8.2.28 segment_anything==1.0 \ No newline at end of file diff --git a/webui.py b/webui.py index a036dfb5..4b63cc2f 100644 --- a/webui.py +++ b/webui.py @@ -16,6 +16,7 @@ import modules.meta_parser import args_manager import copy import launch +from extras.inpaint_mask import SAMOptions from modules.sdxl_styles import legal_style_names from modules.private_logger import get_current_html_path @@ -223,7 +224,7 @@ with shared.gradio_root: choices=flags.inpaint_mask_cloth_category, value=modules.config.default_inpaint_mask_cloth_category, visible=False) - inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False, info='Use singular whenever possible') + inpaint_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False, info='Use singular whenever possible') with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options: inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model) inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False) @@ -231,24 +232,29 @@ with shared.gradio_root: inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) generate_mask_button = gr.Button(value='Generate mask from image') - def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold, debug_dino, dino_erode_or_dilate): + def generate_mask(image, mask_model, cloth_category, dino_prompt_text, sam_model, sam_quant, box_threshold, text_threshold, dino_erode_or_dilate, debug_dino): from extras.inpaint_mask import generate_mask_from_image extras = {} + sam_options = None if mask_model == 'u2net_cloth_seg': extras['cloth_category'] = cloth_category elif mask_model == 'sam': - extras['sam_prompt_text'] = sam_prompt_text - extras['sam_model'] = sam_model - extras['sam_quant'] = sam_quant - extras['box_threshold'] = box_threshold - extras['text_threshold'] = text_threshold + sam_options = SAMOptions( + dino_prompt=dino_prompt_text, + dino_box_threshold=box_threshold, + dino_text_threshold=text_threshold, + box_erode_or_dilate=dino_erode_or_dilate, + max_num_boxes=2, #TODO replace with actual value + sam_checkpoint="./models/sam/sam_vit_l_0b3195.pth", # TODO replace with actual value + model_type="vit_l" + ) - return generate_mask_from_image(image, mask_model, extras, dino_erode_or_dilate, debug_dino) + return generate_mask_from_image(image, mask_model, extras, sam_options) inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, - outputs=[inpaint_mask_cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options], + outputs=[inpaint_mask_cloth_category, inpaint_mask_dino_prompt_text, inpaint_mask_advanced_options], queue=False, show_progress=False) with gr.TabItem(label='Describe') as desc_tab: @@ -737,9 +743,9 @@ with shared.gradio_root: generate_mask_button.click(fn=generate_mask, inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category, - inpaint_mask_sam_prompt_text, inpaint_mask_sam_model, inpaint_mask_sam_quant, - inpaint_mask_box_threshold, inpaint_mask_text_threshold, debug_dino, - dino_erode_or_dilate], + inpaint_mask_dino_prompt_text, inpaint_mask_sam_model, inpaint_mask_sam_quant, + inpaint_mask_box_threshold, inpaint_mask_text_threshold, dino_erode_or_dilate, + debug_dino], outputs=inpaint_mask_image, show_progress=True, queue=True) ctrls = [currentTask, generate_image_grid]