From 228a0aaeea3495c0e17d84199559d201855e7b9a Mon Sep 17 00:00:00 2001 From: rayronvictor Date: Thu, 25 Jan 2024 16:45:02 -0300 Subject: [PATCH] Add SAM support --- .../config/GroundingDINO_SwinT_OGC.py | 43 +++++++++++++++++++ extras/inpaint_mask.py | 43 ++++++++++++++++++- modules/config.py | 6 +++ modules/flags.py | 4 +- requirements_versions.txt | 3 +- webui.py | 22 +++++++--- 6 files changed, 113 insertions(+), 8 deletions(-) create mode 100644 extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py diff --git a/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py new file mode 100644 index 00000000..c2e64680 --- /dev/null +++ b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py @@ -0,0 +1,43 @@ +batch_size = 1 +modelname = "groundingdino" +backbone = "swin_T_224_1k" +position_embedding = "sine" +pe_temperatureH = 20 +pe_temperatureW = 20 +return_interm_indices = [1, 2, 3] +backbone_freeze_keywords = None +enc_layers = 6 +dec_layers = 6 +pre_norm = False +dim_feedforward = 2048 +hidden_dim = 256 +dropout = 0.0 +nheads = 8 +num_queries = 900 +query_dim = 4 +num_patterns = 0 +num_feature_levels = 4 +enc_n_points = 4 +dec_n_points = 4 +two_stage_type = "standard" +two_stage_bbox_embed_share = False +two_stage_class_embed_share = False +transformer_activation = "relu" +dec_pred_bbox_embed_share = True +dn_box_noise_scale = 1.0 +dn_label_noise_ratio = 0.5 +dn_label_coef = 1.0 +dn_bbox_coef = 1.0 +embed_init_tgt = True +dn_labelbook_size = 2000 +max_text_len = 256 +text_encoder_type = "bert-base-uncased" +use_text_enhancer = True +use_fusion_layer = True +use_checkpoint = True +use_transformer_ckpt = True +use_text_cross_attention = True +text_dropout = 0.0 +fusion_dropout = 0.0 +fusion_droppath = 0.1 +sub_sentence_present = True \ No newline at end of file diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 08d3c3cc..a83c9f05 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,4 +1,39 @@ +from PIL import Image +import numpy as np +import torch from rembg import remove, new_session +from groundingdino.util.inference import Model as GroundingDinoModel + +from modules.model_loader import load_file_from_url +from modules.config import path_inpaint + +config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +groundingdino_model = None + + +def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): + + global groundingdino_model + + if groundingdino_model is None: + filename = load_file_from_url( + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + model_dir=path_inpaint) + groundingdino_model = GroundingDinoModel(model_config_path=config_file, model_checkpoint_path=filename, device=device) + + + # run grounding dino model + boxes, _ = groundingdino_model.predict_with_caption( + image=np.array(input_image), + caption=text_prompt, + box_threshold=box_threshold, + text_threshold=text_threshold + ) + + return boxes.xyxy def generate_mask_from_image(image, mask_model, extras): @@ -8,9 +43,15 @@ def generate_mask_from_image(image, mask_model, extras): if 'image' in image: image = image['image'] + if mask_model == 'sam': + boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) + extras['sam_prompt'] = [] + for idx, box in enumerate(boxes): + extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}] + return remove( image, - session=new_session(mask_model), + session=new_session(mask_model, **extras), only_mask=True, **extras ) diff --git a/modules/config.py b/modules/config.py index 228b9f54..b2b5915a 100644 --- a/modules/config.py +++ b/modules/config.py @@ -330,6 +330,12 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default( validator=lambda x: x in modules.flags.inpaint_mask_cloth_category ) +default_inpaint_mask_sam_model = get_config_item_or_set_default( + key='default_inpaint_mask_sam_model', + default_value='sam_vit_b_01ec64', + validator=lambda x: x in modules.flags.inpaint_mask_sam_model +) + config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] possible_preset_keys = [ diff --git a/modules/flags.py b/modules/flags.py index 7e621ec2..45c73008 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -36,11 +36,13 @@ inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] performance_selections = ['Speed', 'Quality', 'Extreme Speed'] inpaint_mask_models = [ - 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime' + 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam' ] inpaint_mask_cloth_category = ['full', 'upper', 'lower'] +inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_h_4b8939', 'sam_vit_l_0b3195'] + inpaint_option_default = 'Inpaint or Outpaint (default)' inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' inpaint_option_modify = 'Modify Content (add objects, change background, etc.)' diff --git a/requirements_versions.txt b/requirements_versions.txt index 9cca8fca..100d5f50 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -16,4 +16,5 @@ opencv-contrib-python==4.8.0.74 httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 -rembg==2.0.53 \ No newline at end of file +rembg==2.0.53 +groundingdino-py==0.4.0 \ No newline at end of file diff --git a/webui.py b/webui.py index 2004e0dc..622f5af5 100644 --- a/webui.py +++ b/webui.py @@ -210,24 +210,36 @@ with shared.gradio_root: choices=flags.inpaint_mask_cloth_category, value=modules.config.default_inpaint_mask_cloth_category, visible=False) + inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False) + with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options: + inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model) + inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False) + inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05) + inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) generate_mask_button = gr.Button(value='Generate mask from image', visible=False) - def generate_mask(image, mask_model, cloth_category): + def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold): from extras.inpaint_mask import generate_mask_from_image - return generate_mask_from_image(image, mask_model, extras={"cloth_category": cloth_category}) + return generate_mask_from_image(image, mask_model, + extras={"cloth_category": cloth_category, "sam_prompt_text": sam_prompt_text, "sam_model": sam_model, "sam_quant": sam_quant, "box_threshold": box_threshold, "text_threshold": text_threshold}) generate_mask_button.click(fn=generate_mask, inputs=[ inpaint_input_image, inpaint_mask_model, - cloth_category + cloth_category, + inpaint_mask_sam_prompt_text, + inpaint_mask_sam_model, + inpaint_mask_sam_quant, + inpaint_mask_box_threshold, + inpaint_mask_text_threshold ], outputs=inpaint_mask_image) - inpaint_mask_model.change(lambda x: gr.update(visible=x == 'u2net_cloth_seg'), + inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, - outputs=cloth_category, + outputs=[cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options], queue=False, show_progress=False) with gr.TabItem(label='Describe') as desc_tab: with gr.Row():