feat: add debug dino and mask dilate and erode

This commit is contained in:
Manuel Schmid 2024-06-09 22:31:41 +02:00
parent f2e7b65ed3
commit 57c049858c
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 38 additions and 41 deletions

View File

@ -1,26 +1,9 @@
from PIL import Image
import numpy as np
import torch
from rembg import remove, new_session
from extras.GroundingDINO.util.inference import default_groundingdino
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold):
# run grounding dino model
detections, _, _, _ = default_groundingdino(
image=np.array(input_image),
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
return detections.xyxy
def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int=0):
def generate_mask_from_image(image: np.ndarray, mask_model: str, extras: dict, box_erode_or_dilate: int=0, debug_dino: bool=False) -> np.ndarray | None:
if image is None:
return
@ -28,29 +11,37 @@ def generate_mask_from_image(image, mask_model, extras, box_erode_or_dilate: int
image = image['image']
if mask_model == 'sam':
img = Image.fromarray(image)
boxes = run_grounded_sam(img, extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold'])
detections, _, _, _ = default_groundingdino(
image=image,
caption=extras['sam_prompt_text'],
box_threshold=extras['box_threshold'],
text_threshold=extras['text_threshold']
)
detection_boxes = detections.xyxy
# use full image if no box has been found
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes
detection_boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(detection_boxes) == 0 else detection_boxes
extras['sam_prompt'] = []
# from PIL import ImageDraw
# draw = ImageDraw.Draw(img)
for idx, box in enumerate(boxes):
for idx, box in enumerate(detection_boxes):
box_list = box.tolist()
if box_erode_or_dilate != 0:
box_list[0] -= box_erode_or_dilate
box_list[1] -= box_erode_or_dilate
box_list[2] += box_erode_or_dilate
box_list[3] += box_erode_or_dilate
# draw.rectangle(box_list, fill=128, outline ="red")
extras['sam_prompt'] += [{"type": "rectangle", "data": box_list}]
# img.show()
if debug_dino:
from PIL import ImageDraw, Image
image_with_boxes = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
draw = ImageDraw.Draw(image_with_boxes)
for box in extras['sam_prompt']:
draw.rectangle(box['data'], fill="white")
return np.array(image_with_boxes)
return remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
# post_process_mask=True,
**extras
)

View File

@ -377,10 +377,15 @@
"Disable preview during generation.": "Disable preview during generation.",
"Disable Intermediate Results": "Disable Intermediate Results",
"Disable intermediate results during generation, only show final gallery.": "Disable intermediate results during generation, only show final gallery.",
"Debug Inpaint Preprocessing": "Debug Inpaint Preprocessing",
"Debug GroundingDINO": "Debug GroundingDINO",
"Used for SAM object detection and box generation": "Used for SAM object detection and box generation",
"GroundingDINO Box Erode or Dilate": "GroundingDINO Box Erode or Dilate",
"Inpaint Engine": "Inpaint Engine",
"v1": "v1",
"Version of Fooocus inpaint model": "Version of Fooocus inpaint model",
"v2.5": "v2.5",
"v2.6": "v2.6",
"Control Debug": "Control Debug",
"Debug Preprocessors": "Debug Preprocessors",
"Mixing Image Prompt and Vary/Upscale": "Mixing Image Prompt and Vary/Upscale",

View File

@ -231,7 +231,7 @@ with shared.gradio_root:
inpaint_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05)
generate_mask_button = gr.Button(value='Generate mask from image')
def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold):
def generate_mask(image, mask_model, cloth_category, sam_prompt_text, sam_model, sam_quant, box_threshold, text_threshold, debug_dino, dino_erode_or_dilate):
from extras.inpaint_mask import generate_mask_from_image
extras = {}
@ -244,19 +244,7 @@ with shared.gradio_root:
extras['box_threshold'] = box_threshold
extras['text_threshold'] = text_threshold
return generate_mask_from_image(image, mask_model, extras)
generate_mask_button.click(fn=generate_mask,
inputs=[
inpaint_input_image, inpaint_mask_model,
inpaint_mask_cloth_category,
inpaint_mask_sam_prompt_text,
inpaint_mask_sam_model,
inpaint_mask_sam_quant,
inpaint_mask_box_threshold,
inpaint_mask_text_threshold
],
outputs=inpaint_mask_image, show_progress=True, queue=True)
return generate_mask_from_image(image, mask_model, extras, dino_erode_or_dilate, debug_dino)
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
@ -570,6 +558,8 @@ with shared.gradio_root:
with gr.Tab(label='Inpaint'):
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
debug_dino = gr.Checkbox(label='Debug GroundingDINO', value=False,
info='Used for SAM object detection and box generation')
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
inpaint_engine = gr.Dropdown(label='Inpaint Engine',
value=modules.config.default_inpaint_engine_version,
@ -592,6 +582,10 @@ with shared.gradio_root:
info='Positive value will make white area in the mask larger, '
'negative value will make white area smaller.'
'(default is 0, always process before any mask invert)')
dino_erode_or_dilate = gr.Slider(label='GroundingDINO Box Erode or Dilate',
minimum=-64, maximum=64, step=1, value=0,
info='Positive value will make white area in the mask larger, '
'negative value will make white area smaller.')
inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
@ -741,6 +735,13 @@ with shared.gradio_root:
inpaint_strength, inpaint_respective_field
], show_progress=False, queue=False)
generate_mask_button.click(fn=generate_mask,
inputs=[inpaint_input_image, inpaint_mask_model, inpaint_mask_cloth_category,
inpaint_mask_sam_prompt_text, inpaint_mask_sam_model, inpaint_mask_sam_quant,
inpaint_mask_box_threshold, inpaint_mask_text_threshold, debug_dino,
dino_erode_or_dilate],
outputs=inpaint_mask_image, show_progress=True, queue=True)
ctrls = [currentTask, generate_image_grid]
ctrls += [
prompt, negative_prompt, translate_prompts, style_selections,