From 757863c023843ad61ea69dbec072a97b875cbc78 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 10 Jun 2024 22:42:35 +0200 Subject: [PATCH] feat: wrap sam model in model patcher for predict --- experiments_mask_generation.py | 4 +- extras/inpaint_mask.py | 9 +- extras/sam/predictor.py | 288 +++++++++++++++++++++++++++++++++ 3 files changed, 293 insertions(+), 8 deletions(-) create mode 100644 extras/sam/predictor.py diff --git a/experiments_mask_generation.py b/experiments_mask_generation.py index 538ad712..8e32c29b 100644 --- a/experiments_mask_generation.py +++ b/experiments_mask_generation.py @@ -12,9 +12,9 @@ sam_options = SAMOptions( dino_prompt='eye', dino_box_threshold=0.3, dino_text_threshold=0.25, - box_erode_or_dilate=0, + dino_erode_or_dilate=0, + dino_debug=False, max_num_boxes=2, - sam_checkpoint="./models/sam/sam_vit_l.safetensors", model_type="vit_l" ) diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 3b4d1cb6..2fd776d8 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,7 +1,8 @@ import numpy as np import torch +from extras.sam.predictor import SamPredictor from rembg import remove, new_session -from segment_anything import sam_model_registry, SamPredictor +from segment_anything import sam_model_registry from segment_anything.utils.amg import remove_small_regions from extras.GroundingDINO.util.inference import default_groundingdino @@ -97,12 +98,8 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] - # TODO add model patcher for model logic and device management - device = "cuda" if torch.cuda.is_available() else "cpu" - sam_checkpoint = modules.config.download_sam_model(sam_options.model_type) sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint) - sam.to(device=device) sam_predictor = SamPredictor(sam) final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) @@ -114,7 +111,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras= masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, - boxes=transformed_boxes.to(device), + boxes=transformed_boxes, multimask_output=False, ) diff --git a/extras/sam/predictor.py b/extras/sam/predictor.py new file mode 100644 index 00000000..337c549b --- /dev/null +++ b/extras/sam/predictor.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from ldm_patched.modules import model_management +from ldm_patched.modules.model_patcher import ModelPatcher + +from segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from segment_anything.utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + model: Sam, + load_device=model_management.text_encoder_device(), + offload_device=model_management.text_encoder_offload_device() + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + model (Sam): The model to use for mask prediction. + """ + super().__init__() + + self.load_device = load_device + self.offload_device = offload_device + # can't use model.half() here as slow_conv2d_cpu is not implemented for half + model.to(self.offload_device) + + self.patcher = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) + + self.transform = ResizeLongestSide(model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.patcher.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.load_device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.patcher.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.patcher.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + model_management.load_model_gpu(self.patcher) + input_image = self.patcher.model.preprocess(transformed_image.to(self.load_device)) + self.features = self.patcher.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.load_device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.load_device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.load_device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.load_device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords.to(self.load_device), point_labels.to(self.load_device)) + else: + points = None + + # load + if boxes is not None: + boxes = boxes.to(self.load_device) + if mask_input is not None: + mask_input = mask_input.to(self.load_device) + model_management.load_model_gpu(self.patcher) + + # Embed prompts + sparse_embeddings, dense_embeddings = self.patcher.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.patcher.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.patcher.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.patcher.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.patcher.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.patcher.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None