diff --git a/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py new file mode 100644 index 00000000..c2e64680 --- /dev/null +++ b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py @@ -0,0 +1,43 @@ +batch_size = 1 +modelname = "groundingdino" +backbone = "swin_T_224_1k" +position_embedding = "sine" +pe_temperatureH = 20 +pe_temperatureW = 20 +return_interm_indices = [1, 2, 3] +backbone_freeze_keywords = None +enc_layers = 6 +dec_layers = 6 +pre_norm = False +dim_feedforward = 2048 +hidden_dim = 256 +dropout = 0.0 +nheads = 8 +num_queries = 900 +query_dim = 4 +num_patterns = 0 +num_feature_levels = 4 +enc_n_points = 4 +dec_n_points = 4 +two_stage_type = "standard" +two_stage_bbox_embed_share = False +two_stage_class_embed_share = False +transformer_activation = "relu" +dec_pred_bbox_embed_share = True +dn_box_noise_scale = 1.0 +dn_label_noise_ratio = 0.5 +dn_label_coef = 1.0 +dn_bbox_coef = 1.0 +embed_init_tgt = True +dn_labelbook_size = 2000 +max_text_len = 256 +text_encoder_type = "bert-base-uncased" +use_text_enhancer = True +use_fusion_layer = True +use_checkpoint = True +use_transformer_ckpt = True +use_text_cross_attention = True +text_dropout = 0.0 +fusion_dropout = 0.0 +fusion_droppath = 0.1 +sub_sentence_present = True \ No newline at end of file diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py new file mode 100644 index 00000000..259094f2 --- /dev/null +++ b/extras/GroundingDINO/util/inference.py @@ -0,0 +1,98 @@ +from typing import Tuple, List + +import ldm_patched.modules.model_management as model_management +from ldm_patched.modules.model_patcher import ModelPatcher +from modules.config import path_inpaint +from modules.model_loader import load_file_from_url + +import numpy as np +import supervision as sv +import torch +from groundingdino.util.inference import Model +from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap + + +class GroundingDinoModel(Model): + def __init__(self): + self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' + self.model = None + self.load_device = torch.device('cpu') + self.offload_device = torch.device('cpu') + + def predict_with_caption( + self, + image: np.ndarray, + caption: str, + box_threshold: float = 0.35, + text_threshold: float = 0.25 + ) -> Tuple[sv.Detections, List[str]]: + if self.model is None: + filename = load_file_from_url( + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + file_name='groundingdino_swint_ogc.pth', + model_dir=path_inpaint) + model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename) + + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + + model.to(self.offload_device) + + self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) + + model_management.load_model_gpu(self.model) + + processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device) + boxes, logits, phrases = predict( + model=self.model, + image=processed_image, + caption=caption, + box_threshold=box_threshold, + text_threshold=text_threshold, + device=self.load_device) + source_h, source_w, _ = image.shape + detections = GroundingDinoModel.post_process_result( + source_h=source_h, + source_w=source_w, + boxes=boxes, + logits=logits) + return detections, phrases + + +def predict( + model, + image: torch.Tensor, + caption: str, + box_threshold: float, + text_threshold: float, + device: str = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + caption = preprocess_caption(caption=caption) + + # override to use model wrapped by patcher + model = model.model.to(device) + image = image.to(device) + + with torch.no_grad(): + outputs = model(image[None], captions=[caption]) + + prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) + prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) + + mask = prediction_logits.max(dim=1)[0] > box_threshold + logits = prediction_logits[mask] # logits.shape = (n, 256) + boxes = prediction_boxes[mask] # boxes.shape = (n, 4) + + tokenizer = model.tokenizer + tokenized = tokenizer(caption) + + phrases = [ + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') + for logit + in logits + ] + + return boxes, logits.max(dim=1)[0], phrases + + +default_groundingdino = GroundingDinoModel().predict_with_caption diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py index 08d3c3cc..dfcb90a9 100644 --- a/extras/inpaint_mask.py +++ b/extras/inpaint_mask.py @@ -1,4 +1,23 @@ +from PIL import Image +import numpy as np +import torch from rembg import remove, new_session +from extras.GroundingDINO.util.inference import default_groundingdino + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): + + # run grounding dino model + boxes, _ = default_groundingdino( + image=np.array(input_image), + caption=text_prompt, + box_threshold=box_threshold, + text_threshold=text_threshold + ) + + return boxes.xyxy def generate_mask_from_image(image, mask_model, extras): @@ -8,9 +27,15 @@ def generate_mask_from_image(image, mask_model, extras): if 'image' in image: image = image['image'] + if mask_model == 'sam': + boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) + extras['sam_prompt'] = [] + for idx, box in enumerate(boxes): + extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}] + return remove( image, - session=new_session(mask_model), + session=new_session(mask_model, **extras), only_mask=True, **extras ) diff --git a/launch.py b/launch.py index e98045f6..63a38767 100644 --- a/launch.py +++ b/launch.py @@ -22,8 +22,9 @@ from build_launcher import build_launcher from modules.launch_util import is_installed, run, python, run_pip, requirements_met from modules.model_loader import load_file_from_url from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \ - checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads + path_inpaint, checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads +os.environ["U2NET_HOME"] = path_inpaint REINSTALL_ALL = False TRY_INSTALL_XFORMERS = False diff --git a/modules/config.py b/modules/config.py index 1010fe6a..07787015 100644 --- a/modules/config.py +++ b/modules/config.py @@ -396,6 +396,12 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default( validator=lambda x: x in modules.flags.inpaint_mask_cloth_category ) +default_inpaint_mask_sam_model = get_config_item_or_set_default( + key='default_inpaint_mask_sam_model', + default_value='sam_vit_b_01ec64', + validator=lambda x: x in modules.flags.inpaint_mask_sam_model +) + config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))] # mapping config to meta parameter diff --git a/modules/flags.py b/modules/flags.py index 7279dba5..06f52784 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -48,11 +48,13 @@ performance_selections = [ output_formats = ['png', 'jpg', 'webp'] inpaint_mask_models = [ - 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime' + 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam' ] inpaint_mask_cloth_category = ['full', 'upper', 'lower'] +inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_h_4b8939', 'sam_vit_l_0b3195'] + inpaint_option_default = 'Inpaint or Outpaint (default)' inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' inpaint_option_modify = 'Modify Content (add objects, change background, etc.)' diff --git a/readme.md b/readme.md index 90c37508..d5045849 100644 --- a/readme.md +++ b/readme.md @@ -27,7 +27,7 @@ Included adjustments: * ✨ https://github.com/lllyasviel/Fooocus/pull/1938 - automatically describe image on uov image upload if prompt is empty * ✨ https://github.com/lllyasviel/Fooocus/pull/1940 - meta data handling, schemes: Fooocus (json) and A1111 (plain text). Compatible with Civitai. * ✨ https://github.com/lllyasviel/Fooocus/pull/1979 - prevent outdated history log link after midnight -* ✨ https://github.com/lllyasviel/Fooocus/pull/2032 - add inpaint mask generation functionality using rembg +* ✨ https://github.com/lllyasviel/Fooocus/pull/2032 - add inpaint mask generation functionality using rembg, incl. segmentation support ✨ = new feature
🐛 = bugfix
diff --git a/requirements_versions.txt b/requirements_versions.txt index 7b2a37e8..ebcd0297 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -18,4 +18,5 @@ httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 translators==5.8.9 -rembg==2.0.53 \ No newline at end of file +rembg==2.0.53 +groundingdino-py==0.4.0 \ No newline at end of file diff --git a/webui.py b/webui.py index bacd380f..8579c9df 100644 --- a/webui.py +++ b/webui.py @@ -211,31 +211,52 @@ with shared.gradio_root: with gr.Column(visible=False) as inpaint_mask_generation_col: inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', - height=500, visible=False) + height=500) inpaint_mask_model = gr.Dropdown(label='Mask generation model', choices=flags.inpaint_mask_models, - value=modules.config.default_inpaint_mask_model, - visible=False) + value=modules.config.default_inpaint_mask_model) inpaint_mask_cloth_category = gr.Dropdown(label='Cloth category', choices=flags.inpaint_mask_cloth_category, value=modules.config.default_inpaint_mask_cloth_category, visible=False) - generate_mask_button = gr.Button(value='Generate mask from image', visible=False) + inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False) + with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options: + inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model) + inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False) + inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05) + inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05) + generate_mask_button = gr.Button(value='Generate mask from image') - def generate_mask(image, mask_model, cloth_category): + def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold): from extras.inpaint_mask import generate_mask_from_image - return generate_mask_from_image(image, mask_model, {"cloth_category": cloth_category}) + + extras = {} + if mask_model == 'u2net_cloth_seg': + extras['cloth_category'] = cloth_category + elif mask_model == 'sam': + extras['sam_prompt_text'] = sam_prompt_text + extras['sam_model'] = sam_model + extras['sam_quant'] = sam_quant + extras['box_threshold'] = box_threshold + extras['text_threshold'] = text_threshold + + return generate_mask_from_image(image, mask_model, extras) generate_mask_button.click(fn=generate_mask, inputs=[ inpaint_input_image, inpaint_mask_model, - inpaint_mask_cloth_category + inpaint_mask_cloth_category, + inpaint_mask_sam_prompt_text, + inpaint_mask_sam_model, + inpaint_mask_sam_quant, + inpaint_mask_box_threshold, + inpaint_mask_text_threshold ], - outputs=inpaint_mask_image) + outputs=inpaint_mask_image, show_progress=True, queue=True) - inpaint_mask_model.change(lambda x: gr.update(visible=x == 'u2net_cloth_seg'), + inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')], inputs=inpaint_mask_model, - outputs=inpaint_mask_cloth_category, + outputs=[inpaint_mask_cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options], queue=False, show_progress=False) with gr.TabItem(label='Describe') as desc_tab: @@ -518,10 +539,9 @@ with shared.gradio_root: inpaint_strength, inpaint_respective_field, inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate] - inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 4, + inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 2, inputs=inpaint_mask_upload_checkbox, - outputs=[inpaint_mask_image, generate_mask_button, - inpaint_mask_model, inpaint_mask_generation_col], + outputs=[inpaint_mask_image, inpaint_mask_generation_col], queue=False, show_progress=False) with gr.Tab(label='FreeU'):