Add SAM support

This commit is contained in:
rayronvictor 2024-01-25 16:45:02 -03:00
parent f0fb9783e1
commit 228a0aaeea
6 changed files with 113 additions and 8 deletions

View File

@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View File

@ -1,4 +1,39 @@
from PIL import Image
import numpy as np
import torch
from rembg import remove, new_session
from groundingdino.util.inference import Model as GroundingDinoModel
from modules.model_loader import load_file_from_url
from modules.config import path_inpaint
config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
groundingdino_model = None
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
global groundingdino_model
if groundingdino_model is None:
filename = load_file_from_url(
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
model_dir=path_inpaint)
groundingdino_model = GroundingDinoModel(model_config_path=config_file, model_checkpoint_path=filename, device=device)
# run grounding dino model
boxes, _ = groundingdino_model.predict_with_caption(
image=np.array(input_image),
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
return boxes.xyxy
def generate_mask_from_image(image, mask_model, extras):
@ -8,9 +43,15 @@ def generate_mask_from_image(image, mask_model, extras):
if 'image' in image:
image = image['image']
if mask_model == 'sam':
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
extras['sam_prompt'] = []
for idx, box in enumerate(boxes):
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
return remove(
image,
session=new_session(mask_model),
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)

View File

@ -330,6 +330,12 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default(
validator=lambda x: x in modules.flags.inpaint_mask_cloth_category
)
default_inpaint_mask_sam_model = get_config_item_or_set_default(
key='default_inpaint_mask_sam_model',
default_value='sam_vit_b_01ec64',
validator=lambda x: x in modules.flags.inpaint_mask_sam_model
)
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
possible_preset_keys = [

View File

@ -36,11 +36,13 @@ inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
performance_selections = ['Speed', 'Quality', 'Extreme Speed']
inpaint_mask_models = [
'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime'
'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam'
]
inpaint_mask_cloth_category = ['full', 'upper', 'lower']
inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_h_4b8939', 'sam_vit_l_0b3195']
inpaint_option_default = 'Inpaint or Outpaint (default)'
inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)'
inpaint_option_modify = 'Modify Content (add objects, change background, etc.)'

View File

@ -16,4 +16,5 @@ opencv-contrib-python==4.8.0.74
httpx==0.24.1
onnxruntime==1.16.3
timm==0.9.2
rembg==2.0.53
rembg==2.0.53
groundingdino-py==0.4.0

View File

@ -210,24 +210,36 @@ with shared.gradio_root:
choices=flags.inpaint_mask_cloth_category,
value=modules.config.default_inpaint_mask_cloth_category,
visible=False)
inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False)
with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options:
inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model)
inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False)
inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05)
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
generate_mask_button = gr.Button(value='Generate mask from image', visible=False)
def generate_mask(image, mask_model, cloth_category):
def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold):
from extras.inpaint_mask import generate_mask_from_image
return generate_mask_from_image(image, mask_model, extras={"cloth_category": cloth_category})
return generate_mask_from_image(image, mask_model,
extras={"cloth_category": cloth_category, "sam_prompt_text": sam_prompt_text, "sam_model": sam_model, "sam_quant": sam_quant, "box_threshold": box_threshold, "text_threshold": text_threshold})
generate_mask_button.click(fn=generate_mask,
inputs=[
inpaint_input_image, inpaint_mask_model,
cloth_category
cloth_category,
inpaint_mask_sam_prompt_text,
inpaint_mask_sam_model,
inpaint_mask_sam_quant,
inpaint_mask_box_threshold,
inpaint_mask_text_threshold
],
outputs=inpaint_mask_image)
inpaint_mask_model.change(lambda x: gr.update(visible=x == 'u2net_cloth_seg'),
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
outputs=cloth_category,
outputs=[cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options],
queue=False, show_progress=False)
with gr.TabItem(label='Describe') as desc_tab:
with gr.Row():