diff --git a/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py
new file mode 100644
index 00000000..c2e64680
--- /dev/null
+++ b/extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py
@@ -0,0 +1,43 @@
+batch_size = 1
+modelname = "groundingdino"
+backbone = "swin_T_224_1k"
+position_embedding = "sine"
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+two_stage_type = "standard"
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+transformer_activation = "relu"
+dec_pred_bbox_embed_share = True
+dn_box_noise_scale = 1.0
+dn_label_noise_ratio = 0.5
+dn_label_coef = 1.0
+dn_bbox_coef = 1.0
+embed_init_tgt = True
+dn_labelbook_size = 2000
+max_text_len = 256
+text_encoder_type = "bert-base-uncased"
+use_text_enhancer = True
+use_fusion_layer = True
+use_checkpoint = True
+use_transformer_ckpt = True
+use_text_cross_attention = True
+text_dropout = 0.0
+fusion_dropout = 0.0
+fusion_droppath = 0.1
+sub_sentence_present = True
\ No newline at end of file
diff --git a/extras/GroundingDINO/util/inference.py b/extras/GroundingDINO/util/inference.py
new file mode 100644
index 00000000..259094f2
--- /dev/null
+++ b/extras/GroundingDINO/util/inference.py
@@ -0,0 +1,98 @@
+from typing import Tuple, List
+
+import ldm_patched.modules.model_management as model_management
+from ldm_patched.modules.model_patcher import ModelPatcher
+from modules.config import path_inpaint
+from modules.model_loader import load_file_from_url
+
+import numpy as np
+import supervision as sv
+import torch
+from groundingdino.util.inference import Model
+from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap
+
+
+class GroundingDinoModel(Model):
+ def __init__(self):
+ self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
+ self.model = None
+ self.load_device = torch.device('cpu')
+ self.offload_device = torch.device('cpu')
+
+ def predict_with_caption(
+ self,
+ image: np.ndarray,
+ caption: str,
+ box_threshold: float = 0.35,
+ text_threshold: float = 0.25
+ ) -> Tuple[sv.Detections, List[str]]:
+ if self.model is None:
+ filename = load_file_from_url(
+ url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
+ file_name='groundingdino_swint_ogc.pth',
+ model_dir=path_inpaint)
+ model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)
+
+ self.load_device = model_management.text_encoder_device()
+ self.offload_device = model_management.text_encoder_offload_device()
+
+ model.to(self.offload_device)
+
+ self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
+
+ model_management.load_model_gpu(self.model)
+
+ processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
+ boxes, logits, phrases = predict(
+ model=self.model,
+ image=processed_image,
+ caption=caption,
+ box_threshold=box_threshold,
+ text_threshold=text_threshold,
+ device=self.load_device)
+ source_h, source_w, _ = image.shape
+ detections = GroundingDinoModel.post_process_result(
+ source_h=source_h,
+ source_w=source_w,
+ boxes=boxes,
+ logits=logits)
+ return detections, phrases
+
+
+def predict(
+ model,
+ image: torch.Tensor,
+ caption: str,
+ box_threshold: float,
+ text_threshold: float,
+ device: str = "cuda"
+) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
+ caption = preprocess_caption(caption=caption)
+
+ # override to use model wrapped by patcher
+ model = model.model.to(device)
+ image = image.to(device)
+
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
+ prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
+
+ mask = prediction_logits.max(dim=1)[0] > box_threshold
+ logits = prediction_logits[mask] # logits.shape = (n, 256)
+ boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
+
+ tokenizer = model.tokenizer
+ tokenized = tokenizer(caption)
+
+ phrases = [
+ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
+ for logit
+ in logits
+ ]
+
+ return boxes, logits.max(dim=1)[0], phrases
+
+
+default_groundingdino = GroundingDinoModel().predict_with_caption
diff --git a/extras/inpaint_mask.py b/extras/inpaint_mask.py
index 08d3c3cc..dfcb90a9 100644
--- a/extras/inpaint_mask.py
+++ b/extras/inpaint_mask.py
@@ -1,4 +1,23 @@
+from PIL import Image
+import numpy as np
+import torch
from rembg import remove, new_session
+from extras.GroundingDINO.util.inference import default_groundingdino
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
+
+ # run grounding dino model
+ boxes, _ = default_groundingdino(
+ image=np.array(input_image),
+ caption=text_prompt,
+ box_threshold=box_threshold,
+ text_threshold=text_threshold
+ )
+
+ return boxes.xyxy
def generate_mask_from_image(image, mask_model, extras):
@@ -8,9 +27,15 @@ def generate_mask_from_image(image, mask_model, extras):
if 'image' in image:
image = image['image']
+ if mask_model == 'sam':
+ boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
+ extras['sam_prompt'] = []
+ for idx, box in enumerate(boxes):
+ extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}]
+
return remove(
image,
- session=new_session(mask_model),
+ session=new_session(mask_model, **extras),
only_mask=True,
**extras
)
diff --git a/launch.py b/launch.py
index e98045f6..63a38767 100644
--- a/launch.py
+++ b/launch.py
@@ -22,8 +22,9 @@ from build_launcher import build_launcher
from modules.launch_util import is_installed, run, python, run_pip, requirements_met
from modules.model_loader import load_file_from_url
from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \
- checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads
+ path_inpaint, checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads
+os.environ["U2NET_HOME"] = path_inpaint
REINSTALL_ALL = False
TRY_INSTALL_XFORMERS = False
diff --git a/modules/config.py b/modules/config.py
index 1010fe6a..07787015 100644
--- a/modules/config.py
+++ b/modules/config.py
@@ -396,6 +396,12 @@ default_inpaint_mask_cloth_category = get_config_item_or_set_default(
validator=lambda x: x in modules.flags.inpaint_mask_cloth_category
)
+default_inpaint_mask_sam_model = get_config_item_or_set_default(
+ key='default_inpaint_mask_sam_model',
+ default_value='sam_vit_b_01ec64',
+ validator=lambda x: x in modules.flags.inpaint_mask_sam_model
+)
+
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
# mapping config to meta parameter
diff --git a/modules/flags.py b/modules/flags.py
index 7279dba5..06f52784 100644
--- a/modules/flags.py
+++ b/modules/flags.py
@@ -48,11 +48,13 @@ performance_selections = [
output_formats = ['png', 'jpg', 'webp']
inpaint_mask_models = [
- 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime'
+ 'u2net', 'u2netp', 'u2net_human_seg', 'u2net_cloth_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'sam'
]
inpaint_mask_cloth_category = ['full', 'upper', 'lower']
+inpaint_mask_sam_model = ['sam_vit_b_01ec64', 'sam_vit_h_4b8939', 'sam_vit_l_0b3195']
+
inpaint_option_default = 'Inpaint or Outpaint (default)'
inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)'
inpaint_option_modify = 'Modify Content (add objects, change background, etc.)'
diff --git a/readme.md b/readme.md
index 90c37508..d5045849 100644
--- a/readme.md
+++ b/readme.md
@@ -27,7 +27,7 @@ Included adjustments:
* ✨ https://github.com/lllyasviel/Fooocus/pull/1938 - automatically describe image on uov image upload if prompt is empty
* ✨ https://github.com/lllyasviel/Fooocus/pull/1940 - meta data handling, schemes: Fooocus (json) and A1111 (plain text). Compatible with Civitai.
* ✨ https://github.com/lllyasviel/Fooocus/pull/1979 - prevent outdated history log link after midnight
-* ✨ https://github.com/lllyasviel/Fooocus/pull/2032 - add inpaint mask generation functionality using rembg
+* ✨ https://github.com/lllyasviel/Fooocus/pull/2032 - add inpaint mask generation functionality using rembg, incl. segmentation support
✨ = new feature
🐛 = bugfix
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7b2a37e8..ebcd0297 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -18,4 +18,5 @@ httpx==0.24.1
onnxruntime==1.16.3
timm==0.9.2
translators==5.8.9
-rembg==2.0.53
\ No newline at end of file
+rembg==2.0.53
+groundingdino-py==0.4.0
\ No newline at end of file
diff --git a/webui.py b/webui.py
index bacd380f..8579c9df 100644
--- a/webui.py
+++ b/webui.py
@@ -211,31 +211,52 @@ with shared.gradio_root:
with gr.Column(visible=False) as inpaint_mask_generation_col:
inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy',
- height=500, visible=False)
+ height=500)
inpaint_mask_model = gr.Dropdown(label='Mask generation model',
choices=flags.inpaint_mask_models,
- value=modules.config.default_inpaint_mask_model,
- visible=False)
+ value=modules.config.default_inpaint_mask_model)
inpaint_mask_cloth_category = gr.Dropdown(label='Cloth category',
choices=flags.inpaint_mask_cloth_category,
value=modules.config.default_inpaint_mask_cloth_category,
visible=False)
- generate_mask_button = gr.Button(value='Generate mask from image', visible=False)
+ inpaint_mask_sam_prompt_text = gr.Textbox(label='Segmentation prompt', value='', visible=False)
+ with gr.Accordion("Advanced options", visible=False, open=False) as inpaint_mask_advanced_options:
+ inpaint_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model)
+ inpaint_mask_sam_quant = gr.Checkbox(label='Quantization', value=False)
+ inpaint_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05)
+ inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
+ generate_mask_button = gr.Button(value='Generate mask from image')
- def generate_mask(image, mask_model, cloth_category):
+ def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold):
from extras.inpaint_mask import generate_mask_from_image
- return generate_mask_from_image(image, mask_model, {"cloth_category": cloth_category})
+
+ extras = {}
+ if mask_model == 'u2net_cloth_seg':
+ extras['cloth_category'] = cloth_category
+ elif mask_model == 'sam':
+ extras['sam_prompt_text'] = sam_prompt_text
+ extras['sam_model'] = sam_model
+ extras['sam_quant'] = sam_quant
+ extras['box_threshold'] = box_threshold
+ extras['text_threshold'] = text_threshold
+
+ return generate_mask_from_image(image, mask_model, extras)
generate_mask_button.click(fn=generate_mask,
inputs=[
inpaint_input_image, inpaint_mask_model,
- inpaint_mask_cloth_category
+ inpaint_mask_cloth_category,
+ inpaint_mask_sam_prompt_text,
+ inpaint_mask_sam_model,
+ inpaint_mask_sam_quant,
+ inpaint_mask_box_threshold,
+ inpaint_mask_text_threshold
],
- outputs=inpaint_mask_image)
+ outputs=inpaint_mask_image, show_progress=True, queue=True)
- inpaint_mask_model.change(lambda x: gr.update(visible=x == 'u2net_cloth_seg'),
+ inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
- outputs=inpaint_mask_cloth_category,
+ outputs=[inpaint_mask_cloth_category, inpaint_mask_sam_prompt_text, inpaint_mask_advanced_options],
queue=False, show_progress=False)
with gr.TabItem(label='Describe') as desc_tab:
@@ -518,10 +539,9 @@ with shared.gradio_root:
inpaint_strength, inpaint_respective_field,
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
- inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 4,
+ inpaint_mask_upload_checkbox.change(lambda x: [gr.update(visible=x)] * 2,
inputs=inpaint_mask_upload_checkbox,
- outputs=[inpaint_mask_image, generate_mask_button,
- inpaint_mask_model, inpaint_mask_generation_col],
+ outputs=[inpaint_mask_image, inpaint_mask_generation_col],
queue=False, show_progress=False)
with gr.Tab(label='FreeU'):