feat: add disclaimer + skipping, stage2 won't properly work when used with inpaint or outpaint

This commit is contained in:
Manuel Schmid 2024-06-13 01:13:21 +02:00
parent dbc844804b
commit f8f36828c7
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 91 additions and 44 deletions

View File

@ -18,7 +18,7 @@ sam_options = SAMOptions(
model_type='vit_b'
)
mask_image = generate_mask_from_image(image, sam_options=sam_options)
mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options)
merged_masks_img = Image.fromarray(mask_image)
merged_masks_img.show()

View File

@ -42,9 +42,13 @@ def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
sam_options: SAMOptions | None = SAMOptions) -> np.ndarray | None:
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
dino_detection_count = 0
sam_detection_count = 0
sam_detection_on_mask_count = 0
if image is None:
return
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
if extras is None:
extras = {}
@ -53,13 +57,15 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
image = image['image']
if mask_model != 'sam' and sam_options is None:
return remove(
result = remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
assert sam_options is not None
detections, boxes, logits, phrases = default_groundingdino(
@ -80,7 +86,11 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
sam_predictor = SamPredictor(sam)
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
if boxes.size(0) > 0:
dino_detection_count = boxes.size(0)
sam_detection_count = 0
sam_detection_on_mask_count = 0
if dino_detection_count > 0:
sam_predictor.set_image(image)
if sam_options.dino_erode_or_dilate != 0:
@ -97,7 +107,7 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
draw = ImageDraw.Draw(debug_dino_image)
for box in boxes.numpy():
draw.rectangle(box.tolist(), fill="white")
return np.array(debug_dino_image)
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
@ -109,12 +119,12 @@ def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=
masks = optimize_masks(masks)
num_obj = min(len(logits), sam_options.max_num_boxes)
for obj_ind in range(num_obj):
sam_objects = min(len(logits), sam_options.max_num_boxes)
for obj_ind in range(sam_objects):
mask_tensor = masks[obj_ind][0]
final_mask_tensor += mask_tensor
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
mask_image = np.array(mask_image, dtype=np.uint8)
return mask_image
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count

View File

@ -1040,13 +1040,15 @@ def worker():
# stage2
progressbar(async_task, current_progress, 'Processing stage2 ...')
final_unet = pipeline.final_unet.clone()
if len(async_task.stage2_ctrls) == 0:
final_unet = pipeline.final_unet
if len(async_task.stage2_ctrls) == 0 or 'inpaint' in goals:
print(f'[Stage2] Skipping, preconditions aren\'t met')
continue
for img in imgs:
for stage2_mask_dino_prompt_text, stage2_mask_box_threshold, stage2_mask_text_threshold, stage2_mask_sam_max_num_boxes, stage2_mask_sam_model in async_task.stage2_ctrls:
mask = generate_mask_from_image(img, sam_options=SAMOptions(
print(f'[Stage2] Searching for "{stage2_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=stage2_mask_dino_prompt_text,
dino_box_threshold=stage2_mask_box_threshold,
dino_text_threshold=stage2_mask_text_threshold,
@ -1060,12 +1062,43 @@ def worker():
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
do_not_show_finished_images=len(
tasks) == 1 or async_task.disable_intermediate_results)
print(f'[Stage2] {dino_detection_count} boxes detected')
print(f'[Stage2] {sam_detection_count} segments detected in boxes')
print(f'[Stage2] {sam_detection_on_mask_count} segments applied to mask')
if dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0:
print(f'[Stage2] Skipping')
continue
# TODO make configurable
# # do not apply loras / controlnets / etc. twice (samplers are needed though)
# pipeline.final_unet = pipeline.model_base.unet.clone()
# pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name,
# base_model_name=async_task.base_model_name,
# loras=[],
# base_model_additional_loras=[],
# use_synthetic_refiner=use_synthetic_refiner,
# vae_name=async_task.vae_name)
# pipeline.set_clip_skip(async_task.clip_skip)
#
# # patch everything again except original inpainting
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
# if async_task.freeu_enabled:
# apply_freeu(async_task)
# patch_samplers(async_task)
# defaults from inpaint mode improve details
denoising_strength_stage2 = 0.5
inpaint_respective_field_stage2 = 0.0
inpaint_head_model_path_stage2 = None
inpaint_parameterized_stage2 = False # inpaint_engine = None, improve detail
goals_stage2 = ['inpaint']
denoising_strength_stage2, initial_latent_stage2, width_stage2, height_stage2 = apply_inpaint(
async_task, None, inpaint_head_model_path_stage2, img, mask,
@ -1080,7 +1113,6 @@ def worker():
# reset and prepare next iteration
img = imgs2[0]
pipeline.final_unet = final_unet
inpaint_worker.current_task = None
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':

View File

@ -250,7 +250,9 @@ with shared.gradio_root:
model_type=sam_model
)
return generate_mask_from_image(image, mask_model, extras, sam_options)
mask, _, _, _ = generate_mask_from_image(image, mask_model, extras, sam_options)
return mask
inpaint_mask_model.change(lambda x: [gr.update(visible=x == 'u2net_cloth_seg'), gr.update(visible=x == 'sam'), gr.update(visible=x == 'sam')],
inputs=inpaint_mask_model,
@ -299,38 +301,41 @@ with shared.gradio_root:
outputs=metadata_json, queue=False, show_progress=True)
with gr.Row(visible=False) as stage2_input_panel:
with gr.Tabs():
stage2_ctrls = []
for index in range(modules.config.default_stage2_tabs):
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
# stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True)
stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True)
example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts,
label='Additional Prompt Quick List',
components=[stage2_mask_dino_prompt_text],
visible=True)
example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False)
with gr.Column():
gr.HTML('DISCLAIMER: Stage2 will be skipped when used in combination with Inpaint or Outpaint!')
with gr.Row():
with gr.Tabs():
stage2_ctrls = []
for index in range(modules.config.default_stage2_tabs):
with gr.TabItem(label=f'Iteration #{index + 1}') as stage2_tab_item:
stage2_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check', container=False)
with gr.Accordion('Options', visible=True, open=False) as stage2_accordion:
# stage2_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_detail, label='Method', interactive=True)
stage2_mask_dino_prompt_text = gr.Textbox(label='Segmentation prompt', info='Use singular whenever possible', interactive=True)
example_stage2_mask_dino_prompt_text = gr.Dataset(samples=modules.config.example_stage2_prompts,
label='Additional Prompt Quick List',
components=[stage2_mask_dino_prompt_text],
visible=True)
example_stage2_mask_dino_prompt_text.click(lambda x: x[0], inputs=example_stage2_mask_dino_prompt_text, outputs=stage2_mask_dino_prompt_text, show_progress=False, queue=False)
with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options:
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
with gr.Accordion("Advanced options", visible=True, open=False) as inpaint_mask_advanced_options:
stage2_mask_sam_model = gr.Dropdown(label='SAM model', choices=flags.inpaint_mask_sam_model, value=modules.config.default_inpaint_mask_sam_model, interactive=True)
stage2_mask_box_threshold = gr.Slider(label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True)
stage2_mask_text_threshold = gr.Slider(label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05, interactive=True)
stage2_mask_sam_max_num_boxes = gr.Slider(label="Maximum number of box detections", minimum=1, maximum=5, value=modules.config.default_sam_max_num_boxes, step=1, interactive=True)
stage2_ctrls += [
stage2_enabled,
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_box_threshold,
stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
]
stage2_ctrls += [
stage2_enabled,
# stage2_mode,
stage2_mask_dino_prompt_text,
stage2_mask_box_threshold,
stage2_mask_text_threshold,
stage2_mask_sam_max_num_boxes,
stage2_mask_sam_model,
]
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
outputs=stage2_accordion, queue=False, show_progress=False)
stage2_enabled.change(lambda x: gr.update(open=x), inputs=stage2_enabled,
outputs=stage2_accordion, queue=False, show_progress=False)
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
down_js = "() => {viewer_to_bottom();}"