From 228a0aaeea3495c0e17d84199559d201855e7b9a Mon Sep 17 00:00:00 2001 From: rayronvictor Date: Thu, 25 Jan 2024 16:45:02 -0300 Subject: [PATCH 1/6] Add SAM support --- .../config/GroundingDINO_SwinT_OGC.py | 43 +++++++++++++++++++ extras/inpaint_mask.py | 43 ++++++++++++++++++- modules/config.py | 6 +++ modules/flags.py | 4 +- requirements_versions.txt | 3 +- webui.py | 22 +++++++--- 6 files changed, 113 insertions(+), 8 deletions(-) create mode 100644 extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py diff --git a/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py new file mode 100644 index 00000000..c2e64680 --- /dev/null +++ b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py @@ -0,0 +1,43 @@ +batch_size = 1 +modelname = "groundingdino" +backbone = "swin_T_224_1k" +position_embedding = "sine" +pe_temperatureH = 20 +pe_temperatureW = 20 +return_interm_indices = [1, 2, 3] +backbone_freeze_keywords = None +enc_layers = 6 +dec_layers = 6 +pre_norm = False +dim_feedforward = 2048 +hidden_dim = 256 +dropout = 0.0 +nheads = 8 +num_queries = 900 +query_dim = 4 +num_patterns = 0 +num_feature_levels = 4 +enc_n_points = 4 +dec_n_points = 4 +two_stage_type = "standard" +two_stage_bbox_embed_share = False +two_stage_class_embed_share = False +transformer_activation = "relu" +dec_pred_bbox_embed_share = True +dn_box_noise_scale = 1.0 +dn_label_noise_ratio = 0.5 +dn_label_coef = 1.0 +dn_bbox_coef = 1.0 +embed_init_tgt = True +dn_labelbook_size = 2000 +max_text_len = 256 +text_encoder_type = "bert-base-uncased" +use_text_enhancer = True +use_fusion_layer = True +use_checkpoint = True +use_transformer_ckpt = True +use_text_cross_attention = True +text_dropout = 0.0 +fusion_dropout = 0.0 +fusion_droppath = 0.1 +sub_sentence_present = True \ No newline at end of file diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 08d3c3cc..a83c9f05 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,4 +1,39 @@ +from PIL import Image +import numpy as np +import torch from rembg import remove, new_session +from groundingdino.util.inference import Model as GroundingDinoModel + +from modules.model_loader import load_file_from_url +from modules.config import path_inpaint + +config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +groundingdino_model = None + + +def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): + + global groundingdino_model + + if groundingdino_model is None: + filename = load_file_from_url( + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + model_dir=path_inpaint) + groundingdino_model = GroundingDinoModel(model_config_path=config_file, model_checkpoint_path=filename, device=device) + + + # run grounding dino model + boxes, _ = groundingdino_model.predict_with_caption( + image=np.array(input_image), + caption=text_prompt, + box_threshold=box_threshold, + text_threshold=text_threshold + ) + + return boxes.xyxy def generate_mask_from_image(image, mask_model, extras): @@ -8,9 +43,15 @@ def generate_mask_from_image(image, mask_model, extras): if 'image' in image: image = image['image'] + if mask_model == 'sam': + boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) + extras['sam_prompt'] = [] + for idx, box in enumerate(boxes): + extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}] + return remove( image, - session=new_session(mask_model), + session=new_session(mask_model, **extras), only_mask=True, **extras ) diff --git a/modules/config.py b/modules/config.py index 228b9f54..b2b5915a 100644 --- a/modules/config.py +++ b/modules/config.py @@ -330,6 +330,12 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default( validator=lambda x: x in modules.flags.inpaint_mask_cloth_category ) +default_inpaint_mask_sam_model = get_config_item_or_set_default( + key='default_inpaint_mask_sam_model', + default_value='sam_vit_b_01ec64', + validator=lambda x: x in modules.flags.inpaint_mask_sam_model +) + config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] possible_preset_keys = [ diff --git a/modules/flags.py b/modules/flags.py index 7e621ec2..45c73008 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -36,11 +36,13 @@ inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] performance_selections = ['Speed', 'Quality', 'Extreme Speed'] inpaint_mask_models = [ - 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime' + 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam' ] inpaint_mask_cloth_category = ['full', 'upper', 'lower'] +inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_h_4b8939', 'sam_vit_l_0b3195'] + inpaint_option_default = 'Inpaint or Outpaint (default)' inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' inpaint_option_modify = 'Modify Content (add objects, change background, etc.)' diff --git a/requirements_versions.txt b/requirements_versions.txt index 9cca8fca..100d5f50 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -16,4 +16,5 @@ opencv-contrib-python==4.8.0.74 httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 -rembg==2.0.53 \ No newline at end of file +rembg==2.0.53 +groundingdino-py==0.4.0 \ No newline at end of file diff --git a/webui.py b/webui.py index 2004e0dc..622f5af5 100644 --- a/webui.py +++ b/webui.py @@ -210,24 +210,36 @@ with shared.gradio_root: choices=flags.inpaint_mask_cloth_category, value=modules.config.default_inpaint_mask_cloth_category, visible=False) + inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False) + with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options: + inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model) + inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False) + inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05) + inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) generate_mask_button = gr.Button(value='Generate mask from image', visible=False) - def generate_mask(image, mask_model, cloth_category): + def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold): from extras.inpaint_mask import generate_mask_from_image - return generate_mask_from_image(image, mask_model, extras={"cloth_category": cloth_category}) + return generate_mask_from_image(image, mask_model, + extras={"cloth_category": cloth_category, "sam_prompt_text": sam_prompt_text, "sam_model": sam_model, "sam_quant": sam_quant, "box_threshold": box_threshold, "text_threshold": text_threshold}) generate_mask_button.click(fn=generate_mask, inputs=[ inpaint_input_image, inpaint_mask_model, - cloth_category + cloth_category, + inpaint_mask_sam_prompt_text, + inpaint_mask_sam_model, + inpaint_mask_sam_quant, + inpaint_mask_box_threshold, + inpaint_mask_text_threshold ], outputs=inpaint_mask_image) - inpaint_mask_model.change(lambda x: gr.update(visible=x == 'u2net_cloth_seg'), + inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, - outputs=cloth_category, + outputs=[cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options], queue=False, show_progress=False) with gr.TabItem(label='Describe') as desc_tab: with gr.Row(): From 90be73a6df86d467d83da157c253af9cae919dcf Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 26 Jan 2024 02:05:38 +0100 Subject: [PATCH 2/6] feat: add model patching automatically unload model when not needed anymore --- extras/GroundingDINO/util/inference.py | 104 +++++++++++++++++++++++++ extras/inpaint_mask.py | 20 +---- 2 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 extras/GroundingDINO/util/inference.py diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py new file mode 100644 index 00000000..e4b723a5 --- /dev/null +++ b/extras/GroundingDINO/util/inference.py @@ -0,0 +1,104 @@ +from typing import Tuple, List + +import ldm_patched.modules.model_management as model_management +from ldm_patched.modules.model_patcher import ModelPatcher +from modules.config import path_inpaint +from modules.model_loader import load_file_from_url + +import numpy as np +import supervision as sv +import torch +from groundingdino.util.inference import Model +from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap + + +class GroundingDinoModel(Model): + def __init__(self): + self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' + self.model = None + self.load_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.offload_device = torch.device('cpu') + self.dtype = torch.float32 + + def predict_with_caption( + self, + image: np.ndarray, + caption: str, + box_threshold: float = 0.35, + text_threshold: float = 0.25 + ) -> Tuple[sv.Detections, List[str]]: + if self.model is None: + filename = load_file_from_url( + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + file_name='groundingdino_swint_ogc.pth', + model_dir=path_inpaint) + model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename) + + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + model.to(self.offload_device) + + if model_management.should_use_fp16(device=self.load_device): + model.half() + self.dtype = torch.float16 + + self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) + + model_management.load_model_gpu(self.model) + + processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device) + boxes, logits, phrases = predict( + model=self.model, + image=processed_image, + caption=caption, + box_threshold=box_threshold, + text_threshold=text_threshold, + device=self.load_device) + source_h, source_w, _ = image.shape + detections = GroundingDinoModel.post_process_result( + source_h=source_h, + source_w=source_w, + boxes=boxes, + logits=logits) + return detections, phrases + + +def predict( + model, + image: torch.Tensor, + caption: str, + box_threshold: float, + text_threshold: float, + device: str = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + caption = preprocess_caption(caption=caption) + + # override to use model wrapped by patcher + model = model.model.to(device) + image = image.to(device) + + with torch.no_grad(): + outputs = model(image[None], captions=[caption]) + + prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) + prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) + + mask = prediction_logits.max(dim=1)[0] > box_threshold + logits = prediction_logits[mask] # logits.shape = (n, 256) + boxes = prediction_boxes[mask] # boxes.shape = (n, 4) + + tokenizer = model.tokenizer + tokenized = tokenizer(caption) + + phrases = [ + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') + for logit + in logits + ] + + return boxes, logits.max(dim=1)[0], phrases + + +default_groundingdino = GroundingDinoModel().predict_with_caption diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index a83c9f05..dfcb90a9 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -2,31 +2,15 @@ from PIL import Image import numpy as np import torch from rembg import remove, new_session -from groundingdino.util.inference import Model as GroundingDinoModel +from extras.GroundingDINO.util.inference import default_groundingdino -from modules.model_loader import load_file_from_url -from modules.config import path_inpaint - -config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -groundingdino_model = None - - def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): - global groundingdino_model - - if groundingdino_model is None: - filename = load_file_from_url( - url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", - model_dir=path_inpaint) - groundingdino_model = GroundingDinoModel(model_config_path=config_file, model_checkpoint_path=filename, device=device) - - # run grounding dino model - boxes, _ = groundingdino_model.predict_with_caption( + boxes, _ = default_groundingdino( image=np.array(input_image), caption=text_prompt, box_threshold=box_threshold, From d515d0f0746c76bd58e138ec882063c1ba768960 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 26 Jan 2024 02:06:17 +0100 Subject: [PATCH 3/6] fix: remove unnecessary fp32 / fp16 handling --- extras/GroundingDINO/util/inference.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py index e4b723a5..3f61e946 100644 --- a/extras/GroundingDINO/util/inference.py +++ b/extras/GroundingDINO/util/inference.py @@ -18,7 +18,6 @@ class GroundingDinoModel(Model): self.model = None self.load_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.offload_device = torch.device('cpu') - self.dtype = torch.float32 def predict_with_caption( self, @@ -36,14 +35,9 @@ class GroundingDinoModel(Model): self.load_device = model_management.text_encoder_device() self.offload_device = model_management.text_encoder_offload_device() - self.dtype = torch.float32 model.to(self.offload_device) - if model_management.should_use_fp16(device=self.load_device): - model.half() - self.dtype = torch.float16 - self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) model_management.load_model_gpu(self.model) From 338004c2e5677ae9e77ec8966511ee379da3a7d3 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 26 Jan 2024 02:08:27 +0100 Subject: [PATCH 4/6] feat: optimize gradio element visibility changes and data provisioning --- webui.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/webui.py b/webui.py index b2fcf857..36face51 100644 --- a/webui.py +++ b/webui.py @@ -201,11 +201,10 @@ with shared.gradio_root: with gr.Column(visible=False) as inpaint_mask_generation_col: inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', - height=500, visible=False) + height=500) inpaint_mask_model = gr.Dropdown(label='Mask generation model', choices=flags.inpaint_mask_models, - value=modules.config.default_inpaint_mask_model, - visible=False) + value=modules.config.default_inpaint_mask_model) inpaint_mask_cloth_category = gr.Dropdown(label='Cloth category', choices=flags.inpaint_mask_cloth_category, value=modules.config.default_inpaint_mask_cloth_category, @@ -216,14 +215,23 @@ with shared.gradio_root: inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False) inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05) inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) - generate_mask_button = gr.Button(value='Generate mask from image', visible=False) + generate_mask_button = gr.Button(value='Generate mask from image') def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold): from extras.inpaint_mask import generate_mask_from_image - return generate_mask_from_image(image, mask_model, - {"cloth_category": cloth_category, "sam_prompt_text": sam_prompt_text, "sam_model": sam_model, "sam_quant": sam_quant, "box_threshold": box_threshold, "text_threshold": text_threshold}) + extras = {} + if mask_model == 'u2net_cloth_seg': + extras['cloth_category'] = cloth_category + elif mask_model == 'sam': + extras['sam_prompt_text'] = sam_prompt_text + extras['sam_model'] = sam_model + extras['sam_quant'] = sam_quant + extras['box_threshold'] = box_threshold + extras['text_threshold'] = text_threshold + + return generate_mask_from_image(image, mask_model, extras) generate_mask_button.click(fn=generate_mask, inputs=[ @@ -235,7 +243,7 @@ with shared.gradio_root: inpaint_mask_box_threshold, inpaint_mask_text_threshold ], - outputs=inpaint_mask_image) + outputs=inpaint_mask_image, show_progress=True, queue=True) inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, @@ -478,10 +486,9 @@ with shared.gradio_root: inpaint_strength, inpaint_respective_field, inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate] - inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 4, + inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 2, inputs=inpaint_mask_upload_checkbox, - outputs=[inpaint_mask_image, generate_mask_button, - inpaint_mask_model, inpaint_mask_generation_col], + outputs=[inpaint_mask_image, inpaint_mask_generation_col], queue=False, show_progress=False) with gr.Tab(label='FreeU'): From 62fe86f1e8f85d47dd5fdfdb79d3379c60061f7d Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 26 Jan 2024 02:34:25 +0100 Subject: [PATCH 5/6] chore: always use cpu as default device this is overridden anyways --- extras/GroundingDINO/util/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py index 3f61e946..259094f2 100644 --- a/extras/GroundingDINO/util/inference.py +++ b/extras/GroundingDINO/util/inference.py @@ -16,7 +16,7 @@ class GroundingDinoModel(Model): def __init__(self): self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' self.model = None - self.load_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.load_device = torch.device('cpu') self.offload_device = torch.device('cpu') def predict_with_caption( From dd2fd04fd7f5dd0fee2e683bff5ac2aab6c4c32c Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 26 Jan 2024 11:17:58 +0100 Subject: [PATCH 6/6] feat: set U2NET_HOME env var to path_inpaint previously was using the user dir, see https://github.com/danielgatis/rembg/blob/49d1686f65b71b72df1c3b420551406d13074426/rembg/sessions/base.py#L78 --- launch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index e98045f6..63a38767 100644 --- a/launch.py +++ b/launch.py @@ -22,8 +22,9 @@ from build_launcher import build_launcher from modules.launch_util import is_installed, run, python, run_pip, requirements_met from modules.model_loader import load_file_from_url from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \ - checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads + path_inpaint, checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads +os.environ["U2NET_HOME"] = path_inpaint REINSTALL_ALL = False TRY_INSTALL_XFORMERS = False